本文已参与「新人创作礼」活动,一起开启掘金创作之路。
(三)什么是Gumbel softmax trick?
Gumbel分布描述了自然界或者说人造的某种数据(其实也是自然界吧,毕竟人也是自然的一部分。)的极值分布的 “规律”(分布其实只是认识”规律“的一种方式)。所以自然地,我们之所以会用到Gumbel分布,就是因为我们要处理的数据中,存在极值分布(~废话)。
考虑如下场景:
对一个离散随机变量进行采样,随机变量的取值范围为。首先要知道随机变量的分布函数,这里假设用MLP学习一个K维的向量:。
(假如是直接做inference的话,不考虑概率意义,那么我们直接取这个向量元素最大值的下标当做预测的离散变量值就可以了,即,.,但我们希望的是预测的离散变量具有概率意义,或者说得到的多个预测值的经验分布符合理论的概率分布。否则的话就是deterministic的,会导致某些小概率的变量值根本取不到,进而影响后续的任务。)
所以,我们需要赋予概率意义。通常,我们可以用softmax函数作用到求得一个符合概率意义的新概率向量,即:
这样我们就获得了各个离散取值的概率分布,其中。这里是一个在K维simplex中的一个向量。
到这里,我们得到了的概率分布,如果要直接得到离散变量,直接取即可。(注意,这里每次inference的时候,取了最大值,是不是和Gumbel分布的含义很像了。)
问题是我们需要的是采样,也就是生成的多个样本的频率分布要符合其理论的概率分布。另外,可以开始考虑,是否能够将求导和采样这两个操作解耦。
如果知道一些reparameterization trick的技巧,很容易想到,我们只需要将加上一个要学习的参数无关(即无需进行求导)的某个随机变量,那么采样过程就可以通过进行(曲线救国了算是),这样做相当于把求导与采样解耦了。这里,只需要保证结合后的分布,与原分布相等或近似即可。
接下来就是与服从Gumbel分布的随机变量结合:
这里是一个提前采样好的标准Gumbel分布序列。通过这种方法,理论上可以证明,这个新随机变量的分布函数和原分布函数相等。证明见:。。
但这样的问题在于不可导,导致无法使用梯度下降来更新参数。所以一种办法是将随机变量的取值从变为用一个K维的one_hot向量编码来表示。比如,本来取,现如果用one_hot来表示的话,就是,也就是第个下标的值为1,其它都为0,我们记第个下标的值为。那么我们就可以用softmax函数来近似这个one_hot向量:
这里的被叫做温度系数,或者说是一个缩放因子。一般来说,,想象一下以为底的指数分布图像,可以发现,如果越小,指数的值。也就是说,这个存在的意义就是让本来大的越大,所以会导致越接近1,并且会接近0,所以就更接近一个one_hot表示。
图片来源于文章[7]: