不再费脑, 手算 Attention 公式, 理解 Transformer 注意力的数学本质

178 阅读5分钟

大家好, 我是印刻君. 今天我们来聊注意力 (Attention) , 这个 Transfromer 架构的核心.

不少科普文章, 在介绍到注意力时, 都会刻意绕过数学公式, 单纯靠比喻来解释注意力.

但靠比喻感知注意力, 跟完全理解它始终会有一层隔阂, 毕竟看地图永远不等于走一遍路.

本篇文章, 我们会通过一个例子, 亲自过一遍计算注意力的核心公式, 一起探究注意力机制的数学本质.

本文涉及 Transformer 基础架构, 矩阵运算, 若你对这两部分不熟悉, 建议先补全前置知识, 阅读体验会更好:

不熟悉 Transformer 基础架构, 可以看我的文章: 数学不好也能懂: 解读 AI 经典论文《Attention is All You Need》与大模型生成原理

不熟悉矩阵运算, 可以看我的文章: 不再费脑, 写给 AI 爱好者的矩阵 (Matrix) 入门指南

先建立直觉:注意力到底是什么?

在啃公式之前, 我们把注意力这个概念放到生活场景里. 毕竟所有复杂的技术, 本质都是对现实世界的抽象.

微信底部菜单截图

看这张微信菜单截图, 你是不是瞬间被消息图标上的红色提示吸引? 这就是的 "注意力分配": 大脑会自动筛选信息, 优先聚焦到有价值, 有异常的部分.

语言理解中的注意力同理. 比如这句话: "印刻君没有吃晚餐, 因为他不饿".

当你读到 "他" 这个代词时, 大脑会自动把注意力集中在 "印刻君" 上, 而不是 "晚餐" 上. 这是因为你的大脑已经通过上下文, 判断出 "他" 这个代词和 "印刻君" 的关联性更强.

Transformer 里的注意力, 也是在做同样的事情: 针对句子里的每个词, 计算它和其他所有词的关联性得分, 再根据得分重新整合信息. 高分分配的注意力多, 低分分配的注意力少.

而我们今天要过的公式, 就是把这个先计算关联性得分, 后整合信息的过程, 用数学语言精确描述.

手算 Attention 公式的每一步

0 Attention 公式核心解读

在动手计算前,我们先看注意力机制的核心公式:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V

其中, Attention(Q,K,V)Attention(Q, K, V) 是整合了上下文信息的全新矩阵. 公式的意思是, 这个矩阵先由 Q 与 K 计算注意力权重,再在注意力权重的基础上乘以 V.

  1. 公式中的 Q, K, V 是三个核心输入, 后续我们会详细说明它们的由来和作用;
  2. QKTQK^T 表示查询矩阵与键矩阵的转置相乘;
  3. dk\sqrt{d_k} 是 K 向量维度的平方根, 作用是对得分进行缩放;
  4. Softmax 对缩放后的得分矩阵做归一化处理, 将每一行的得分转换为 0 到 1 之间的概率值;
  5. 乘以 V, 是用归一化后的注意力权重矩阵, 提取出所有词的价值信息

1 准备输入向量 A, B, C

为简化计算, 我们做一个假设: 一句话只有 A, B, C 三个词, 且每个词都用 4 维向量表示 (真实模型中向量维度有几百维, 这里选 4 维只是为了好算).

以下是这三个词的初始词嵌入向量:

A=[0,1,2,3]A = [0, 1, 2, 3]

B=[4,5,6,7]B = [4, 5, 6, 7]

C=[8,9,10,11]C = [8, 9, 10, 11]

这 3 个向量会被组合成一个矩阵 X

X=[01234567891011]X = \begin{bmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ 8 & 9 & 10 & 11 \end{bmatrix}

这些向量是从哪来的?

这些向量最初值是随机分配的, 模型通过大量训练, "反向传播" 后得到最后的值. 这里用简单整数代替真实向量的值, 只是为了演示计算过程.

不熟悉词向量, 可以看我的文章: 大模型如何分辨 “狼” 与 “狗” —— 词向量的训练过程

2 准备 3 个核心矩阵 WQW^Q, WKW^K, WVW^V

Attention 中的 Q、K、V 矩阵, 是由 3 个矩阵 WQW^Q, WKW^K, WVW^V, 和输入矩阵 X 相乘得到

WQW^Q, WKW^K, WVW^V 是用来做向量转换的矩阵, 专业叫法是 "投影矩阵".

我们先给出这三个矩阵的具体数值(同样为了好算,用简单矩阵代替真实模型的随机初始化矩阵):

WQ=[10100101]WK=[01011010]W^Q = \begin{bmatrix} 1 & 0 \\ 1 & 0 \\ 0 & 1 \\ 0 & 1 \end{bmatrix} \quad W^K = \begin{bmatrix} 0 & 1 \\ 0 & 1 \\ 1 & 0 \\ 1 & 0 \end{bmatrix}
WV=[10011001]W^V = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}

这三个矩阵是从哪来的?

和词嵌入向量类似, Transformer 架构搭建好后, 这三个矩阵就存在了, 但里面的数值都是随机的.模型训练的核心过程, 就是通过 "反向传播" 不断调整这三个矩阵(以及其他所有参数)的数值.直到它们能精准捕捉词与词的关联性.

简单说: 训练 Transformer, 本质就是在调优这三个矩阵的数字, 让 Q 和 K 的匹配更合理, V 的信息整合更有效.

3 计算 Q, K, V 矩阵

Q, K, V 的计算逻辑很简单: 把所有输入向量组成一个输入矩阵, 记为 X, 分别和 WQW^QWKW^KWVW^V 做矩阵乘法, 结果就是 Q, K, V 矩阵.

X=[01234567891011]X = \begin{bmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ 8 & 9 & 10 & 11 \end{bmatrix}
  • 第一行 [0,1,2,3][0, 1, 2, 3] 是向量 A;
  • 第二行 [4,5,6,7][4, 5, 6, 7] 是向量 B;
  • 第三行 [8,9,10,11][8, 9, 10, 11] 是向量 C.

3.1 计算 Q 矩阵

Q=XWQQ = X \cdot W^Q
=[01234567891011][10100101]= \begin{bmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ 8 & 9 & 10 & 11 \end{bmatrix} \cdot \begin{bmatrix} 1 & 0 \\ 1 & 0 \\ 0 & 1 \\ 0 & 1 \end{bmatrix}
=[159131721]= \begin{bmatrix} 1 & 5 \\ 9 & 13 \\ 17 & 21 \end{bmatrix}
  • Q 矩阵的第一行是 QAQ^A , 代表向量 A 的查询向量;
  • 第二行是 QBQ^B, 代表向量 B 的查询向量;
  • 第三行是 QCQ^C, 代表向量 C 的查询向量.

3.2 计算 K 矩阵

K=XWKK = X \cdot W^K
=[01234567891011][01011010]= \begin{bmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ 8 & 9 & 10 & 11 \end{bmatrix} \cdot \begin{bmatrix} 0 & 1 \\ 0 & 1 \\ 1 & 0 \\ 1 & 0 \end{bmatrix}
=[511392117]= \begin{bmatrix} 5 & 1 \\ 13 & 9 \\ 21 & 17 \end{bmatrix}
  • K 矩阵的第一行是 kAk^A, 代表向量 A 的键向量;
  • 第二行是 KBK^B, 代表向量 B 的键向量;
  • 第三行是 KCK^C, 代表向量 C 的键向量.

3.3 计算 V 矩阵

V=XWVV = X \cdot W^V
=[01234567891011][10011001]= \begin{bmatrix} 0 & 1 & 2 & 3 \\ 4 & 5 & 6 & 7 \\ 8 & 9 & 10 & 11 \end{bmatrix} \cdot \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}
=[2410121820]= \begin{bmatrix} 2 & 4 \\ 10 & 12 \\ 18 & 20 \end{bmatrix}

3.4 Q、K、V 的作用

算到这, 我们已经得到了 Q, K, V 三个矩阵, 它们是 Q, K, V 向量的组合, 我们用字典来打个比方:

  • Q (query, 查询向量): 好比我们查字典时输的关键词, 放到模型里, 就是当前这个词想找谁, 想和谁搭关系;
  • K (key, 键向量), 好比字典里所有词条的标题. 字典里每个词都有个标题, 模型里每个词也都有个键向量, 专门用来和其他词的查询向量做配对;
  • V (value, 价值向量), 好比字典标题对应的解释内容. 你用关键词 Q 找到对应的标题 K 后, 真正要的是标题底下的解释; 而 V 就是每个词身上最核心的信息.

4 计算 QKTQK^T

公式里的 KTK^T 是 K 矩阵的 "转置", 也就是把 K 的行和列互换

K=[511392117]K = \begin{bmatrix} 5 & 1 \\ 13 & 9 \\ 21 & 17 \end{bmatrix}
KT=[513211917]K^T = \begin{bmatrix} 5 & 13 & 21 \\ 1 & 9 & 17 \end{bmatrix}

QKTQK^T 的结果, 就是 "每个词与其他所有词的关联性得分", 也叫 "注意力得分".

QKT=[159131721][513211917]QK^T = \begin{bmatrix} 1 & 5 \\ 9 & 13 \\ 17 & 21 \end{bmatrix} \cdot \begin{bmatrix} 5 & 13 & 21 \\ 1 & 9 & 17 \\ \end{bmatrix}
=[105810658234400106400714]= \begin{bmatrix} 10 & 58 & 106 \\ 58 & 234 & 400 \\ 106 & 400 & 714 \end{bmatrix}

QKTQK^T 得到的矩阵, 每一行对应一个词的关联得分表:

  • 第一行是 A 对 A, B, C 的得分;
  • 第二行是 B 对 A, B, C 的得分;
  • 第三行是 C 对 A, B, C 的得分.

得分越高, 说明两个词的关联性越强.

5 除以 dk\sqrt{d_k}

公式里的 dk\sqrt{d_k} 是键向量的维度 dkd_k 的平方根.

这里 K 向量是 2 维, 所以

dk=2,dk1.414d_k = 2, \sqrt{d_k} \approx 1.414

为什么要缩放?因为当 dkd_k 很大时, QKTQK^T 的得分会变得很大, 代入 Softmax 函数后, 指数运算会导致数值要么趋近于无穷大, 要么趋近于 0, 缩放后能让得分更稳定.

计算缩放后的结果:

QKTdk\frac{QK^T}{\sqrt{d_k}}
=[10/1.41458/1.414106/1.41458/1.414234/1.414400/1.414106/1.414400/1.414714/1.414]= \begin{bmatrix} 10/1.414 & 58/1.414 & 106/1.414 \\ 58/1.414 & 234/1.414 & 400/1.414 \\ 106/1.414 & 400/1.414 & 714/1.414 \end{bmatrix}
[7.0741.0274.9641.02165.48282.8474.96282.84504.90]\approx \begin{bmatrix} 7.07 & 41.02 & 74.96 \\ 41.02 & 165.48 & 282.84 \\ 74.96 & 282.84 & 504.90 \end{bmatrix}

6 softmax 归一化

softmax 函数的作用是归一化, 把每一行的得分转换成 0 到 1 之间的概率, 且每一行的概率和为 1.

假设有一个输入向量 z=[z1,z2,...,zk]z = [z_1, z_2, ..., z_k], 那么对于向量中的第 i 个元素,Softmax 函数的定义如下:

Softmax(zi)=ezij=1KezjSoftmax(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}

这里有个问题, 缩放后的得分最大已经到 504.90, 直接计算 e504.90e^504.90 会得到一个天文数字, 计算机无法精确计算.

这时候可以用 Softmax 的一个关键性质: 给同一行的所有元素同时减去该行的最大值, 结果不变.

我们以第一行 [7.07,41.02,74.96][7.07, 41.02, 74.96] 为例, 每个元素都减 74.96 , 得到

[67.89,33.94,0][-67.89, -33.94, 0]

此时计算

e67.890e^{-67.89} \approx 0
e33.940e^{-33.94} \approx 0
e0=1e^0 = 1

所以第一行 softmax 的结果是 [0,0,1][0, 0, 1].

第二行和第三行同理, 计算后得到完整的 Softmax 矩阵:

softmax(QKTdk)[001001001]softmax(\frac{QK^T}{\sqrt{d_k}}) \approx \begin{bmatrix} 0 & 0 & 1\\ 0 & 0 & 1\\ 0 & 0 & 1 \end{bmatrix}

每一行的概率代表当前词对其他词的关注权重:

  • 第一行 [0,0,1][0, 0, 1], 说明 A 几乎把所有注意力都放在了 C 上;
  • 第二行 [0,0,1][0, 0, 1], 说明 B 也把所有注意力放在了 C 上;
  • 第三行 [0,0,1][0, 0, 1], 说明 C 只关注自己.

7 乘以 V

最后一步是把 Softmax 得到的关注权重, 和 V 矩阵 (词的价值信息) 相乘. 本质是按权重整合其他词的信息, 得到每个词的最终向量 (融合了上下文信息).

Attention(Q,K,V)Attention(Q, K, V)
=softmax(QKTdk)V= softmax(\frac{QK^T}{\sqrt{d_k}})V
=[001001001][2410121820]= \begin{bmatrix} 0 & 0 & 1\\ 0 & 0 & 1\\ 0 & 0 & 1 \end{bmatrix} \cdot \begin{bmatrix} 2 & 4 \\ 10 & 12 \\ 18 & 20 \end{bmatrix}
=[182018201820]= \begin{bmatrix} 18 & 20 \\ 18 & 20 \\ 18 & 20 \end{bmatrix}

总结:注意力的数学本质到底是什么?

我们看最终结果, 原本 A, B, C 的向量各不相同, 但经过注意力计算后, A 和 B 的向量融合了 C 的向量信息, 不再是原来孤立的自己. 换句话说, A, B, C 通过阅读上下文, 不断调整自身的语义, 从而实现精确的理解.

比如翻译 "我喜欢吃苹果" 这句话时, "苹果" 的查询向量 Q 会和 "吃" 的键向量 K 计算出高关联权重, 随后这个高权重会让模型在整合价值信息 V 时, 更侧重把 "吃" 的语义融入到 "苹果" 自身的价值向量中, 最终精准判断 "苹果" 在这里指水果, 而非手机.

放到真实 transformer 里, 这个过程会更复杂 (比如多注意力头, 更高维度的向量), 但核心逻辑没变. 就是 "用数学方法计算词与词的关联性, 再按关联性权重整合上下文信息".

我是印刻君, 一位探索 AI 的前端程序员, 关注我, 让 AI 知识有温度, 技术落地有深度.