1. 什么是自注意力?
想象一下,你正在阅读一本小说,每看到一个词语时,大脑会自动关注前文中与之相关的信息。这种"聚焦重点"的能力,正是自注意力机制的核心思想。
自注意力(Self-Attention)是一种让序列中的每个元素都能关注整个序列的机制。就像班级讨论时,每个同学发言(查询)都会考虑所有人的观点(键和值)。具体来说:
给定输入序列 ,自注意力通过三个步骤生成输出:
- 生成问题纸条:每个词元创建查询向量
- 制作答案卡:每个词元生成键向量 和值向量
- 收集答案:每个查询收集所有键值对的加权和:
其中:
- (Query)是查询矩阵,大小为 ,其中 是查询的数量, 是特征维度。
- (Key)是键矩阵,大小为 ,其中 是键的数量, 是特征维度。
- (Value)是值矩阵,大小为 。
- 是一个缩放因子,用于防止大数值导致 softmax 过于极端,从而影响梯度的稳定性。
示例:考虑句子"猫吃鱼",自注意力会让"吃"同时关注"猫"和"鱼",就像我们在理解动词时会自动联系主语和宾语。
下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,输入张量 的形状为 ,经过自注意力计算后,输出张量与输入张量形状保持一致。
import torch
import d2l
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)
"""输出:
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
"""
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)
# 输出:torch.Size([2, 4, 100])
2. 三大序列模型的巅峰对决
2.1 参赛选手介绍
| 模型类型 | 工作方式 | 可视化类比 |
|---|---|---|
| CNN | 滑动窗口扫描 | 望远镜观察局部区域 |
| RNN | 顺序传递信息 | 接力赛传递消息 |
| 自注意力 | 全局直接交互 | 电话会议全员讨论 |
2.2 性能参数对比
使用 个词元,每个维度 ,卷积核大小 :
| 指标 | CNN | RNN | 自注意力 |
|---|---|---|---|
| 计算复杂度 | |||
| 并行能力 | 高 | 低 | 极高 |
| 最大路径长度 |
示例:处理100个词的句子时,自注意力需要100×100=10,000次交互计算,而CNN(假设k=3)只需3×100=300次局部计算。
3. 位置编码:给词语发"座位号"
3.1 为什么需要位置信息?
自注意力虽然强大,但有个致命缺陷——所有词语同时处理,就像把句子里的词全部平铺在桌面上,模型无法知道它们的原始顺序。这时就需要位置编码来标记每个词的位置。
3.2 神奇的三角函数编码
使用正弦和余弦函数的组合生成位置编码矩阵 ,其中第 行对应位置,第 和 列使用:
示例:当 时,位置1的编码可能是: [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/4)), cos(1/10000^(2/4))]
让我们在下面的PositionalEncoding类中实现它这种编码方式:
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
在位置嵌入矩阵 中,行代表词元在序列中的位置,列代表位置编码的不同维度。
encoding_dim, num_steps = 32, 60
pos_encoding = d2l.PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6.18, 3.82), legend=["Col %d" % d for d in torch.arange(6, 10)])
从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。
3.3 编码特性揭秘
绝对位置感知
不同列对应不同频率的波形,就像钢琴键盘上从左到右音调逐渐降低。高频(左侧列)帮助捕捉相邻词语的位置关系,低频(右侧列)负责编码词语在序列中的整体位置。
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.82, 6.18), cmap='Blues')
相对位置推理
关键公式:位置 的编码可以表示为位置 编码的线性变换:
这就像通过三角函数公式,模型可以推导出词语之间的相对距离。
4. 关键知识点总结
- 自注意力的本质:让每个词元都能与序列中所有词元直接交互
- 三大模型对比:
- CNN:局部感知,适合处理图像
- RNN:顺序处理,适合流式数据
- 自注意力:全局交互,适合长程依赖
- 位置编码的妙用:
- 绝对位置:通过不同频率的正余弦函数编码
- 相对位置:利用三角恒等式实现位置偏移的线性表示
通过这个魔法般的组合,现代Transformer模型才能在机器翻译、文本生成等任务中展现出惊人的性能。理解这些基础原理,就是打开深度学习宝库的第一把钥匙!