本文已参与「新人创作礼」活动,一起开启掘金创作之路。
我们可以简单回顾上两篇当中对基本的Gumbel分布,已经重参数化的技巧进行简单的回顾。Gumbel分布,即,极大极小分布,简单说就是,每个班级男生或者女生的分布是正态分布,那么把每个班级的高个子拿出来,组成的样本的分布是Gumbel极大分布。重参数技巧,简单说就是,将求导过程和采样过程解偶。
(四)如何生成Gumbel分布的样本
最后一步,就是如何生成Gumbel 分布的样本,即,如何产生。
这里使用最常见的一种方法也就是inverse CDF method。先求出Gumbel的CDF函数的反函数(根据CDF的公式:,把y和x反过来表示就可),然后只要生成的均匀分布的序列,那么相应的就服从Gumbel分布,,也即,的CDF函数为原来的。证明如下:
到这里我们就可以通过以上的公式进行采样了。
(五)pytorch实现
下面用pytorch实现一下上面描述的采样过程。
# Gumbel softmax trick:
import torch
import torch.nn.functional as F
import numpy as np
def inverse_gumbel_cdf(y, mu, beta):
return mu - beta * np.log(-np.log(y))
def gumbel_softmax_sampling(h, mu=0, beta=1, tau=0.1):
"""
h : (N x K) tensor. Assume we need to sample a NxK tensor, each row is an independent r.v.
"""
shape_h = h.shape
p = F.softmax(h, dim=1)
y = torch.rand(shape_h) + 1e-25 # ensure all y is positive.
g = inverse_gumbel_cdf(y, mu, beta)
x = torch.log(p) + g # samples follow Gumbel distribution.
# using softmax to generate one_hot vector:
x = x/tau
x = F.softmax(x, dim=1) # now, the x approximates a one_hot vector.
return x
N = 10 # 假设 有N个独立的离散变量需要采样
K = 3 # 假设 每个离散变量有3个取值
h = torch.rand((N, K)) # 假设 h是由一个神经网络输出的tensor。
mu = 0
beta = 1
tau = 0.1
samples = gumbel_softmax_sampling(h, mu, beta, tau)