05-从隐藏向量到文字:LM Head如何输出"下一个词"?

14 阅读3分钟

回顾:大模型的完整流程

在前面的章节中,我们学习了Transformer的各个组件。现在让我们回顾一下完整流程:

输入:"今天天气"(Tokenization + Embedding)Token表示:XRn×dmodel(位置编码)加入位置:X+PE(多层Transformer)Layer 1:Attention + MLP + Residual + LNLayer 2:Attention + MLP + Residual + LNLayer N:Attention + MLP + Residual + LN最终隐藏状态:HRn×dmodel\begin{aligned} &\text{输入:} \quad \text{"今天天气"} \\ &\quad \downarrow \text{(Tokenization + Embedding)} \\ &\text{Token表示:} \quad X \in \mathbb{R}^{n \times d_{\text{model}}} \\ &\quad \downarrow \text{(位置编码)} \\ &\text{加入位置:} \quad X + \text{PE} \\ &\quad \downarrow \text{(多层Transformer)} \\ &\text{Layer 1:} \quad \text{Attention + MLP + Residual + LN} \\ &\text{Layer 2:} \quad \text{Attention + MLP + Residual + LN} \\ &\quad \vdots \\ &\text{Layer N:} \quad \text{Attention + MLP + Residual + LN} \\ &\quad \downarrow \\ &\text{最终隐藏状态:} \quad H \in \mathbb{R}^{n \times d_{\text{model}}} \end{aligned}

问题来了HH 是一个连续的向量(768维或4096维),但我们需要输出的是具体的文字(如"很好"、"不错")。

如何从连续向量变成离散的词?这就是LM Head的作用!

LM Head:语言模型的"输出层"

LM Head(Language Model Head)是Transformer的最后一层,它的作用非常明确:

将Transformer输出的隐藏向量映射到词表空间,为每个词计算概率,然后选择最可能的下一个词

LM Head的结构

LM Head通常就是一个简单的线性层(不带激活函数):

logits=HWlm+blm\text{logits} = H \cdot W_{\text{lm}} + b_{\text{lm}}

参数解释

  • HRn×dmodelH \in \mathbb{R}^{n \times d_{\text{model}}}:Transformer最后一层的输出(每个Token的隐藏表示)
  • WlmRdmodel×VW_{\text{lm}} \in \mathbb{R}^{d_{\text{model}} \times V}:LM Head的权重矩阵
  • blmRVb_{\text{lm}} \in \mathbb{R}^{V}:偏置向量(很多模型不使用偏置)
  • VV:词表大小(Vocabulary size),如50257(GPT-2)、32000(LLaMA)
  • logitsRn×V\text{logits} \in \mathbb{R}^{n \times V}:每个位置对所有词的"得分"(未归一化)

关键点

  • 输入:dmodeld_{\text{model}} 维的连续向量(如768维)
  • 输出:VV 维的分数向量(如32000维),每一维对应词表中的一个词

维度变化示例

假设 GPT-2 模型(dmodel=768d_{\text{model}}=768V=50257V=50257):

H:(n,768)(最后一层输出)Wlm:(768,50257)(LM Head权重)logits:(n,50257)(每个位置对所有词的分数)\begin{aligned} H &: (n, 768) \quad \text{(最后一层输出)} \\ W_{\text{lm}} &: (768, 50257) \quad \text{(LM Head权重)} \\ \text{logits} &: (n, 50257) \quad \text{(每个位置对所有词的分数)} \end{aligned}

对于最后一个位置(预测下一个词):

hlast:(768,)(最后一个Token的表示)logitslast=hlastWlm:(50257,)\begin{aligned} h_{\text{last}} &: (768,) \quad \text{(最后一个Token的表示)} \\ \text{logits}_{\text{last}} &= h_{\text{last}} \cdot W_{\text{lm}} : (50257,) \\ \end{aligned}

这个50257维的向量,每一维代表词表中对应词的"得分":

logitslast=[s1,s2,s3,,s50257]\text{logits}_{\text{last}} = [s_1, s_2, s_3, \ldots, s_{50257}]
  • s1s_1:词表中第1个词(如 <pad>)的得分
  • s2s_2:词表中第2个词(如 <unk>)的得分
  • s100s_{100}:词表中第100个词(如 "the")的得分
  • ...

得分越高,表示这个词越可能是下一个词。

从Logits到概率:Softmax归一化

Logits只是"得分",不是概率(可以是负数、可以很大)。我们需要将它们转换为概率分布

P(wi)=esij=1Vesj=softmax(logits)iP(w_i) = \frac{e^{s_i}}{\sum_{j=1}^{V} e^{s_j}} = \text{softmax}(\text{logits})_i

性质

  • 0P(wi)10 \leq P(w_i) \leq 1(每个概率在0-1之间)
  • i=1VP(wi)=1\sum_{i=1}^{V} P(w_i) = 1(所有概率加起来等于1)

具体例子

假设最后一个位置的logits(简化为5个词):

logitslast=[2.3,1.5,3.8,0.5,1.2]\text{logits}_{\text{last}} = [2.3, 1.5, 3.8, 0.5, 1.2]

对应词表中的5个词:["很", "好", "不错", "真", "差"]

应用Softmax

P("很")=e2.3e2.3+e1.5+e3.8+e0.5+e1.2=9.979.97+4.48+44.70+1.65+3.32=0.155P("好")=e1.564.12=0.070P("不错")=e3.864.12=0.697(最高!)P("真")=e0.564.12=0.026P("差")=e1.264.12=0.052\begin{aligned} P(\text{"很"}) &= \frac{e^{2.3}}{e^{2.3} + e^{1.5} + e^{3.8} + e^{0.5} + e^{1.2}} = \frac{9.97}{9.97 + 4.48 + 44.70 + 1.65 + 3.32} = 0.155 \\ P(\text{"好"}) &= \frac{e^{1.5}}{64.12} = 0.070 \\ P(\text{"不错"}) &= \frac{e^{3.8}}{64.12} = 0.697 \quad \text{(最高!)} \\ P(\text{"真"}) &= \frac{e^{0.5}}{64.12} = 0.026 \\ P(\text{"差"}) &= \frac{e^{1.2}}{64.12} = 0.052 \end{aligned}

概率分布:

Logit概率
2.315.5%
1.57.0%
不错3.869.7%
0.52.6%
1.25.2%

"不错"得分最高,概率最大,很可能被选为下一个词。

采样策略:如何选择下一个词?

有了概率分布后,如何选择下一个词?有多种策略:

1. Greedy Decoding(贪心解码)

最简单的方法:直接选概率最高的词

wnext=argmaxiP(wi)w_{\text{next}} = \arg\max_{i} P(w_i)

优点

  • 简单、确定性(每次输出相同)
  • 速度快

缺点

  • 输出单调、缺乏多样性
  • 容易陷入重复("我觉得我觉得我觉得...")
  • 可能错过全局最优解

代码

# logits: (vocab_size,)
probs = torch.softmax(logits, dim=-1)
next_token = torch.argmax(probs)  # 选择概率最大的

2. Random Sampling(随机采样)

按概率分布随机采样

wnextP(w)=softmax(logits)w_{\text{next}} \sim P(w) = \text{softmax}(\text{logits})

概率高的词更可能被选中,但不是绝对的。

优点

  • 输出多样化
  • 可以探索不同的生成路径

缺点

  • 有时会采样到低概率的"坏词"
  • 输出质量不稳定

代码

probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)  # 按概率采样

3. Temperature Sampling(温度采样)

在softmax之前,用温度参数 TT 缩放logits:

P(wi)=esi/Tjesj/TP(w_i) = \frac{e^{s_i / T}}{\sum_{j} e^{s_j / T}}

参数解释

  • T=1T = 1:标准softmax(不改变)
  • T<1T < 1(如0.5):"降温",概率分布更陡峭,偏向高概率词
  • T>1T > 1(如1.5):"升温",概率分布更平缓,增加多样性

直观理解

假设原始logits:[2.3,1.5,3.8,0.5,1.2][2.3, 1.5, 3.8, 0.5, 1.2]

低温 T=0.5T=0.5

scaled_logits=[4.6,3.0,7.6,1.0,2.4]\text{scaled\_logits} = [4.6, 3.0, 7.6, 1.0, 2.4]

Softmax后:

原始概率T=0.5T=0.5概率
15.5%3.8%
7.0%0.8%
不错69.7%94.2% ⬆️
2.6%0.1%
5.2%0.5%

高概率词("不错")的概率被进一步放大!

高温 T=1.5T=1.5

scaled_logits=[1.53,1.0,2.53,0.33,0.8]\text{scaled\_logits} = [1.53, 1.0, 2.53, 0.33, 0.8]

Softmax后:

原始概率T=1.5T=1.5概率
15.5%21.2% ⬆️
7.0%12.6% ⬆️
不错69.7%58.1% ⬇️
2.6%6.4% ⬆️
5.2%10.3% ⬆️

概率分布更均匀,其他词的机会增加!

使用场景

  • T<1T < 1:需要确定性、准确性的任务(如翻译、摘要)
  • T=1T = 1:平衡点
  • T>1T > 1:需要创意、多样性的任务(如故事生成、头脑风暴)

代码

temperature = 0.8
logits_scaled = logits / temperature
probs = torch.softmax(logits_scaled, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)

4. Top-K Sampling

只从概率最高的K个词中采样

  1. 对概率排序,保留前K个词
  2. 其余词的概率设为0
  3. 重新归一化
  4. 从这K个词中按概率采样

举例K=3K=3):

原始概率:

概率
不错69.7%
15.5%
7.0%
5.2%
2.6%

保留Top-3,重新归一化:

新概率
不错69.7% / (69.7+15.5+7.0) = 75.6%
15.5% / 92.2% = 16.8%
7.0% / 92.2% = 7.6%
0%
0%

优点

  • 过滤掉明显不合适的低概率词
  • 保持一定多样性

缺点

  • K是固定的,不够灵活
  • 有时Top-K之外还有合理的词

代码

top_k = 50
# 获取top-k的索引和值
top_k_probs, top_k_indices = torch.topk(probs, top_k)
# 重新归一化
top_k_probs = top_k_probs / top_k_probs.sum()
# 从top-k中采样
next_token_idx = torch.multinomial(top_k_probs, num_samples=1)
next_token = top_k_indices[next_token_idx]

5. Top-P Sampling(Nucleus Sampling)

动态选择最小的词集合,使得累积概率达到P

  1. 对概率从高到低排序
  2. 累加概率,直到达到阈值P(如0.9)
  3. 只保留这些词
  4. 重新归一化并采样

举例P=0.9P=0.9):

概率累积概率
不错69.7%69.7%
15.5%85.2%
7.0%92.2% ✅ 达到90%
5.2%97.4%
2.6%100%

保留前3个词(累积概率刚好超过90%):

新概率
不错69.7% / 92.2% = 75.6%
15.5% / 92.2% = 16.8%
7.0% / 92.2% = 7.6%

优点

  • 自适应:概率分布陡峭时,选择少数词;平缓时,选择更多词
  • 更灵活than Top-K
  • 实践效果好

缺点

  • 计算稍复杂

代码

top_p = 0.9
# 降序排序
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
# 计算累积概率
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
# 找到累积概率>top_p的位置
mask = cumsum_probs > top_p
# 保留第一个超过top_p的词(确保至少有一个)
mask[1:] = mask[:-1].clone()
mask[0] = False
# 过滤
sorted_probs[mask] = 0
# 重新归一化
sorted_probs = sorted_probs / sorted_probs.sum()
# 采样
next_token_idx = torch.multinomial(sorted_probs, num_samples=1)
next_token = sorted_indices[next_token_idx]

采样策略对比

策略多样性质量稳定性计算复杂度适用场景
Greedy翻译、问答
Random不推荐
Temperature可调通用
Top-K通用
Top-P自适应推荐⭐

实践中的组合

通常会结合多种策略

# Temperature + Top-P(最常用)
temperature = 0.8
top_p = 0.9

logits_scaled = logits / temperature
probs = torch.softmax(logits_scaled, dim=-1)

# 应用Top-P过滤
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum_probs > top_p
mask[1:] = mask[:-1].clone()
mask[0] = False
sorted_probs[mask] = 0
sorted_probs = sorted_probs / sorted_probs.sum()

# 采样
next_token = sorted_indices[torch.multinomial(sorted_probs, 1)]

Embedding权重共享(Weight Tying)

LM Head的权重矩阵 WlmRdmodel×VW_{\text{lm}} \in \mathbb{R}^{d_{\text{model}} \times V} 非常大!

对于GPT-3(dmodel=12288d_{\text{model}}=12288V=50257V=50257):

参数量=12288×50257=617,155,776617M\text{参数量} = 12288 \times 50257 = 617{,}155{,}776 \approx 617M

这占了模型总参数的很大一部分!

Token Embedding:从词到向量的学习过程

在讨论权重共享之前,我们先理解Token Embedding层是如何训练的

什么是Token Embedding?

Token Embedding层是模型的第一层,它的作用是将离散的Token ID转换为连续的向量:

EtokenRV×dmodelE_{\text{token}} \in \mathbb{R}^{V \times d_{\text{model}}}

工作原理

对于输入的Token ID(如1234),Embedding层就是一个**查表(lookup)**操作:

token_id=1234vector=Etoken[1234,:]Rdmodel\text{token\_id} = 1234 \quad \Rightarrow \quad \text{vector} = E_{\text{token}}[1234, :] \in \mathbb{R}^{d_{\text{model}}}

这个向量就是Token 1234的表示。

举例(简化为5维向量):

假设词表有3个词:

Token IDTokenEmbedding向量
0"今天"[0.12, -0.34, 0.56, 0.23, -0.45]
1"天气"[0.87, 0.21, -0.32, 0.54, 0.11]
2"很好"[-0.23, 0.67, 0.89, -0.12, 0.34]

输入序列:"今天 天气"(Token IDs: [0, 1])

Etoken[0]=[0.12,0.34,0.56,0.23,0.45]("今天"的向量)Etoken[1]=[0.87,0.21,0.32,0.54,0.11]("天气"的向量)\begin{aligned} E_{\text{token}}[0] &= [0.12, -0.34, 0.56, 0.23, -0.45] \quad \text{("今天"的向量)} \\ E_{\text{token}}[1] &= [0.87, 0.21, -0.32, 0.54, 0.11] \quad \text{("天气"的向量)} \end{aligned}

Token Embedding 需要训练吗?

答案:绝对需要!

Token Embedding矩阵 EtokenE_{\text{token}}可学习的参数,和权重矩阵 WQW_QW1W_1 等完全一样,通过梯度下降训练。

1. 初始化

训练开始前,Embedding矩阵需要随机初始化:

Etoken[i,:]N(0,σ2)E_{\text{token}}[i, :] \sim \mathcal{N}(0, \sigma^2)

通常使用:

  • 正态分布初始化σ=0.02\sigma = 0.02(GPT系列)
  • 均匀分布初始化U(3/dmodel,3/dmodel)\mathcal{U}(-\sqrt{3/d_{\text{model}}}, \sqrt{3/d_{\text{model}}})

初始状态:每个词的向量是随机的,完全没有语义!

import torch
import torch.nn as nn

vocab_size = 50257
d_model = 768

# 创建Embedding层
token_embedding = nn.Embedding(vocab_size, d_model)

# 查看初始化后的值
print("Token 0 的初始embedding:", token_embedding.weight[0][:5])
# 输出:tensor([-0.0134,  0.0089, -0.0156,  0.0201, -0.0178])

print("Token 1 的初始embedding:", token_embedding.weight[1][:5])
# 输出:tensor([ 0.0167, -0.0145,  0.0123, -0.0098,  0.0134])

# 完全是随机值,没有任何语义!

2. 前向传播

在前向传播中,Embedding层将Token IDs转换为向量:

X=Etoken[token_ids]X = E_{\text{token}}[\text{token\_ids}]

举例

# 输入序列:[1234, 5678, 9012]
token_ids = torch.tensor([1234, 5678, 9012])

# 查表得到embedding
X = token_embedding(token_ids)  # shape: (3, 768)

# X[0] = E_token[1234, :]
# X[1] = E_token[5678, :]
# X[2] = E_token[9012, :]

这些向量会通过Transformer层,最终产生输出。

3. 反向传播

当损失函数的梯度反向传播时,会传到Embedding层:

LEtoken[i]=LX[j](如果token_ids[j] = i)\frac{\partial L}{\partial E_{\text{token}}[i]} = \frac{\partial L}{\partial X[j]} \quad \text{(如果token\_ids[j] = i)}

关键点

  • 只有出现在输入序列中的Token的embedding会收到梯度
  • 没出现的Token的embedding在这个batch中保持不变

举例

假设输入序列是 [1234, 5678, 9012]

  • Etoken[1234]E_{\text{token}}[1234] 会收到梯度 LX[0]\frac{\partial L}{\partial X[0]}
  • Etoken[5678]E_{\text{token}}[5678] 会收到梯度 LX[1]\frac{\partial L}{\partial X[1]}
  • Etoken[9012]E_{\text{token}}[9012] 会收到梯度 LX[2]\frac{\partial L}{\partial X[2]}
  • 其他49999个Token的embedding在这个batch中不更新

4. 参数更新

使用优化器更新Embedding矩阵:

Etoken[i]Etoken[i]ηLEtoken[i]E_{\text{token}}[i] \leftarrow E_{\text{token}}[i] - \eta \cdot \frac{\partial L}{\partial E_{\text{token}}[i]}

和其他参数完全一样的训练过程!

5. 训练后的语义

经过大量数据的训练,Embedding矩阵会学到有意义的语义表示:

相似词的向量会接近

# 训练后(示意)
E_token["国王"] ≈ [0.23, 0.56, -0.12, ..., 0.45]
E_token["女王"] ≈ [0.25, 0.54, -0.10, ..., 0.43]  # 很接近!

E_token["男人"] ≈ [0.67, 0.21, -0.34, ..., 0.12]
E_token["女人"] ≈ [0.69, 0.19, -0.32, ..., 0.10]  # 很接近!

# 著名的关系:king - man + woman ≈ queen

计算余弦相似度

similarity("国王","女王")=Etoken["国王"]Etoken["女王"]Etoken["国王"]Etoken["女王"]0.85\text{similarity}(\text{"国王"}, \text{"女王"}) = \frac{E_{\text{token}}[\text{"国王"}] \cdot E_{\text{token}}[\text{"女王"}]}{\|E_{\text{token}}[\text{"国王"}]\| \cdot \|E_{\text{token}}[\text{"女王"}]\|} \approx 0.85
similarity("国王","苹果")0.12(不相关)\text{similarity}(\text{"国王"}, \text{"苹果"}) \approx 0.12 \quad \text{(不相关)}

6. Embedding的参数量

对于GPT-3(V=50257V=50257dmodel=12288d_{\text{model}}=12288):

Embedding参数量=V×dmodel=50257×12288=617,155,776617M\text{Embedding参数量} = V \times d_{\text{model}} = 50257 \times 12288 = 617{,}155{,}776 \approx 617M

和LM Head的参数量一样大!(如果不共享权重的话)

7. 稀疏更新的效率

由于每个batch只更新出现的Token,Embedding层的更新是稀疏的

  • 总Token数:50257
  • 每个batch出现的Token:~100-1000
  • 更新比例:<2%

这也是为什么Embedding训练需要大量数据——需要让每个Token都有足够的训练机会。

8. 与位置编码的区别

Token Embedding(可学习):

  • 表示"这是什么词"
  • 通过训练学习
  • 每个Token有独立的向量

位置编码(一般是固定的):

  • 表示"词在哪个位置"
  • 可以是固定公式(Sinusoidal)或可学习参数
  • 不同位置有不同的向量

两者相加:

Xinput=Etoken[token_ids]+PE[positions]X_{\text{input}} = E_{\text{token}}[\text{token\_ids}] + \text{PE}[\text{positions}]

完整的训练流程图

训练前:
E_token = 随机初始化

第1个epoch:
输入:"今天天气很好" → [1234, 5678, 9012, 4567]
  ↓
E_token[1234], E_token[5678], E_token[9012], E_token[4567] 被使用
  ↓
通过Transformer → 计算Loss
  ↓
反向传播 → 这4个Token的embedding收到梯度
  ↓
E_token[1234], E_token[5678], E_token[9012], E_token[4567] 被更新
(其他49,999个Token不变)

第2个epoch:
输入:"天气真不错" → [5678, 3456, 7890, 2345]
  ↓
又有4个Token的embedding被更新
...

经过数百万个样本:
  ↓
所有Token都被更新过很多次
  ↓
E_token 学到了丰富的语义表示!

总结:Token Embedding的训练

特性Token Embedding权重矩阵 (W₁, W₂)
是否可学习✅ 是✅ 是
初始化方式正态分布(σ=0.02)He/Xavier初始化
训练方式梯度下降(稀疏更新)梯度下降(稠密更新)
参数量V×dmodelV \times d_{\text{model}}din×doutd_{\text{in}} \times d_{\text{out}}
语义训练后学到词义训练后学到变换规则

关键点

  • Token Embedding 不是预定义的,而是从随机初始化开始训练的
  • 它是模型参数的重要组成部分(占比可达20%+)
  • 训练后会自动学到语义相似性(相似词的向量接近)
  • 每个batch只更新出现的Token(稀疏更新)

什么是权重共享?

现在我们理解了Token Embedding也是训练出来的,让我们看权重共享的概念。

模型的第一层是Token Embedding层

EtokenRV×dmodelE_{\text{token}} \in \mathbb{R}^{V \times d_{\text{model}}}

它将词表中的每个词映射到 dmodeld_{\text{model}} 维向量。

观察

  • Embedding矩阵:(V,dmodel)(V, d_{\text{model}})
  • LM Head矩阵:(dmodel,V)(d_{\text{model}}, V)

两者是转置关系

权重共享(Weight Tying):让LM Head直接使用Embedding矩阵的转置:

Wlm=EtokenTW_{\text{lm}} = E_{\text{token}}^T

为什么可以共享?

直观理解

  • Embedding:词 → 向量("猫" → [0.1,0.5,...,0.3][0.1, 0.5, ..., 0.3]
  • LM Head:向量 → 词([0.1,0.5,...,0.3][0.1, 0.5, ..., 0.3] → "猫")

它们在做相反的事情!如果一个词的embedding向量是 vv,那么当隐藏状态接近 vv 时,应该输出这个词。

数学上

logiti=hWlm[:,i]=hEtoken[i,:]T=hei\text{logit}_i = h \cdot W_{\text{lm}}[:, i] = h \cdot E_{\text{token}}[i, :]^T = h \cdot e_i

即:词 wiw_i 的logit等于隐藏状态 hh 与该词的embedding eie_i 的点积(相似度)。

越相似,logit越大,该词越可能被选中!

权重共享的优缺点

优点

  1. 大幅减少参数量

    • 不共享:Embedding参数 + LM Head参数
    • 共享:只有Embedding参数
    • 节省:617M617M 参数(对于GPT-3)
  2. 理论优雅

    • Embedding和LM Head在语义空间中对称
    • 鼓励模型学习一致的表示
  3. 正则化效果

    • 相当于对两个矩阵施加了约束
    • 可能提高泛化能力

缺点

  1. 灵活性降低

    • Embedding和LM Head被强制对称
    • 可能限制表达能力
  2. 实践中效果不总是最好

    • 小模型上效果好(参数少,需要正则化)
    • 大模型上效果不明显(参数足够,不需要强约束)

现代模型的选择

模型是否共享原因
BERT✅ 共享小模型(110M-340M),节省参数
GPT-2✅ 共享中等模型(117M-1.5B),节省参数
GPT-3❌ 不共享大模型(175B),参数足够
LLaMA❌ 不共享大模型(7B-65B),追求性能
T5❌ 不共享编码器-解码器架构,更复杂

趋势:随着模型规模增大,越来越多的模型选择不共享,以获得更大的灵活性和表达能力。

代码实现

不共享权重

class TransformerLM(nn.Module):
    def __init__(self, vocab_size=50257, d_model=768):
        super().__init__()
        # Token Embedding
        self.token_embedding = nn.Embedding(vocab_size, d_model)

        # Transformer layers
        self.transformer = nn.ModuleList([...])

        # LM Head(独立的权重)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, input_ids):
        x = self.token_embedding(input_ids)  # (batch, seq, d_model)

        # 通过Transformer
        for layer in self.transformer:
            x = layer(x)

        # LM Head
        logits = self.lm_head(x)  # (batch, seq, vocab_size)
        return logits

共享权重

class TransformerLM_Tied(nn.Module):
    def __init__(self, vocab_size=50257, d_model=768):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.ModuleList([...])

        # LM Head使用Embedding的转置
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        # 关键:权重绑定
        self.lm_head.weight = self.token_embedding.weight

    def forward(self, input_ids):
        x = self.token_embedding(input_ids)

        for layer in self.transformer:
            x = layer(x)

        # LM Head使用共享的权重
        logits = self.lm_head(x)  # 实际上是 x @ token_embedding.weight.T
        return logits

手动实现共享

# 更显式的写法
def forward(self, input_ids):
    x = self.token_embedding(input_ids)  # (batch, seq, d_model)

    for layer in self.transformer:
        x = layer(x)

    # 手动使用embedding权重的转置
    logits = torch.matmul(x, self.token_embedding.weight.T)  # (batch, seq, vocab)
    return logits

完整的生成流程

让我们把所有内容串起来,看一个完整的文本生成例子:

输入:"今天天气"

目标:生成下一个词

步骤1:Tokenization

"今天天气"  [1234, 5678, 9012]

步骤2:Embedding + 位置编码

X=Etoken[[1234,5678,9012]]+PEX = E_{\text{token}}[[1234, 5678, 9012]] + \text{PE}
XR3×768X \in \mathbb{R}^{3 \times 768}

步骤3:通过Transformer层

X1=TransformerLayer1(X)X2=TransformerLayer2(X1)H=TransformerLayer12(X11)\begin{aligned} X_1 &= \text{TransformerLayer}_1(X) \\ X_2 &= \text{TransformerLayer}_2(X_1) \\ &\vdots \\ H &= \text{TransformerLayer}_{12}(X_{11}) \end{aligned}
HR3×768H \in \mathbb{R}^{3 \times 768}

步骤4:取最后一个位置

hlast=H[2,:]R768h_{\text{last}} = H[2, :] \in \mathbb{R}^{768}

这个向量包含了"今天天气"后面应该接什么的所有信息。

步骤5:LM Head映射到词表

logits=hlastWlmR50257\text{logits} = h_{\text{last}} \cdot W_{\text{lm}} \in \mathbb{R}^{50257}

假设结果(简化):

logits = {
    "很": 2.3,
    "好": 1.5,
    "不错": 3.8,
    "真": 0.5,
    "差": 1.2,
    ...
}

步骤6:Softmax归一化

probs=softmax(logits)\text{probs} = \text{softmax}(\text{logits})
probs = {
    "很": 15.5%,
    "好": 7.0%,
    "不错": 69.7%,  # 最高
    "真": 2.6%,
    "差": 5.2%,
    ...
}

步骤7:采样(Top-P,p=0.9)

# 累积概率达到90%的词:["不错", "很", "好"]
# 重新归一化后按概率采样
next_token = sample(["不错", "很", "好"], probs=[0.756, 0.168, 0.076])
# 结果:next_token = "不错"

步骤8:输出

"今天天气" + "不错" = "今天天气不错"

继续生成(自回归)

如果要继续生成:

  1. 将"不错"加入输入序列
  2. 重复步骤2-7
  3. 生成下一个词(如",")
  4. 继续迭代...

最终可能生成:

"今天天气不错,适合出去玩。"

LM Head的参数量和计算量

参数量

LM Head参数量=dmodel×V\text{LM Head参数量} = d_{\text{model}} \times V
模型dmodeld_{\text{model}}VVLM Head参数量占总参数比例
BERT-Base76830,52223M21%
GPT-276850,25739M26%
GPT-312,28850,257617M0.35%
LLaMA-7B4,09632,000131M1.9%
LLaMA-65B8,19232,000262M0.4%

观察

  • 小模型:LM Head占比很大(20%+)
  • 大模型:LM Head占比很小(<2%)

原因:LM Head的参数量与词表大小 VV 成正比,与模型深度无关。大模型通过增加层数和宽度来扩大规模,但词表大小基本不变,所以LM Head的占比下降。

计算量

每次生成一个Token的计算量:

计算量=dmodel×V\text{计算量} = d_{\text{model}} \times V

对于LLaMA-7B(dmodel=4096d_{\text{model}}=4096V=32000V=32000):

计算量=4096×32000=131,072,000131M FLOPs\text{计算量} = 4096 \times 32000 = 131{,}072{,}000 \approx 131M \text{ FLOPs}

对比

  • 一个Transformer层的MLP:2×4096×11008×2180M2 \times 4096 \times 11008 \times 2 \approx 180M FLOPs
  • LM Head与一个MLP层的计算量相当

虽然参数占比小,但计算量不可忽略!

小结

  1. LM Head的作用

    • 将Transformer输出的连续向量映射到词表空间
    • 为每个词计算logit(得分)
    • 通过softmax转换为概率分布
  2. 核心公式

    logits=HWlmRn×VP(w)=softmax(logits)\begin{aligned} \text{logits} &= H \cdot W_{\text{lm}} \in \mathbb{R}^{n \times V} \\ P(w) &= \text{softmax}(\text{logits}) \end{aligned}
  3. 采样策略

    • Greedy:选最大概率(确定性,缺乏多样性)
    • Temperature:调整概率分布的陡峭程度
    • Top-K:只从前K个词中采样
    • Top-P:自适应选择词集合(推荐⭐)
  4. 权重共享

    • 小模型:常用权重共享,节省参数
    • 大模型:常用独立权重,追求性能
    • 公式:Wlm=EtokenTW_{\text{lm}} = E_{\text{token}}^T
  5. 参数量分析

    • 小模型中占比大(20%+)
    • 大模型中占比小(<2%)
    • 但计算量不可忽略
  6. 完整流程

    • Tokenization → Embedding → Transformer → LM Head → Softmax → Sampling → Token

LM Head看似简单(就是一个线性层),但它是连接模型内部表示和外部文字的关键桥梁,没有它,再强大的Transformer也无法输出一个字!