一文讲清 nn.Embedding ​嵌入字典

106 阅读5分钟

我们继续用生活化的例子,向初学者讲清楚 PyTorch 中 torch.nn.Embedding 的原理、应用场景和使用方法。


🎯 一句话总结:

Embedding 是一个“查字典”的操作:把离散的整数编号(如单词ID、用户ID、商品ID),映射成连续的、有意义的向量(叫“嵌入向量”),让神经网络能理解它们。


📚 一、用“学生选课系统”打比方 —— 形象理解 Embedding

假设你是学校教务系统的AI助手,要预测“学生喜欢什么课”。

你有这些数据:

  • 学生编号:1~1000(张三=1,李四=2,王五=3...)
  • 课程编号:1~50(数学=1,语文=2,编程=3...)

但神经网络不能直接吃“编号123”,它需要“有意义的向量”。

✅ 于是你给每个学生和课程都配一个“性格/特征向量”:

  • 张三(ID=1) → 向量 [0.8, -0.2, 0.5] (代表:喜欢理科、讨厌背诵、动手能力强)
  • 数学(ID=1) → 向量 [0.9, -0.1, 0.7] (代表:理科、少背诵、重逻辑)

这样,你就可以用“向量相似度”来预测张三是否喜欢数学 → 两个向量点积大,说明匹配!

🧠 Embedding 层,就是这个“编号→向量”的自动查表+学习系统!


⚙️ 二、Embedding 的数学本质

  • 输入:一个整数 i(比如学生ID=3)
  • 输出:一个向量 embedding[i](比如 [0.1, -0.3, 0.6])
  • 这个向量表是可训练的参数矩阵,形状是 (num_embeddings, embedding_dim)

比如:

embedding = nn.Embedding(num_embeddings=1000, embedding_dim=64)
# 表示有1000个词/ID,每个映射成64维向量
# 内部就是一个形状为 [1000, 64] 的矩阵

当你输入 input_ids = [3, 8, 15],它就返回:

output = embedding(input_ids)  # shape: [3, 64]
# 第0行 = 第3个向量,第1行 = 第8个向量,第2行 = 第15个向量

🌍 三、为什么用 Embedding?—— 应用场景

  1. 自然语言处理(NLP):单词 → 词向量(Word Embedding)

    • “猫” → [0.2, 0.8, -0.1, ...]
    • “狗” → [0.3, 0.7, -0.2, ...] (和猫很接近)
  2. 推荐系统:用户ID、商品ID → 向量

    • 用户123喜欢科幻 → 向量靠近“科幻电影”
    • 商品456是科幻片 → 向量也靠近“科幻”
  3. 分类特征编码:代替 One-Hot,更高效、能表达语义

📌 举个实际例子:

你想训练一个模型,输入“用户ID + 电影ID”,输出“用户是否会点击”。
用 Embedding 把用户和电影都变成向量,拼接后输入全连接层,效果远好于直接用ID数字!


💻 四、PyTorch 中怎么用?—— 代码示例(学生选课版)

我们用“学生ID → 学生兴趣向量”来演示:

import torch
import torch.nn as nn

# 假设有5个学生(ID 0~4),每个学生映射成3维兴趣向量
student_embedding = nn.Embedding(num_embeddings=5, embedding_dim=3)

# 查看内部“字典”(初始是随机值)
print("Embedding 矩阵(5个学生,每人3维):")
print(student_embedding.weight.data)
print("形状:", student_embedding.weight.shape)  # torch.Size([5, 3])

# 输入:一批学生ID(比如 [0, 2, 4] 代表张三、李四、王五)
student_ids = torch.tensor([0, 2, 4])

# 查表:获取对应的学生兴趣向量
student_vectors = student_embedding(student_ids)

print("\n输入学生ID:", student_ids)
print("输出向量:")
print(student_vectors)
print("输出形状:", student_vectors.shape)  # torch.Size([3, 3])

📌 输出示例(随机初始化):

Embedding 矩阵(5个学生,每人3维):
tensor([[-0.5832, -0.6224,  0.9778],
        [-0.3370,  0.1267,  0.3845],
        [ 0.9107,  0.5623, -0.1234],
        [ 0.2211, -0.7890,  0.4567],
        [-0.3456,  0.8901, -0.2345]])
形状: torch.Size([5, 3])

输入学生ID: tensor([0, 2, 4])
输出向量:
tensor([[-0.5832, -0.6224,  0.9778],
        [ 0.9107,  0.5623, -0.1234],
        [-0.3456,  0.8901, -0.2345]], grad_fn=<EmbeddingBackward0>)
输出形状: torch.Size([3, 3])

🔍 说明:

  • ID=0 → 取第0行向量
  • ID=2 → 取第2行向量
  • ID=4 → 取第4行向量
  • 这些向量在训练过程中会被优化器更新!

🧠 五、训练时发生了什么?

假设你用这个 Embedding 做推荐:

# 伪代码:预测学生对课程的喜好
student_id = torch.tensor([1])
course_id = torch.tensor([3])

student_vec = student_embedding(student_id)   # [1, 3]
course_vec = course_embedding(course_id)      # [1, 3]

# 计算相似度(比如点积)
similarity = torch.sum(student_vec * course_vec, dim=1)  # [1]

loss = criterion(similarity, label)  # 比如 label=1 表示喜欢
loss.backward()  # 反向传播 → 更新 student_embedding 和 course_embedding 的向量!

👉 经过训练后:

  • 喜欢数学的学生 → 向量靠近“数学”向量
  • 喜欢语文的学生 → 向量靠近“语文”向量
  • 甚至能发现:“喜欢数学的学生”和“喜欢物理的学生”向量很接近!

这就是“语义空间”的自动学习!


✅ 六、初学者常见问题

  1. 输入必须是整数吗?
    → 是的!Embedding 输入必须是 LongTensor(整数索引),不能是浮点数。

  2. ID 能超出范围吗?
    → 不能!如果你定义 num_embeddings=5,那合法ID是 0~4,输入5会报错。

  3. Embedding 和 One-Hot 有什么区别?

    • One-Hot:1000个类 → 1000维稀疏向量(只有一个1)
    • Embedding:1000个类 → 64维稠密向量(全部是小数,有语义)
    • Embedding 更省空间、能表达相似性、效果更好!
  4. 在哪里用?

    • NLP:放在模型最前面,把单词ID转成词向量
    • 推荐系统:把用户ID、物品ID转成向量
    • 图神经网络:节点ID → 节点向量

🎉 七、总结口诀(方便记忆):

“编号查字典,向量有内涵,语义自动学,模型更会算!”


📚 举一反三:

  • nn.Embedding(10000, 128) → 1万个词,每个词128维(如BERT词向量)
  • nn.Embedding(100, 64) → 100个用户,每人64维兴趣向量
  • nn.Embedding(50, 16) → 50种商品,每种16维特征

✅ 最后提醒:

Embedding 层是“可训练参数”,记得在优化器中包含它!它不是固定查表,而是会随着任务不断优化语义表示的“活字典”!


希望这个“学生选课”的例子,让你彻底理解了 Embedding!它就像神经网络的“翻译官”,把人类世界的离散符号,翻译成AI能理解的数学语言 🌟