本文深度剖析Transformer的核心机制——自注意力(Self-Attention)和多头注意力(Multi-Head Attention)。通过数学推导、可视化图表和PyTorch代码实现,详细讲解QKV矩阵计算、注意力分数、缩放点积注意力等关键技术。涵盖Transformer Block完整结构、残差连接、层归一化等工程实践要点,是理解现代大语言模型架构的必读教程。
一、Transformer架构全景
1.1 Transformer诞生的背景
2017年Google发表的《Attention Is All You Need》彻底改变了NLP领域。其核心创新:完全抛弃RNN/CNN,只用注意力机制。
graph TB
subgraph 传统方法
A["RNN/LSTM"] --> B["串行计算<br/>梯度消失<br/>难以并行"]
end
subgraph Transformer革命
C["Self-Attention"] --> D["并行计算<br/>长距离依赖<br/>可扩展性强"]
end
B -.2017年突破.-> D
style B fill:#ffcdd2
style D fill:#a5d6a7
1.2 完整架构图
graph TB
subgraph Encoder[编码器 N×]
Input1[输入嵌入] --> PE1[位置编码]
PE1 --> MHA1[多头注意力]
MHA1 --> Add1[Add & Norm]
Add1 --> FFN1[前馈网络]
FFN1 --> Add2[Add & Norm]
end
subgraph Decoder[解码器 N×]
Input2[输出嵌入] --> PE2[位置编码]
PE2 --> MHA2[Masked<br/>多头注意力]
MHA2 --> Add3[Add & Norm]
Add3 --> MHA3[交叉注意力]
MHA3 --> Add4[Add & Norm]
Add4 --> FFN2[前馈网络]
FFN2 --> Add5[Add & Norm]
end
Add2 --> MHA3
Add5 --> Linear[线性层]
Linear --> Softmax[Softmax]
Softmax --> Output[输出概率]
style MHA1 fill:#fff59d
style MHA2 fill:#fff59d
style MHA3 fill:#fff59d
style Output fill:#a5d6a7
1.3 三大核心组件
| 组件 | 作用 | 关键技术 |
|---|---|---|
| Self-Attention | 捕捉序列内部依赖 | QKV矩阵、缩放点积 |
| Multi-Head Attention | 多视角信息提取 | 多个注意力头并行 |
| Feed-Forward Network | 非线性变换 | 两层全连接+激活 |
二、自注意力机制(Self-Attention)详解
2.1 核心思想
Self-Attention的本质:让序列中的每个元素都能"看到"其他所有元素,并决定关注程度。
graph LR
subgraph 输入句子
W1["我(I)"]
W2["爱(love)"]
W3["AI"]
end
subgraph 自注意力计算
W2 -->|0.2| W1
W2 -->|0.5| W2
W2 -->|0.3| W3
end
Result["增强表示:<br/>爱' = 0.2×我 + 0.5×爱 + 0.3×AI"]
W2 --> Result
style W2 fill:#fff59d
style Result fill:#a5d6a7
2.2 QKV三剑客:Query、Key、Value
核心概念:
- Query (Q): "我要找什么?"(查询向量)
- Key (K): "我是什么?"(键向量)
- Value (V): "我有什么信息?"(值向量)
类比理解:就像图书馆检索系统
sequenceDiagram
participant 你的查询
participant 图书索引
participant 图书内容
你的查询->>图书索引: Query: "机器学习"
图书索引->>图书索引: 计算相关度(Key匹配)
Note over 图书索引: 《深度学习》: 0.9<br/>《数据结构》: 0.2<br/>《Python》: 0.6
图书索引->>图书内容: 按权重提取Value
图书内容->>你的查询: 返回加权信息
2.3 数学公式推导
Step 1: 生成QKV矩阵
对于输入序列 (n个词,每个d维):
graph LR
X["输入矩阵 X<br/>[n × d]"] --> WQ["权重 W^Q<br/>[d × d_k]"]
X --> WK["权重 W^K<br/>[d × d_k]"]
X --> WV["权重 W^V<br/>[d × d_v]"]
WQ --> Q["Query<br/>[n × d_k]"]
WK --> K["Key<br/>[n × d_k]"]
WV --> V["Value<br/>[n × d_v]"]
style X fill:#e3f2fd
style Q fill:#fff9c4
style K fill:#c5e1a5
style V fill:#ce93d8
Step 2: 计算注意力分数
为什么要除以 ?
graph TD
A["问题: 维度d_k增大时<br/>点积值会急剧增长"] --> B["导致Softmax梯度消失"]
B --> C["解决: 缩放因子 √d_k<br/>保持数值稳定"]
style A fill:#ffcdd2
style C fill:#a5d6a7
Step 3: Softmax归一化
Step 4: 加权求和
完整公式:
2.4 直观示例:计算"爱"的新表示
输入句子:"我 爱 学习 AI"
graph TB
subgraph 输入[Step 1: 生成QKV]
X["输入:<br/>我 爱 学习 AI"]
X --> Q["Q矩阵<br/>(查询)"]
X --> K["K矩阵<br/>(键)"]
X --> V["V矩阵<br/>(值)"]
end
subgraph 计算[Step 2-3: 计算注意力]
Q --> Score["QK^T / √d_k"]
K --> Score
Score --> Softmax["Softmax归一化"]
end
subgraph 输出[Step 4: 加权求和]
Softmax --> Weight["权重:<br/>我:0.1 爱:0.2<br/>学习:0.4 AI:0.3"]
V --> Output["新表示:<br/>爱' = Σ(权重×值)"]
Weight --> Output
end
style Score fill:#fff9c4
style Softmax fill:#c5e1a5
style Output fill:#a5d6a7
假设计算"爱"对其他词的注意力:
| 词 | Query·Key | Score | Softmax | 最终贡献 |
|---|---|---|---|---|
| 我 | 2.1 | 2.1/√64=0.26 | 0.1 | 0.1×V(我) |
| 爱 | 3.5 | 0.44 | 0.2 | 0.2×V(爱) |
| 学习 | 7.2 | 0.90 | 0.4 | 0.4×V(学习) |
| AI | 5.8 | 0.73 | 0.3 | 0.3×V(AI) |
最终:"爱"的新表示 = 0.1×V(我) + 0.2×V(爱) + 0.4×V(学习) + 0.3×V(AI)
三、多头注意力(Multi-Head Attention)
3.1 为什么需要多头?
单头的局限:只能捕捉一种关系模式
graph LR
subgraph 单头注意力
S["我爱学习AI"] --> H1["只关注<br/>语义相关性"]
end
subgraph 多头注意力
M["我爱学习AI"] --> MH1["Head1<br/>语义关系"]
M --> MH2["Head2<br/>语法关系"]
M --> MH3["Head3<br/>位置关系"]
M --> MH4["Head4<br/>..."]
end
style H1 fill:#ffcdd2
style MH1 fill:#a5d6a7
style MH2 fill:#90caf9
style MH3 fill:#ce93d8
style MH4 fill:#fff59d
3.2 多头注意力架构
graph TB
X["输入 X"] --> Split["线性变换并分割"]
Split --> H1["Head 1<br/>Attention(Q1,K1,V1)"]
Split --> H2["Head 2<br/>Attention(Q2,K2,V2)"]
Split --> H3["Head 3<br/>Attention(Q3,K3,V3)"]
Split --> H4["Head h<br/>Attention(Qh,Kh,Vh)"]
H1 --> Concat["拼接 Concat"]
H2 --> Concat
H3 --> Concat
H4 --> Concat
Concat --> Linear["线性变换 W^O"]
Linear --> Output["输出"]
style Split fill:#e3f2fd
style Concat fill:#fff9c4
style Output fill:#a5d6a7
3.3 数学公式
对于 个注意力头:
其中:
典型配置(如BERT):
- (模型维度)
- (注意力头数)
- (每个头的维度)
3.4 多头的优势
graph TB
subgraph 例子[句子: The animal didn't cross the street because it was too tired]
Example["it 指代什么?"]
end
subgraph Heads[不同注意力头的关注点]
H1["Head 1<br/>it → animal<br/>(主语指代)"]
H2["Head 2<br/>it → street<br/>(位置关系)"]
H3["Head 3<br/>didn't → tired<br/>(因果关系)"]
end
Example --> H1
Example --> H2
Example --> H3
H1 --> Final["综合判断:<br/>it = animal"]
H2 --> Final
H3 --> Final
style H1 fill:#a5d6a7
style H2 fill:#ffcdd2
style H3 fill:#90caf9
style Final fill:#fff59d
四、PyTorch完整实现
4.1 Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
参数:
Q: [batch_size, n_heads, seq_len, d_k]
K: [batch_size, n_heads, seq_len, d_k]
V: [batch_size, n_heads, seq_len, d_v]
mask: [batch_size, 1, 1, seq_len] 可选
返回:
output: [batch_size, n_heads, seq_len, d_v]
attention_weights: [batch_size, n_heads, seq_len, seq_len]
"""
d_k = Q.size(-1)
# 1. 计算注意力分数: QK^T / √d_k
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores: [batch_size, n_heads, seq_len, seq_len]
# 2. 可选:应用mask(用于Decoder中的自回归)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# 3. Softmax归一化
attention_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 测试
batch_size, n_heads, seq_len, d_k = 2, 8, 10, 64
Q = torch.randn(batch_size, n_heads, seq_len, d_k)
K = torch.randn(batch_size, n_heads, seq_len, d_k)
V = torch.randn(batch_size, n_heads, seq_len, d_k)
output, attn_weights = scaled_dot_product_attention(Q, K, V)
print(f"输出形状: {output.shape}") # torch.Size([2, 8, 10, 64])
print(f"注意力权重: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])
4.2 Multi-Head Attention完整实现
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads):
"""
参数:
d_model: 模型维度(如768)
n_heads: 注意力头数(如12)
"""
super().__init__()
assert d_model % n_heads == 0, "d_model必须能被n_heads整除"
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads # 每个头的维度
# QKV的线性变换矩阵
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
# 输出的线性变换
self.W_O = nn.Linear(d_model, d_model)
def split_heads(self, x):
"""
将输入分割成多个头
x: [batch_size, seq_len, d_model]
返回: [batch_size, n_heads, seq_len, d_k]
"""
batch_size, seq_len, d_model = x.size()
return x.view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
"""
参数:
Q, K, V: [batch_size, seq_len, d_model]
mask: [batch_size, 1, 1, seq_len]
"""
batch_size = Q.size(0)
# 1. 线性变换
Q = self.W_Q(Q) # [batch_size, seq_len, d_model]
K = self.W_K(K)
V = self.W_V(V)
# 2. 分割成多个头
Q = self.split_heads(Q) # [batch_size, n_heads, seq_len, d_k]
K = self.split_heads(K)
V = self.split_heads(V)
# 3. 缩放点积注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# attn_output: [batch_size, n_heads, seq_len, d_k]
# 4. 合并多个头
attn_output = attn_output.transpose(1, 2).contiguous()
# [batch_size, seq_len, n_heads, d_k]
attn_output = attn_output.view(batch_size, -1, self.d_model)
# [batch_size, seq_len, d_model]
# 5. 最终线性变换
output = self.W_O(attn_output)
return output, attn_weights
# 使用示例
d_model = 512
n_heads = 8
batch_size = 2
seq_len = 10
mha = MultiHeadAttention(d_model, n_heads)
# 输入
x = torch.randn(batch_size, seq_len, d_model)
# 自注意力:Q=K=V
output, attn_weights = mha(x, x, x)
print(f"输入形状: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}") # torch.Size([2, 10, 512])
print(f"注意力权重: {attn_weights.shape}") # torch.Size([2, 8, 10, 10])
4.3 可视化注意力权重
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention(attn_weights, tokens, head_idx=0):
"""
可视化某个注意力头的权重
attn_weights: [batch_size, n_heads, seq_len, seq_len]
tokens: 词列表
head_idx: 要可视化的头索引
"""
# 提取第一个样本的指定头
attn = attn_weights[0, head_idx].detach().cpu().numpy()
plt.figure(figsize=(10, 8))
sns.heatmap(attn,
xticklabels=tokens,
yticklabels=tokens,
cmap='YlOrRd',
annot=True,
fmt='.2f',
cbar_kws={'label': '注意力权重'})
plt.title(f'Attention Head {head_idx}')
plt.xlabel('Key')
plt.ylabel('Query')
plt.tight_layout()
plt.show()
# 示例:可视化英译中的注意力
tokens = ['I', 'love', 'learning', 'AI', '<EOS>']
seq_len = len(tokens)
# 模拟注意力权重
mha = MultiHeadAttention(d_model=512, n_heads=8)
x = torch.randn(1, seq_len, 512)
_, attn_weights = mha(x, x, x)
# 可视化第0个头
visualize_attention(attn_weights, tokens, head_idx=0)
输出效果:
注意力矩阵 (Head 0)
I love learning AI <EOS>
I 0.20 0.15 0.10 0.35 0.20
love 0.10 0.50 0.30 0.05 0.05
learning 0.05 0.30 0.40 0.20 0.05
AI 0.15 0.10 0.25 0.45 0.05
<EOS> 0.05 0.05 0.05 0.10 0.75
五、Transformer Block完整结构
5.1 单个Block的组成
graph TB
Input[输入 X] --> MHA[多头注意力]
Input -.残差连接.-> Add1
MHA --> Add1[Add]
Add1 --> Norm1[Layer Norm]
Norm1 --> FFN[前馈网络<br/>2层全连接]
Norm1 -.残差连接.-> Add2
FFN --> Add2[Add]
Add2 --> Norm2[Layer Norm]
Norm2 --> Output[输出 Y]
style MHA fill:#fff59d
style FFN fill:#90caf9
style Add1 fill:#a5d6a7
style Add2 fill:#a5d6a7
style Norm1 fill:#ce93d8
style Norm2 fill:#ce93d8
5.2 残差连接(Residual Connection)
为什么需要?
graph LR
subgraph 无残差[深层网络问题]
A["梯度消失"] --> B["难以训练"]
end
subgraph 残差解决[ResNet思想]
C["Y = X + F(X)"] --> D["梯度直通<br/>易于优化"]
end
B -.引入残差.-> D
style A fill:#ffcdd2
style D fill:#a5d6a7
数学表示:
5.3 Layer Normalization
与Batch Norm的区别:
| 特性 | Batch Norm | Layer Norm |
|---|---|---|
| 归一化维度 | 跨batch维度 | 跨特征维度 |
| 适用场景 | CNN(固定batch) | NLP(变长序列) |
| 依赖性 | 依赖batch大小 | 独立于batch |
graph LR
subgraph BatchNorm[Batch Normalization]
B1["样本1<br/>[d1,d2,d3]"]
B2["样本2<br/>[d1,d2,d3]"]
B3["样本3<br/>[d1,d2,d3]"]
B1 --> BN["对每个特征<br/>跨样本归一化"]
B2 --> BN
B3 --> BN
end
subgraph LayerNorm[Layer Normalization]
L1["样本1<br/>[d1,d2,d3]"]
L1 --> LN["对每个样本<br/>跨特征归一化"]
end
style BN fill:#ffccbc
style LN fill:#a5d6a7
Layer Norm公式:
其中 是当前层所有特征的均值和标准差。
5.4 前馈网络(Feed-Forward Network)
结构:两层全连接+激活函数
graph LR
Input["输入<br/>[seq_len, d_model]"] --> Linear1["全连接1<br/>d_model → d_ff"]
Linear1 --> ReLU["ReLU激活"]
ReLU --> Linear2["全连接2<br/>d_ff → d_model"]
Linear2 --> Output["输出<br/>[seq_len, d_model]"]
style Linear1 fill:#fff9c4
style ReLU fill:#90caf9
style Linear2 fill:#c5e1a5
典型配置:
- BERT: , (4倍)
- GPT-3: , (4倍)
5.5 完整Transformer Block实现
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
"""
参数:
d_model: 模型维度
n_heads: 注意力头数
d_ff: 前馈网络隐藏层维度
dropout: Dropout比率
"""
super().__init__()
# 多头注意力
self.mha = MultiHeadAttention(d_model, n_heads)
# 前馈网络
self.ffn = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Linear(d_ff, d_model)
)
# Layer Normalization
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
x: [batch_size, seq_len, d_model]
"""
# 1. 多头自注意力 + 残差连接 + LayerNorm
attn_output, _ = self.mha(x, x, x, mask)
attn_output = self.dropout1(attn_output)
x = self.norm1(x + attn_output) # 残差连接
# 2. 前馈网络 + 残差连接 + LayerNorm
ffn_output = self.ffn(x)
ffn_output = self.dropout2(ffn_output)
x = self.norm2(x + ffn_output) # 残差连接
return x
# 测试
d_model = 512
n_heads = 8
d_ff = 2048
block = TransformerBlock(d_model, n_heads, d_ff)
x = torch.randn(2, 10, d_model) # [batch, seq_len, d_model]
output = block(x)
print(f"输入形状: {x.shape}") # torch.Size([2, 10, 512])
print(f"输出形状: {output.shape}") # torch.Size([2, 10, 512])
六、位置编码(Positional Encoding)
6.1 为什么需要位置编码?
问题:自注意力是置换不变的(permutation-invariant)
sentence1 = "我 爱 AI"
sentence2 = "AI 爱 我"
# 如果没有位置编码,Self-Attention会给出相同的结果!
graph TB
Problem["问题: 自注意力无法区分词序"] --> Solution["解决: 添加位置信息"]
Solution --> PE1["方案1: 学习位置嵌入<br/>(BERT)"]
Solution --> PE2["方案2: 固定位置编码<br/>(原始Transformer)"]
style Problem fill:#ffcdd2
style Solution fill:#fff9c4
style PE1 fill:#c5e1a5
style PE2 fill:#90caf9
6.2 正弦位置编码
公式:
其中:
- : 词的位置(0, 1, 2, ...)
- : 维度索引(0, 1, ..., d_model/2)
优点:
- ✅ 可以处理任意长度的序列
- ✅ 不需要训练参数
- ✅ 相对位置关系可以通过线性变换表达
6.3 实现代码
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
# 创建位置编码矩阵
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
# 计算分母
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
# 计算正弦和余弦
pe[:, 0::2] = torch.sin(position * div_term) # 偶数维度
pe[:, 1::2] = torch.cos(position * div_term) # 奇数维度
# 添加batch维度
pe = pe.unsqueeze(0) # [1, max_len, d_model]
self.register_buffer('pe', pe)
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
# 添加位置编码
x = x + self.pe[:, :x.size(1), :]
return x
# 可视化位置编码
def visualize_positional_encoding(d_model=128, max_len=100):
pe = PositionalEncoding(d_model, max_len)
encoding = pe.pe[0, :max_len, :].numpy()
plt.figure(figsize=(15, 5))
plt.imshow(encoding.T, aspect='auto', cmap='RdBu',
interpolation='nearest')
plt.colorbar(label='编码值')
plt.xlabel('位置')
plt.ylabel('维度')
plt.title('正弦位置编码可视化')
plt.tight_layout()
plt.show()
visualize_positional_encoding()
输出效果:会看到规律的波浪状图案,不同频率编码不同维度。
七、完整Encoder实现
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, d_ff,
n_layers, max_len=5000, dropout=0.1):
"""
参数:
vocab_size: 词汇表大小
d_model: 模型维度
n_heads: 注意力头数
d_ff: 前馈网络维度
n_layers: Transformer Block层数
max_len: 最大序列长度
dropout: Dropout比率
"""
super().__init__()
# 词嵌入
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model, max_len)
# 多层Transformer Block
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
"""
x: [batch_size, seq_len] (token indices)
"""
# 1. 词嵌入 + 位置编码
x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
x = self.pos_encoding(x)
x = self.dropout(x)
# 2. 通过多层Transformer Block
for layer in self.layers:
x = layer(x, mask)
return x
# 使用示例:构建一个小型BERT
vocab_size = 30000
d_model = 768
n_heads = 12
d_ff = 3072
n_layers = 12
encoder = TransformerEncoder(vocab_size, d_model, n_heads, d_ff, n_layers)
# 输入token indices
input_ids = torch.randint(0, vocab_size, (2, 20)) # [batch=2, seq_len=20]
output = encoder(input_ids)
print(f"输入形状: {input_ids.shape}") # torch.Size([2, 20])
print(f"输出形状: {output.shape}") # torch.Size([2, 20, 768])
print(f"参数量: {sum(p.numel() for p in encoder.parameters())/1e6:.1f}M")
# 输出: 参数量: 110.1M (接近BERT-Base的110M)
八、关键概念总结
8.1 公式总结
| 组件 | 公式 | 说明 |
|---|---|---|
| Self-Attention | 缩放点积注意力 | |
| Multi-Head | 多视角融合 | |
| FFN | 两层全连接 | |
| Layer Norm | 特征归一化 | |
| Positional | 位置编码 |
8.2 配置对比
| 模型 | d_model | n_heads | n_layers | d_ff | 参数量 |
|---|---|---|---|---|---|
| BERT-Base | 768 | 12 | 12 | 3072 | 110M |
| BERT-Large | 1024 | 16 | 24 | 4096 | 340M |
| GPT-2 | 768 | 12 | 12 | 3072 | 117M |
| GPT-3 | 12288 | 96 | 96 | 49152 | 175B |
8.3 架构流程图
graph TB
Start[输入Token IDs] --> Embed[词嵌入层]
Embed --> PE[位置编码]
PE --> Block1[Transformer Block 1]
subgraph Block[每个Block内部]
MHA[多头注意力] --> Add1[Add & Norm]
Add1 --> FFN[前馈网络]
FFN --> Add2[Add & Norm]
end
Block1 --> Block2[Transformer Block 2]
Block2 --> BlockN[... Block N]
BlockN --> Output[输出表示]
style Embed fill:#e3f2fd
style PE fill:#fff9c4
style MHA fill:#ffccbc
style FFN fill:#c5e1a5
style Output fill:#a5d6a7
九、实战练习
练习1:计算注意力权重
题目:给定QKV矩阵,手工计算注意力输出
# 已知(简化为2x2矩阵方便计算)
Q = torch.tensor([[1.0, 0.0],
[0.0, 1.0]])
K = torch.tensor([[1.0, 0.0],
[0.5, 0.5]])
V = torch.tensor([[2.0, 0.0],
[1.0, 1.0]])
# 步骤:
# 1. 计算 QK^T / √d_k
# 2. Softmax
# 3. 乘以V
# 你的答案:
练习2:实现Masked Self-Attention
任务:修改Self-Attention,实现Decoder中的mask机制(当前词不能看到未来词)
def masked_self_attention(Q, K, V):
"""
TODO: 实现masked attention
提示: 使用torch.tril创建下三角mask
"""
pass
练习3:分析注意力头
任务:加载预训练BERT,可视化不同注意力头关注的模式
from transformers import BertModel, BertTokenizer
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
text = "The animal didn't cross the street because it was too tired."
inputs = tokenizer(text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
attentions = outputs.attentions # 12层,每层12个头