交叉熵损失到底在“惩罚”什么?

70 阅读1分钟

场景:教 AI 认动物

假设你在训练一个 AI,让它从图片中识别三种动物:猫、狗、鸟

你给它一张真实是“猫” 的图片,AI 输出了三个“信心分数”:

动物AI 的信心(概率)
0.7
0.2
0.1

看起来不错!但你怎么量化它“做得好不好”? 这时候,就需要损失函数来打分。


一、最朴素的想法:看“正确类别的分数”有多高

既然真实答案是“猫”,那我们就只关心 AI 对“猫”的信心 pcatp_{\text{cat}}

  • 如果 pcat=0.99p_{\text{cat}} = 0.99 → 几乎肯定,很好!
  • 如果 pcat=0.1p_{\text{cat}} = 0.1 → 几乎否定,很差!

所以,损失应该随着 pcatp_{\text{cat}} 增大而减小

但怎么设计这个“减小”关系? 线性?平方?还是……


二、为什么用 log(p)-\log(p) ?——来自信息论的直觉

想象你在考试:

  • 如果你蒙对了(比如瞎猜选 A,结果真是 A),老师会觉得你运气好,但不值得表扬
  • 如果你非常确定地答对了(说“A,我 99% 确信”),老师会认为你真懂。

反过来:

  • 如果你非常确定地答错了(说“绝对是 B”,结果是 A),老师会认为你错得离谱,要重点批评!

关键原则越自信还越错,惩罚越重;越不确定但对了,奖励有限。

而数学上,log(p)-\log(p) 完美符合这个直觉:

正确类别的概率 pp损失 log(p)-\log(p)解读
0.99≈ 0.01几乎没惩罚
0.5≈ 0.69中等惩罚
0.1≈ 2.30严重惩罚
0.01≈ 4.60极其严重!

📌 注意:当 p0p \to 0log(p)-\log(p) \to \infty —— 绝不允许模型“完全否认正确答案”

这就是负对数损失(Negative Log Loss) 的核心思想。


三、交叉熵 = 负对数损失(在分类任务中)

在多分类问题中,我们通常用 one-hot 编码表示真实标签:

  • 真实是“猫” → 标签向量为 [1,0,0][1, 0, 0]
  • 真实是“狗” → [0,1,0][0, 1, 0]

而模型输出是一个概率分布,比如 [0.7,0.2,0.1][0.7, 0.2, 0.1]

🔢 交叉熵是怎么算的?——手把手演示

公式Loss=i=1Kyilog(y^i)\text{Loss} = -\sum_{i=1}^K y_i \log(\hat{y}_i)

其中:

  • KK 是类别总数(这里是 3);
  • yiy_i 是真实标签(one-hot 向量);
  • y^i\hat{y}_i 是模型预测的概率。

👉 举个具体例子:

  • 真实标签(猫):   y=[1, 0, 0]y = [1,\ 0,\ 0]
  • 模型预测概率:  y^=[0.7, 0.2, 0.1]\hat{y} = [0.7,\ 0.2,\ 0.1]

代入公式:

第 1 步:写出求和项

Loss=(y1log(y^1)+y2log(y^2)+y3log(y^3))\text{Loss} = -\big( y_1 \cdot \log(\hat{y}1) + y_2 \cdot \log(\hat{y}2) + y_3 \cdot \log(\hat{y}_3) \big)

第 2 步:代入数值

=(1log(0.7)+0log(0.2)+0log(0.1))= -\big( 1 \cdot \log(0.7) + 0 \cdot \log(0.2) + 0 \cdot \log(0.1) \big)

第 3 步:简化(0 乘任何数为 0)

=(log(0.7)+0+0)=log(0.7)= -\big( \log(0.7) + 0 + 0 \big) = -\log(0.7)

第 4 步:计算结果(以自然对数为例)

log(0.7)0.357Loss(0.357)=0.357\log(0.7) \approx -0.357 \quad \Rightarrow \quad \text{Loss} \approx -(-0.357) = 0.357

你看:只有真实类别那一项起作用,其他都是 0!

再试一个更差的例子:

  • 预测:y^=[0.1, 0.8, 0.1]\hat{y} = [0.1,\ 0.8,\ 0.1](把猫错认成狗)
  • 真实:y=[1, 0, 0]y = [1,\ 0,\ 0]

计算: Loss=log(0.1)2.302\text{Loss} = -\log(0.1) \approx 2.302

→ 损失大得多!说明模型犯了高自信错误,被狠狠惩罚。

💡 记住:交叉熵的计算,本质上就是取模型对“正确答案”的预测概率,然后算它的负对数


四、PyTorch 里的一行代码,到底做了什么?

loss = F.cross_entropy(logits, labels)

这行代码背后发生了什么?

  1. logits 是网络最后一层的原始输出(比如 [2.1, 0.5, -1.0]),还不是概率;
  2. PyTorch 先用 Softmax 把 logits 转成概率: y^i=elogitijelogitj\hat{y}_i = \frac{e^{\text{logit}_i}}{\sum_j e^{\text{logit}_j}} → 得到像 [0.7, 0.2, 0.1] 这样的分布;
  3. 再取出真实类别对应的概率 y^true\hat{y}_{\text{true}}
  4. 最后计算 log(y^true)-\log(\hat{y}_{\text{true}}) 作为损失。

整个过程,目标只有一个:让模型对正确答案给出更高的概率


五、为什么不用“1 - 正确概率”?

你可能会想:既然只关心正确类别的概率 pp,为什么不直接用:

Loss=1p?\text{Loss} = 1 - p \quad ?

这看起来更简单!

但问题在于:它对“自信错误”的惩罚不够狠

pp(正确概率)1p1 - plog(p)-\log(p)
0.90.10.11
0.10.92.30
  • 1p1-p:错得很离谱(p=0.1p=0.1)只罚 0.9;
  • log(p)-\log(p):同样情况罚 2.3,超过两倍

而现实中,一个过度自信的错误模型比犹豫的模型更危险(比如医疗诊断、自动驾驶)。 所以,我们需要一个非线性放大惩罚的函数——log(p)\boldsymbol{-\log(p)} 正好满足。


六、总结:三句话记住交叉熵

  1. 交叉熵损失 = log(模型对正确答案的信心)-\log(\text{模型对正确答案的信心})
  2. 模型越自信还越错,损失爆炸式增长
  3. 训练目标:让正确答案的概率尽可能接近 1

你不需要懂“似然”“分布”“信息熵”,只要理解:

损失函数是在教模型:别瞎猜,要诚实,更要准确。