Gumbel Softmax trick快速理解(附pytorch实现)(一)

1,505 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。


完整版也可查看我的知乎: Gumbel softmax trick (快速理解附代码)

(一)目的

在深度学习中,对某一个离散随机变量XX进行采样,并且又要保证采样过程是可导的(因为要用梯度下降进行优化,并且用BP进行权重更新),那么就可以用Gumbel softmax trick。属于重参数技巧(re-parameterization)的一种。

首先我们要介绍,什么是Gumbel distribution,然后再介绍怎么用到梯度下降中,最后用pytorch实现它。

(二)什么是Gumbel distribution?

一种极值分布(或者叫做Fisher-Tippett extreme value distributions),顾名思义就是用来研究极值(极大值,或者极小值)的一种概率分布形式。和别的一些分布形式一样,给定一个描述分布的公式,然后再给定公式中的某些参数,那么就确定了这个分布。

(本质就是想用数学语言或公式来逼近或解释现实世界观察到的现象,比如自然界很多现象可以用正太分布来描述,自然地,也存在一些自然现象,要用极值分布来描述。)

举个简单的例子: 某高中三年级一共有16个班,现在从每个班级里面抽出30人(假设全为男生或者全为女生),那么现在总共有16组30人的样本。如果看每组样本里面的身高分布大概率是服从正态分布的。现在,从每组样本里面挑出身高最高的人,将这些人再组成一个新样本集合,也就是现在这个样本集合有16个人,那么这16个人的样本集合就是服从的Gumbel 分布(极大值 Gumbel distribution,当然也有极小值 Gumbel distribution)。

下面定义极大值的Gumbel distribution。

CDF:

F(x;μ,β)=ee(xμ)βF(x;\mu,\beta)=e^{-e^{- \frac {(x-\mu)}{\beta}}}

PDF:

f(x)=Fx=1βe(z+ez),f(x)=\frac{\partial F}{\partial x}=\frac{1}{\beta}e^{-(z+e^{-z})},

where z=xμβz=\frac{x-\mu}{\beta}.

标准Gumbel 分布:

即,μ=0,β=1\mu=0, \beta=1, 则CDF为:

F(x;μ=0,β=1)=F(x)=ee(x)F(x;\mu=0,\beta=1)=F(x)=e^{-e^{-(x)}}

函数图像:

在这里插入图片描述

接下一篇(二)