Gumbel Softmax trick快速理解(附pytorch实现)(二)

1,262 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。


接上一篇(一)

(三)什么是Gumbel softmax trick?

Gumbel分布描述了自然界或者说人造的某种数据(其实也是自然界吧,毕竟人也是自然的一部分。)的极值分布的 “规律”(分布其实只是认识”规律“的一种方式)。所以自然地,我们之所以会用到Gumbel分布,就是因为我们要处理的数据中,存在极值分布(~废话)。

考虑如下场景:

对一个离散随机变量X\mathbf{X}进行采样,随机变量的取值范围为{1,2,...,K}\{1,2,...,K\}。首先要知道随机变量的分布函数,这里假设用MLP学习一个K维的向量:hRK\mathbf{h} \in \mathbb{R}^K

(假如是直接做inference的话,不考虑概率意义,那么我们直接取这个向量元素最大值的下标当做预测的离散变量值就可以了,即,Xi=argmaxihiX_i = \arg\max_i h_i.,但我们希望的是预测的离散变量具有概率意义,或者说得到的多个预测值的经验分布符合理论的概率分布。否则的话就是deterministic的,会导致某些小概率的变量值根本取不到,进而影响后续的任务。)

所以,我们需要赋予概率意义。通常,我们可以用softmax函数作用到h\mathbf{h}求得一个符合概率意义的新概率向量,即:

pi=softmax(h,hi)=exp(hi)iexp(hi).p_i=softmax(h,h_i)=\frac{exp(h_i)}{\sum_i exp(h_i)}.

这样我们就获得了各个离散取值的概率分布p[0,1]K\mathbf{p} \in [0,1]^K,其中pi=Pr{Xi=i}p_i=Pr\{X_i=i\}。这里p\mathbf{p}是一个在K维simplex中的一个向量。

到这里,我们得到了XX的概率分布,如果要直接得到离散变量,直接取Xi=argmaxipiX_i = \arg\max_i p_i即可。(注意,这里每次inference的时候,取了最大值,是不是和Gumbel分布的含义很像了。)

问题是我们需要的是采样,也就是生成的多个样本的频率分布要符合其理论的概率分布。另外,可以开始考虑,是否能够将求导采样这两个操作解耦。

如果知道一些reparameterization trick的技巧,很容易想到,我们只需要将p\mathbf{p}加上一个要学习的参数无关(即无需进行求导)的某个随机变量g\mathbf{g},那么采样过程就可以通过g\mathbf{g}进行(曲线救国了算是),这样做相当于把求导采样解耦了。这里,只需要保证结合后的分布,与原分布p\mathbf{p}相等或近似即可。

接下来就是与服从Gumbel分布的随机变量g\mathbf{g}结合:

Xi=argmaxi(log(pi)+gi)X_i = \arg \max_i (\log(p_i) + g_i)

这里gig_i是一个提前采样好的标准Gumbel分布序列。通过这种方法,理论上可以证明,这个新随机变量的分布函数和原分布函数相等。证明见:。。

但这样的问题在于argmax()\arg\max ()不可导,导致无法使用梯度下降来更新参数。所以一种办法是将随机变量的取值从1,...,K{1,...,K}变为用一个K维的one_hot向量编码来表示。比如,本来取Xi=iX_i=i,现如果用one_hot来表示的话,就是Xi=(0,...,1,...,0)X_i = (0,...,1,...,0),也就是第ii个下标的值为1,其它都为0,我们记第ii个下标的值为yiy_i。那么我们就可以用softmax函数来近似这个one_hot向量:

yi=exp((log(pi)+gi)/τ)k=1Kexp((log(pk)+gk)/τ)y_i = \frac{\exp((\log(p_i) + g_i)/\tau)}{\sum_{k=1}^K\exp((\log(p_k) + g_k)/\tau)}

这里的τ\tau被叫做温度系数,或者说是一个缩放因子。一般来说,τ<1\tau < 1,想象一下以ee为底的指数分布图像,可以发现,如果τ\tau越小,指数的值e(x/τ)越大,简记x=log(pi)+gie^{(x/\tau)}越大,简记x=log(p_i)+g_i。也就是说,这个τ\tau存在的意义就是让本来大的xx越大,所以会导致yiy_i越接近1,并且ji,yj\forall j \neq i, y_j会接近0,所以XiX_i就更接近一个one_hot表示。 图片来源于文章[7]:图片来源于文章[7]