本文已参与「新人创作礼」活动,一起开启掘金创作之路。
完整版也可查看我的知乎: Gumbel softmax trick (快速理解附代码)
(一)目的
在深度学习中,对某一个离散随机变量进行采样,并且又要保证采样过程是可导的(因为要用梯度下降进行优化,并且用BP进行权重更新),那么就可以用Gumbel softmax trick。属于重参数技巧(re-parameterization)的一种。
首先我们要介绍,什么是Gumbel distribution,然后再介绍怎么用到梯度下降中,最后用pytorch实现它。
(二)什么是Gumbel distribution?
一种极值分布(或者叫做Fisher-Tippett extreme value distributions),顾名思义就是用来研究极值(极大值,或者极小值)的一种概率分布形式。和别的一些分布形式一样,给定一个描述分布的公式,然后再给定公式中的某些参数,那么就确定了这个分布。
(本质就是想用数学语言或公式来逼近或解释现实世界观察到的现象,比如自然界很多现象可以用正太分布来描述,自然地,也存在一些自然现象,要用极值分布来描述。)
举个简单的例子: 某高中三年级一共有16个班,现在从每个班级里面抽出30人(假设全为男生或者全为女生),那么现在总共有16组30人的样本。如果看每组样本里面的身高分布大概率是服从正态分布的。现在,从每组样本里面挑出身高最高的人,将这些人再组成一个新样本集合,也就是现在这个样本集合有16个人,那么这16个人的样本集合就是服从的Gumbel 分布(极大值 Gumbel distribution,当然也有极小值 Gumbel distribution)。
下面定义极大值的Gumbel distribution。
CDF:
PDF:
where .
标准Gumbel 分布:
即,, 则CDF为: