Gumbel-Softmax揭秘:探索随机噪声在离散优化中的应用 (1)

245 阅读2分钟

#金石计划征文活动

实际上,Gumbel-Softmax 算法并不直接选择概率最高的值作为最终输出,而是采用了一种被称为“Gumbel-max 技巧”的方法来近似达到这一目的。该算法主要用于生成离散分布的梯度,从而支持神经网络中的反向传播过程。

借助 Gumbel-Softmax,我们能够获得一种连续的近似方法,这使得即使在需要做出离散选择的情况下,也能够维持反向传播的有效性。这意味着,通过运用 Gumbel-Softmax 技术,我们可以从离散分布中抽样,同时确保梯度可以顺利传递,进而优化模型的学习过程。

Softmax 函数

首先,Softmax 函数将一组数值转换为概率分布。对于给定的数据 (data = [1, 2, 3, 4]),Softmax 函数的计算公式为:

pi=exij=1nexjp_i = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}

其中 ( xi ) 是输入数据中的第 i 个元素。

对于data = [1, 2, 3, 4],计算得到的Softmax概率 ( p ) 为: p = [0.03, 0.08, 0.24, 0.65]

(注意:这里的概率值是近似值,实际计算时可能会有轻微的差异)

Gumbel噪声

Gumbel噪声是一种从Gumbel分布中采样得到的噪声,它在机器学习和特别是深度学习中有着重要的应用,尤其是在处理离散随机变量和近似离散分布的梯度时。Gumbel分布是一种极值分布,常用于模拟一组独立同分布的随机变量中的最小值或最大值。

Gumbel分布

Gumbel分布是由德国数学家恩斯特·爱德华·库贝尔(Ernst Eduard Gumbel)提出的,它是一种连续概率分布,常用于描述极端事件。Gumbel分布的概率密度函数(PDF)和累积分布函数(CDF)如下:

概率密度函数 (PDF):

f(x;μ,β)=1βexp((xμβ+exμβ))f(x; \mu, \beta) = \frac{1}{\beta} \exp \left( -\left( \frac{x - \mu}{\beta} + e^{-\frac{x - \mu}{\beta}} \right) \right)

其中,μ\mu 是位置参数,β\beta 是尺度参数。

累积分布函数 (CDF):

F(x;μ,β)=exp(exμβ)F(x; \mu, \beta) = \exp \left( -e^{-\frac{x - \mu}{\beta}} \right)

Gumbel噪声的生成

在实际应用中,我们通常使用Gumbel分布的CDF的逆函数来生成Gumbel噪声。具体步骤如下:

  • 生成均匀分布的随机数:从均匀分布 ( U(0,1) ) 中采样一个随机数 ( u )。
  • 应用Gumbel分布的逆CDF:使用公式g=log(log(u)) g = -\log(-\log(u)) 来生成Gumbel噪声。这个公式实际上是Gumbel分布的CDF的逆函数。

Gumbel-Softmax 函数

Gumbel-Softmax 函数通过引入 Gumbel 噪声来近似离散分布。其计算公式为:

yi=exp((log(pi)+gi)/τ)k=1Kexp((log(pk)+gk)/τ)y_i = \frac{\exp((\log(p_i) + g_i) / \tau)}{\sum_{k=1}^{K} \exp((\log(p_k) + g_k) / \tau)}

其中 g i是从Gumbel分布中采样的噪声, tau是温度参数。

Gumbel-Softmax 的实现步骤
Softmax 计算:首先计算Softmax 概率。

Gumbel 噪声:为每个概率值添加 Gumbel 噪声。

Gumbel-Softmax 计算:计算 Gumbel-Softmax 值。

选择最大值:在Gumbel-Softmax值中选择最大的值。

示例代码
下面是一个 Python 示例,使用 PyTorch 实现 Gumbel-Softmax:

def test_gumbel_softmax():
    """Gumbel-Softmax 分布"""
    import torch
    import numpy as np

    def gumbel_softmax(logits, temperature=1.0):
        # 使用 torch.rand_like(logits) 生成一个与输入 logits 形状相同的张量,其中的元素是从均匀分布U(0,1)中采样得到的
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-10))
        y = logits + gumbel_noise
        y = torch.softmax(y / temperature, dim=-1)
        return y

    data = torch.tensor([1.0, 2.0, 3.0, 4.0])
    # 利用公式将数据转化成概率 p(i) = exp(z_i) / sum_j(exp(z_j))
    probs = torch.softmax(data, dim=0)
    gumbel_probs = gumbel_softmax(data, temperature=1.0)
    print("Softmax probabilities:", probs)
    print("Gumbel-Softmax probabilities:", gumbel_probs)

    # 选择 Gumbel-Softmax 中概率最大的值
    max_data, max_index = torch.max(gumbel_probs, dim=0)
    print(f"Selected index:{max_index.item()},Selected data:{ max_data.item()}")


if __name__ == "__main__":
    test_gumbel_softmax()

结果解释

  • Softmax 概率:计算得到的概率分布。
  • Gumbel-Softmax 概率:在 Softmax 概率基础上添加 Gumbel 噪声并重新计算。
  • 选择最大值:从 Gumbel-Softmax 概率中选择概率最大的值的索引。

这个过程并不是简单地选择概率最大的值,而是通过引入噪声来近似实现这一点,使得在梯度下降过程中可以处理离散变量。