Gumbel Softmax trick快速理解(附pytorch实现)(三)

1,280 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。


接上一篇(二)

我们可以简单回顾上两篇当中对基本的Gumbel分布,已经重参数化的技巧进行简单的回顾。Gumbel分布,即,极大极小分布,简单说就是,每个班级男生或者女生的分布是正态分布,那么把每个班级的高个子拿出来,组成的样本的分布是Gumbel极大分布。重参数技巧,简单说就是,将求导过程和采样过程解偶。

(四)如何生成Gumbel分布的样本

最后一步,就是如何生成Gumbel 分布的样本,即,如何产生gig_i

这里使用最常见的一种方法也就是inverse CDF method。先求出Gumbel的CDF函数F(x;μ,β)F(x;\mu,\beta)的反函数x=F1(y;μ,β)=μβln(lny)x = F^{-1}(y;\mu,\beta)=\mu - \beta \ln(- \ln y)(根据CDF的公式:y=F(x;μ,β)y=F(x;\mu,\beta),把y和x反过来表示就可),然后只要生成yUniform(0,1)y \sim Uniform(0,1)的均匀分布的序列,那么相应的xx就服从Gumbel分布,xGumbel(μ,β)x \sim Gumbel(\mu, \beta),也即,xx的CDF函数为原来的F(x)F(x)证明如下:

P(F1(y)x)=P(yF(x))=0F(x)pdf(y)dy=0F(x)1dy=F(x)P(F^{-1}(y) \leq x)= P(y \leq F(x))=\int_0^{F(x)}pdf(y)dy=\int_0^{F(x)}1dy=F(x)

到这里我们就可以通过以上的公式进行采样了。

(五)pytorch实现

下面用pytorch实现一下上面描述的采样过程。

# Gumbel softmax trick:

import torch
import torch.nn.functional as F
import numpy as np

def inverse_gumbel_cdf(y, mu, beta):
    return mu - beta * np.log(-np.log(y))

def gumbel_softmax_sampling(h, mu=0, beta=1, tau=0.1):
    """
    h : (N x K) tensor. Assume we need to sample a NxK tensor, each row is an independent r.v.
    """
    shape_h = h.shape
    p = F.softmax(h, dim=1)
    y = torch.rand(shape_h) + 1e-25  # ensure all y is positive.
    g = inverse_gumbel_cdf(y, mu, beta)
    x = torch.log(p) + g  # samples follow Gumbel distribution.
    # using softmax to generate one_hot vector:
    x = x/tau
    x = F.softmax(x, dim=1)  # now, the x approximates a one_hot vector.
    return x

N = 10  # 假设 有N个独立的离散变量需要采样
K = 3   # 假设 每个离散变量有3个取值
h = torch.rand((N, K))  # 假设 h是由一个神经网络输出的tensor。

mu = 0
beta = 1
tau = 0.1

samples = gumbel_softmax_sampling(h, mu, beta, tau)

References

  1. mathworld.wolfram.com/GumbelDistr…
  2. www.itl.nist.gov/div898/hand…
  3. en.wikipedia.org/wiki/Fisher…
  4. en.wikipedia.org/wiki/Gumbel…
  5. www.cnblogs.com/initial-h/p…
  6. arxiv.org/pdf/1611.04…
  7. arxiv.org/abs/1611.01…