一文讲清 PyTorch 中损失函数(Loss Function)

186 阅读5分钟

我们用生活化比喻 + 场景分类 + 代码示例 + 一句话口诀,向初学者彻底讲清楚:


🎯 PyTorch 中的损失函数(Loss Function)—— 通俗易懂完全指南

💡 一句话总结:损失函数 = 衡量“模型预测”和“真实答案”差多远的尺子。差得越远,损失越大,模型就要越努力学习!


一、生活化比喻:教练打分 🎯

想象你是一个射箭运动员,每次射箭后:

  • 🎯 命中靶心 → 教练打 0 分(完美!)
  • 🎯 偏离 10cm → 教练打 50 分(有进步空间)
  • 🎯 射到隔壁靶 → 教练打 100 分(太差了!)

损失函数 = 教练的打分规则
它告诉模型:“你这次预测差多少?差得越多,分数越高 → 你要调整参数,让下次分数变低!”


二、损失函数的核心作用

  1. 衡量模型预测质量 → 数值化“错误程度”
  2. 指导反向传播 → 损失越大,梯度越大,参数调整越猛
  3. 优化目标 → 训练就是让损失函数值越来越小!

📌 没有损失函数,模型就不知道“往哪个方向改进”!


三、常用损失函数分类与使用场景

我们按任务类型分类讲解:


🧩 1. 回归任务(预测连续值)→ 用 MSELoss, L1Loss

🎯 场景:
  • 预测房价、温度、股票价格
  • 输出是一个或多个连续数值
nn.MSELoss(均方误差)—— 最常用!

“差值平方的平均” —— 对大误差惩罚更重!

import torch
import torch.nn as nn

loss_fn = nn.MSELoss()

y_pred = torch.tensor([2.5, 0.8, 1.2])
y_true = torch.tensor([3.0, 1.0, 1.0])

loss = loss_fn(y_pred, y_true)
print(loss)  # tensor(0.1033) → 越小越好!

# 手动计算:((2.5-3)² + (0.8-1)² + (1.2-1)²) / 3 = (0.25+0.04+0.04)/3 ≈ 0.11
nn.L1Loss(平均绝对误差)—— 对异常值更鲁棒

“差值绝对值的平均” —— 惩罚更均匀

loss_fn = nn.L1Loss()
loss = loss_fn(y_pred, y_true)
print(loss)  # tensor(0.3000) → (0.5 + 0.2 + 0.2) / 3 = 0.3

📌 选择建议

  • 默认用 MSELoss
  • 数据有异常值 → 用 L1Loss

🧩 2. 分类任务(预测类别)→ 用 CrossEntropyLoss, BCELoss

🎯 场景:
  • 图像分类(猫/狗/鸟)
  • 文本分类(正面/负面评论)
  • 输出是类别概率 or 类别标签

nn.CrossEntropyLoss —— 多分类首选!🔥

输入:模型原始输出(logits,未经过 softmax)
目标:真实类别索引(整数,不是 one-hot!)

loss_fn = nn.CrossEntropyLoss()

# 模拟:3个样本,4个类别
logits = torch.tensor([
    [2.0, 1.0, 0.1, 0.5],  # 样本1的预测分数
    [0.5, 2.0, 0.3, 0.1],  # 样本2
    [0.1, 0.2, 3.0, 0.4]   # 样本3
])

# 真实标签:样本1属于第0类,样本2属于第1类,样本3属于第2类
labels = torch.tensor([0, 1, 2])

loss = loss_fn(logits, labels)
print(loss)  # tensor(0.5929)

# ✅ 内部自动做:log_softmax + NLLLoss → 数值稳定!

📌 重要特点

  • 自动加 softmax → 你不用手动加!
  • 输入是 logits(原始分数),不是概率!
  • 目标用类别索引(0, 1, 2...),不是 one-hot!

nn.BCELoss(二元交叉熵)—— 二分类 or 多标签

用于:输出是概率值(0~1之间),目标也是概率(或0/1)

loss_fn = nn.BCELoss()

# 预测概率(必须经过 sigmoid)
pred_probs = torch.tensor([0.9, 0.3, 0.8])
true_labels = torch.tensor([1.0, 0.0, 1.0])  # 也可以是 [1, 0, 1]

loss = loss_fn(pred_probs, true_labels)
print(loss)  # tensor(0.2781)

⚠️ 注意:输入必须是概率 → 通常前面加 sigmoid


nn.BCEWithLogitsLoss —— BCELoss + sigmoid 合体版(推荐!)

输入:原始分数(logits)→ 自动加 sigmoid + BCELoss → 数值更稳定!

loss_fn = nn.BCEWithLogitsLoss()

logits = torch.tensor([2.0, -1.0, 1.5])  # 原始分数
labels = torch.tensor([1.0, 0.0, 1.0])

loss = loss_fn(logits, labels)
print(loss)  # tensor(0.2781) ← 和上面结果一致!

📌 选择建议

  • 多分类 → CrossEntropyLoss
  • 二分类/多标签 → BCEWithLogitsLoss(推荐)或 BCELoss(需手动 sigmoid)

🧩 3. 其他常用损失函数

nn.NLLLoss(负对数似然)—— 通常和 LogSoftmax 搭配
# 一般不用单独用,CrossEntropyLoss 已包含
log_probs = torch.log_softmax(logits, dim=1)
loss = nn.NLLLoss()(log_probs, labels)
nn.SmoothL1Loss —— 回归任务,对异常值更鲁棒(Faster R-CNN 用)

结合了 MSE 和 L1 的优点


四、在训练循环中的标准用法

import torch
import torch.nn as nn
import torch.optim as optim

# 定义模型、损失函数、优化器
model = MyModel()
criterion = nn.CrossEntropyLoss()  # 根据任务选
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(100):
    for x, y_true in dataloader:
        optimizer.zero_grad()           # 1. 清空梯度
        
        y_pred = model(x)               # 2. 前向传播
        loss = criterion(y_pred, y_true) # 3. 计算损失 ← 关键!
        
        loss.backward()                 # 4. 反向传播
        optimizer.step()                # 5. 更新参数

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

五、初学者常见误区 & 注意事项

❌ 误区1:分类任务手动加 softmax + 用 NLLLoss

# ❌ 不推荐(除非特殊需求)
output = model(x)
probs = F.softmax(output, dim=1)
loss = F.nll_loss(torch.log(probs), labels)

# ✅ 推荐:直接用 CrossEntropyLoss
loss = F.cross_entropy(output, labels)  # 内部自动处理,数值更稳定!

❌ 误区2:BCELoss 输入没加 sigmoid

logits = model(x)
# ❌ 错误!BCELoss 要求输入是概率
loss = nn.BCELoss()(logits, labels)  # 可能报错或结果异常

# ✅ 正确1:手动加 sigmoid
loss = nn.BCELoss()(torch.sigmoid(logits), labels)

# ✅ 正确2:用 BCEWithLogitsLoss(推荐)
loss = nn.BCEWithLogitsLoss()(logits, labels)

❌ 误区3:目标格式错误

# ❌ CrossEntropyLoss 不能用 one-hot!
labels_onehot = torch.tensor([[1,0,0], [0,1,0]])  # 错!

# ✅ 要用类别索引
labels = torch.tensor([0, 1])  # 对!

六、总结:损失函数选择速查表

任务类型推荐损失函数输入要求目标要求
回归(连续值)MSELoss(默认)预测值真实值
L1Loss(抗异常值)预测值真实值
多分类CrossEntropyLoss原始分数(logits)类别索引(int)
二分类/多标签BCEWithLogitsLoss原始分数(logits)0/1 或概率
BCELoss概率值(0~1)0/1 或概率

🎁 给初学者的黄金口诀:

“回归用 MSE,分类用 CrossEntropy,
二分类选 BCEWithLogits,
输入输出要看清,
梯度下降靠它行!”


🧠 动手小练习:

  1. MSELoss 计算预测 [1.2, 3.1] 和真实 [1.0, 3.0] 的损失
  2. CrossEntropyLoss 计算 logits [[2.0, 1.0]] 和标签 [0] 的损失
  3. BCEWithLogitsLoss 计算 logits [1.5] 和标签 [1.0] 的损失

(答案在最后 👇)


🎉 恭喜你!现在你已经掌握了损失函数的核心原理和实战用法!
它是连接“模型预测”和“参数更新”的桥梁 —— 没有它,深度学习就转不起来!


小练习答案

  1. MSELoss: ((1.2-1.0)**2 + (3.1-3.0)**2) / 2 = (0.04 + 0.01)/2 = 0.025
  2. CrossEntropyLoss: -log(softmax([2,1])[0]) = -log(e²/(e²+e¹)) ≈ 0.313
  3. BCEWithLogitsLoss: = -[1*log(σ(1.5)) + (1-1)*log(1-σ(1.5))] = -log(σ(1.5)) ≈ 0.2014
    (其中 σ(1.5) = 1/(1+e^{-1.5}) ≈ 0.8176)