自注意力与位置编码:让模型理解序列的魔法

251 阅读2分钟

1. 什么是自注意力?

想象一下,你正在阅读一本小说,每看到一个词语时,大脑会自动关注前文中与之相关的信息。这种"聚焦重点"的能力,正是自注意力机制的核心思想。

自注意力(Self-Attention)是一种让序列中的每个元素都能关注整个序列的机制。就像班级讨论时,每个同学发言(查询)都会考虑所有人的观点(键和值)。具体来说:

给定输入序列 X=[x1,x2,...,xn]\mathbf{X} = [x_1, x_2, ..., x_n],自注意力通过三个步骤生成输出:

  1. 生成问题纸条:每个词元创建查询向量 qi=Wqxi\mathbf{q}_i = W_q x_i
  2. 制作答案卡:每个词元生成键向量 kj=Wkxj\mathbf{k}_j = W_k x_j 和值向量 vj=Wvxj\mathbf{v}_j = W_v x_j
  3. 收集答案:每个查询收集所有键值对的加权和:
yi=j=1nsoftmax(qikjd)vjy_i = \sum_{j=1}^n \text{softmax}\left(\frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d}}\right) \mathbf{v}_j

其中:

  • q\mathbf{q}(Query)是查询矩阵,大小为 (n×d)(n \times d),其中 nn 是查询的数量,dd 是特征维度。
  • k\mathbf{k}(Key)是键矩阵,大小为 (m×d)(m \times d),其中 mm 是键的数量,dd 是特征维度。
  • v\mathbf{v}(Value)是值矩阵,大小为 (m×dv)(m \times d_v)
  • 1d\frac{1}{\sqrt{d}} 是一个缩放因子,用于防止大数值导致 softmax 过于极端,从而影响梯度的稳定性。

示例:考虑句子"猫吃鱼",自注意力会让"吃"同时关注"猫"和"鱼",就像我们在理解动词时会自动联系主语和宾语。

下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,输入张量 XX 的形状为 (批量大小,序列长度,特征维度)(\text{批量大小}, \text{序列长度}, \text{特征维度}),经过自注意力计算后,输出张量与输入张量形状保持一致。

import torch
import d2l

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()

print(attention)
"""输出:

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
"""

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))

print(attention(X, X, X, valid_lens).shape)
# 输出:torch.Size([2, 4, 100])

2. 三大序列模型的巅峰对决

2.1 参赛选手介绍

模型类型工作方式可视化类比
CNN滑动窗口扫描望远镜观察局部区域
RNN顺序传递信息接力赛传递消息
自注意力全局直接交互电话会议全员讨论

2.2 性能参数对比

使用 nn 个词元,每个维度 dd,卷积核大小 kk

指标CNNRNN自注意力
计算复杂度O(knd2)\mathcal{O}(knd^2)O(nd2)\mathcal{O}(nd^2)O(n2d)\mathcal{O}(n^2d)
并行能力极高
最大路径长度O(n/k)\mathcal{O}(n/k)O(n)\mathcal{O}(n)O(1)\mathcal{O}(1)

cnn-rnn-self-attention.svg

图1 比较卷积神经网络(填充词元被忽略)、循环神经网络和自注意力三种架构

示例:处理100个词的句子时,自注意力需要100×100=10,000次交互计算,而CNN(假设k=3)只需3×100=300次局部计算。


3. 位置编码:给词语发"座位号"

3.1 为什么需要位置信息?

自注意力虽然强大,但有个致命缺陷——所有词语同时处理,就像把句子里的词全部平铺在桌面上,模型无法知道它们的原始顺序。这时就需要位置编码来标记每个词的位置。

3.2 神奇的三角函数编码

使用正弦和余弦函数的组合生成位置编码矩阵 PP,其中第 ii 行对应位置,第 2j2j2j+12j+1 列使用:

Pi,2j=sin(i100002j/d)Pi,2j+1=cos(i100002j/d)\begin{aligned} P_{i,2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right) \\ P_{i,2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right) \end{aligned}

示例:当 d=4d=4 时,位置1的编码可能是: [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/4)), cos(1/10000^(2/4))]

让我们在下面的PositionalEncoding类中实现它这种编码方式:

class PositionalEncoding(nn.Module):
    """位置编码"""

    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

在位置嵌入矩阵 PP 中,行代表词元在序列中的位置,列代表位置编码的不同维度。

encoding_dim, num_steps = 32, 60
pos_encoding = d2l.PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6.18, 3.82), legend=["Col %d" % d for d in torch.arange(6, 10)])

从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。

10.6.位置编码.png

3.3 编码特性揭秘

绝对位置感知

不同列对应不同频率的波形,就像钢琴键盘上从左到右音调逐渐降低。高频(左侧列)帮助捕捉相邻词语的位置关系,低频(右侧列)负责编码词语在序列中的整体位置。

P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.82, 6.18), cmap='Blues')

10.6.位置编码热图.png

相对位置推理

关键公式:位置 i+ki+k 的编码可以表示为位置 ii 编码的线性变换:

sin(ωj(i+k))=sin(ωji)cos(ωjk)+cos(ωji)sin(ωjk)cos(ωj(i+k))=cos(ωji)cos(ωjk)sin(ωji)sin(ωjk)\begin{aligned} \sin(\omega_j(i+k)) &= \sin(\omega_j i)\cos(\omega_j k) + \cos(\omega_j i)\sin(\omega_j k) \\ \cos(\omega_j(i+k)) &= \cos(\omega_j i)\cos(\omega_j k) - \sin(\omega_j i)\sin(\omega_j k) \end{aligned}

这就像通过三角函数公式,模型可以推导出词语之间的相对距离。


4. 关键知识点总结

  1. 自注意力的本质:让每个词元都能与序列中所有词元直接交互
  2. 三大模型对比
    • CNN:局部感知,适合处理图像
    • RNN:顺序处理,适合流式数据
    • 自注意力:全局交互,适合长程依赖
  3. 位置编码的妙用
    • 绝对位置:通过不同频率的正余弦函数编码
    • 相对位置:利用三角恒等式实现位置偏移的线性表示

通过这个魔法般的组合,现代Transformer模型才能在机器翻译、文本生成等任务中展现出惊人的性能。理解这些基础原理,就是打开深度学习宝库的第一把钥匙!