Softmax 与 交叉熵损失

129 阅读2分钟

本文档详细解析了多分类任务中 Softmax交叉熵损失 (Cross-Entropy Loss) 的数学原理、配合机制以及工程实现中的数值稳定性问题。


1. 核心理论:从 Logits 到 Loss

假设神经网络最后一层的原始输出(未归一化)为向量 o\mathbf{o},其中第 ii 个节点的输出记为 oio_i(工程上常称为 Logits)。

1.1 Softmax 函数:归一化

Softmax 的核心作用是将任意实数域的输出 oio_i 映射为概率分布 yiy_i

公式: yi=Softmax(oi)=eoij=1Ceojy_i = \text{Softmax}(o_i) = \frac{e^{o_i}}{\sum_{j=1}^{C} e^{o_j}}

  • 非负性eoi>0e^{o_i} > 0,保证概率为正。
  • 归一化yi=1\sum y_i = 1,符合概率定义。
  • 差异放大:指数函数会拉大数值之间的差距,使“强者恒强”。

1.2 交叉熵损失 (Cross-Entropy Loss)

交叉熵用于衡量“预测概率分布”与“真实标签分布”之间的差异。

公式: L=i=1Ctilog(yi)L = - \sum_{i=1}^{C} t_i \cdot \log(y_i)

其中 tt 是真实标签的 One-hot 编码向量(正确类别为 1,其余为 0)。因此,对于单个样本,公式简化为:

L=log(ycorrect)L = - \log(y_{\text{correct}})

即:我们只关心模型对“正确类别”预测了多大的概率。

1.3 为什么它们是“黄金搭档”?

当我们将 Softmax 和 Cross-Entropy 结合在一起对 oio_i 求导时,会得到非常优雅的梯度形式:

Loi=yiti\frac{\partial L}{\partial o_i} = y_i - t_i

物理含义: 梯度等于 (预测概率) - (真实标签)

  • 这种线性的梯度特性避免了均方误差(MSE)在分类任务中可能遇到的梯度消失问题。
  • 误差越大,梯度越大,模型参数更新越快。

2. 工程挑战:数值稳定性 (Numerical Stability)

在实际工程落地时,直接按照数学公式计算 Softmax 会遇到严重问题。

2.1 上溢问题

指数函数 exe^x 增长极快。在标准的 float32 浮点数系统中:

  • e1002.6×1043e^{100} \approx 2.6 \times 10^{43}
  • oi>88o_i > 88,则 eoiinfe^{o_i} \to \text{inf} (无穷大)。

一旦分子或分母出现 inf,计算结果就会变成 NaN (Not a Number),导致训练崩溃。

2.2 解决方案:减去最大值 (The Max Trick)

利用 Softmax 的 平移不变性,我们在分子分母的指数中同时减去输入向量的最大值 M=max(o)M = \max(\mathbf{o})

推导: Softmax(oi)=eoieoj=eoieM(eoj)eM=eoiMeojM\text{Softmax}(o_i) = \frac{e^{o_i}}{\sum e^{o_j}} = \frac{e^{o_i} \cdot e^{-M}}{(\sum e^{o_j}) \cdot e^{-M}} = \frac{e^{o_i - M}}{\sum e^{o_j - M}}

优势:

  1. 最大的指数项变为 eMM=e0=1e^{M-M} = e^0 = 1
  2. 其余所有项的指数部分均为负数或零,结果在 (0,1](0, 1] 之间。
  3. 彻底解决了上溢问题。

2.3 下溢问题

即便解决了上溢,如果在计算 Loss 时先算 Softmax 再算 Log,还可能遇到 下溢

如果某个类别的 oio_i 非常小(负绝对值很大),经过 Softmax 后 yiy_i 可能会极其接近 0。 在浮点数精度受限的情况下,计算机可能直接将 yiy_i 截断为 0。 随后计算 Loss 时: L=log(0)infL = -\log(0) \to \text{inf} 这会导致训练梯度爆炸或 Loss 变为无穷大。

2.4 解决方案:Log-Sum-Exp

核心思想: 不要分步计算 log(Softmax)\log(\text{Softmax}),而是将其合并推导,转化为一个原子操作。

数学推导:

log(Softmax(oi))=log(eoijeoj)=log(eoi)log(jeoj)=oilog(jeoj)\begin{aligned} \log(\text{Softmax}(o_i)) &= \log\left( \frac{e^{o_i}}{\sum_{j} e^{o_j}} \right) \\ &= \log(e^{o_i}) - \log\left( \sum_{j} e^{o_j} \right) \\ &= o_i - \log\left( \sum_{j} e^{o_j} \right) \end{aligned}

这里的 Log-Sum-Exp 部分: log(eoj)\log(\sum e^{o_j}) 同样再次使用 Max Trick 来保证这一步的稳定性: log(jeoj)=log(jeojMeM)=M+log(jeojM)\log\left( \sum_{j} e^{o_j} \right) = \log\left( \sum_{j} e^{o_j - M} \cdot e^M \right) = M + \log\left( \sum_{j} e^{o_j - M} \right)

2.5 最终结论

通过合并计算,我们完全避免了直接计算概率 yiy_i 这一步,而是直接通过 Logits 计算 Log-Probability。 公式变为: LogSoftmax(oi)=oiMlog(eojM)\text{LogSoftmax}(o_i) = o_i - M - \log\left( \sum e^{o_j - M} \right) 在这个公式中,所有中间数值都被限制在安全范围内,既不会上溢也不会下溢。


3. 代码演示

import torch
from torch import nn

# 模拟数据
# 2 个样本,4 个类别
outputs = torch.tensor([ # 模型未归一化的输出(logits)
    [-1.2, -0.2, 0.8, 1.3], # 未溢出数据
    [-102.6, 85.7, 87.4, 90.2]])  # 溢出数据
targets = torch.tensor([1, 0])  # 真实标签

## 上溢版本 softmax
def unsafe_softmax(X):
    exp_X = torch.exp(X)
    partition = exp_X.sum(1, keepdim=True)
    return exp_X / partition

unsafe_so = unsafe_softmax(outputs)
unsafe_so

输出: 由于exp溢出,第二行数据计算错误
tensor([[0.0429, 0.1167, 0.3173, 0.5231], [0.0000, 0.0000, 0.0000, nan]])

## 安全版本 sotfmax
def safe_softmax(X):
    M = X.max(1, keepdim=True).values
    XM = X - M
    exp_X = torch.exp(XM)
    partition = exp_X.sum(1, keepdim=True)
    return exp_X / partition

safe_so = safe_softmax(outputs)
safe_so

输出: softmax有平移不变形,计算结果正确
tensor([[0.0429, 0.1167, 0.3173, 0.5231], [0.0000, 0.0104, 0.0567, 0.9329]])

## 下溢版本 Cross Entropy Loss
## -102.6 下溢,并且为标签数据。注意,非标签数据存在下溢不影响结果
## 因为softmax > 0, 所以非标签数据的 P(o) * log(softmax(o)) 恒等于0
def unsafe_cnl(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

unsafe_cy = unsafe_cnl(safe_so, targets)
unsafe_cy

输出: -102.6计算了log(0),导致结果错误
tensor([2.1480, inf])

## 安全版本 Cross Entropy Loss
## Loss = -LogSoftmax = M - Oi + log(sum(exp(Oj - M)))
def safe_cnl(X, y):
    M = X.max(1, keepdim=True).values
    true_x = X[range(X.size(0)), y]
    shift_x = torch.exp(X - M)
    log_sum = torch.log(shift_x.sum(1))
    return M.reshape(y.shape) - true_x + log_sum

safe_cy = safe_cnl(outputs, targets)
safe_cy

输出:
tensor([ 2.1480, 192.8694])

## pytorch 自带溢出处理
## 直接使用 nn.CrossEntropyLoss
loss_fn_none = nn.CrossEntropyLoss(reduction='none')
loss_none = loss_fn_none(outputs, targets)
loss_none

输出:
tensor([ 2.1480, 192.8694])