一文讲清 nn.functional.One-Hot 编码

66 阅读4分钟

本文向初学者讲清楚:什么是 One-Hot 编码?它在 torch 中如何实现?原理、应用场景和使用方法是什么?

依然用生活化例子 👇


📚 一、用“班级点名册”打比方 —— 形象理解 One-Hot

假设你是班主任,班上有 5 个学生:

学号姓名
0张三
1李四
2王五
3赵六
4小美

现在校长问你:“今天谁迟到了?” 你说:“学号 2 的王五迟到了。”

但 AI 系统看不懂“学号2”,它需要你用“开关信号”表示:

✅ 你画了一个“点名表”,每人一列,迟到的打 1,其他人打 0:

[0, 0, 1, 0, 0]
 ↑ ↑ ↑ ↑ ↑
张 李 王 赵 小
三 四 五 六 美

👉 这就是 One-Hot 编码“只有一位是1,其他全是0”,用来表示“哪个类别被选中”。


⚙️ 二、One-Hot 编码的数学原理

  • 输入:一个整数 i(类别索引,从0开始)
  • 输出:一个长度为 num_classes 的向量,只有第 i 位是 1,其余是 0

例子:

  • 类别总数 = 5
  • 输入 = 2 → 输出 = [0, 0, 1, 0, 0]
  • 输入 = 0 → 输出 = [1, 0, 0, 0, 0]

🧠 它把“离散类别”变成“稀疏向量”,便于神经网络处理(比如做分类任务的标签)。


🌍 三、为什么用 One-Hot?—— 应用场景

  1. 分类任务的标签编码(最常见!)

    • 你有3类:猫=0,狗=1,鸟=2
    • 标签“狗” → [0, 1, 0]
    • 配合 CrossEntropyLossNLLLoss 使用(注意:有些损失函数不需要one-hot!)
  2. 代替原始整数输入(当不想用 Embedding 时)

    • 比如:星期几(0~6)→ 变成7维one-hot,避免模型误以为“6比0大”
  3. 特征工程中表示离散变量

    • 性别:男=0,女=1 → [1,0][0,1]
    • 颜色:红=0,绿=1,蓝=2 → [1,0,0], [0,1,0], [0,0,1]

📌 举个实际例子:

你想训练一个模型识别手写数字(0~9),输出层有10个神经元。
标签“5”不能直接输入5,而要变成 [0,0,0,0,0,1,0,0,0,0],这样第5个神经元输出高就表示预测正确!


💻 四、PyTorch 中怎么用?—— 代码示例(班级点名版)

import torch
import torch.nn.functional as F

# 假设:班上有5个学生,学号 0~4
# 今天迟到的学生学号是:2(王五)和 0(张三)
student_ids = torch.tensor([2, 0])  # shape: [2]

print("迟到的学生学号:", student_ids)

# 转成 One-Hot 编码,总类别数 = 5
one_hot = F.one_hot(student_ids, num_classes=5)

print("\nOne-Hot 编码结果:")
print(one_hot)
print("形状:", one_hot.shape)  # [2, 5]

📌 输出:

迟到的学生学号: tensor([2, 0])

One-Hot 编码结果:
tensor([[0, 0, 1, 0, 0],
        [1, 0, 0, 0, 0]])
形状: torch.Size([2, 5])

🔍 说明:

  • 第一行:学号2 → 第2位是1
  • 第二行:学号0 → 第0位是1
  • 每行是一个学生的 one-hot 表示

🧠 五、One-Hot vs Embedding 对比(初学者必看!)

特性One-HotEmbedding
输入类别ID(整数)类别ID(整数)
输出稀疏向量(很长,只有一个1)稠密向量(较短,都是小数)
维度= 类别总数(如1000类→1000维)自定义(如1000类→64维)
是否可训练❌ 固定编码✅ 向量会学习优化
是否表达语义❌ 只是“谁被选中”,无语义关系✅ “猫”和“狗”向量会靠近
适用场景标签、小类别特征大类别、需语义、推荐/NLP

👉 简单说:

  • 类别少(<10),用 One-Hot 没问题
  • 类别多(用户ID、单词ID),用 Embedding 更高效、更智能

✅ 六、初学者常见误区 & 注意事项

  1. CrossEntropyLoss 不需要 One-Hot!
    很多人搞错!PyTorch 的 nn.CrossEntropyLoss 期望的 target 是 原始类别整数,不是 one-hot!

    # ✅ 正确用法
    loss_fn = nn.CrossEntropyLoss()
    output = model(x)        # shape: [batch, num_classes]
    target = torch.tensor([2, 0, 1])  # 直接传类别ID!
    loss = loss_fn(output, target)
    
    # ❌ 错误用法(除非你用BCELoss)
    target_onehot = F.one_hot(target, num_classes=3).float()
    loss = loss_fn(output, target_onehot)  # 会报错!
    

    如果你真想用 one-hot + 交叉熵,要用 BCEWithLogitsLoss 或手动实现。

  2. One-Hot 是“表示方法”,不是“层”
    它没有可训练参数,只是数据预处理或标签转换。

  3. 维度爆炸问题
    10000个商品 → one-hot 就是10000维 → 占内存、难训练 → 这时请用 Embedding!


🎉 七、总结口诀(方便记忆):

“一位热情似火,其余冷若冰霜;分类标签常用,类别多了别上!”


📚 举一反三:

  • 手写数字(0~9)→ 10维 one-hot
  • 星期几(0~6)→ 7维 one-hot
  • 性别(0=男,1=女)→ 2维 one-hot

✅ 最后提醒:

One-Hot 是基础但重要的编码方式,尤其在做多分类任务标签时非常常用。但记住:不是所有损失函数都吃它!CrossEntropyLoss 吃的是原始整数标签!


希望这个“班级点名”的例子,让你彻底理解了 One-Hot 编码!它就像“举手点名”——谁被点到谁举手(=1),其他人放下手(=0)✋