上一篇解释完点积和矩阵乘法,矩阵乘法是一种转换,这一篇看 Transformer 中如何运用的。
LLM 的本质是预测下一个 token,阶段二中,使用大量的互联网内容,给模型做训练,使用自监督学习,不断调整 1750 亿个参数。直到模型能够正确的补全文本内容。
阶段二的产物,只有文本补全功能,不具备问答、对话能力。
现在假设所有参数已经调节完毕,以一个输入是 "The cat sat" 展示模型怎么预测下一个 token 是 "on" 的。
查询向量坐标
输入 "The cat sat" 经过 Tokenizer 会把这句话分为 the、cat、sat 三个 token,再从 Vocab Map<String, Integer> 中找到对应的id,再去 Embedding Table 中找到这三个向量坐标。
得到 3 个形状为 [1,4096] 的向量,Vector_The: [0.1, -0.5, ...]、Vector_cat: [0.8, 0.2, ...]、Vector_sat: [-0.1, 0.9, ...]。把这三个向量合并到一个矩阵里面,得到一个形状是 [3,4096] 的向量。
Xin=xThexcatxsat
进入 Layer 加工
接下来正式进入 Layer 加工。之前说过一共有 96 层 Layers,每一层 Layer 有 MHA (Multi-Head Attention) 多头注意力机制 和 FFN (Feed-Forward Network) 前馈神经网络处理。形状为 [3,4096] 的矩阵会完整经历所有 Layers,最后得到加工后的 [3,4096]。
x_in ➔ [Norm] ➔ [MHA] ➔ (+ 残差连接) ➔ x_mid ➔ [Norm] ➔ [FFN] ➔ (+ 残差连接) ➔ x_out
以上是一层 Layer 完整过程。 x_out 会是下一层 Layer 的 x_in。
Norm 层归一化
其中 Norm 是 Layer Normalization(层归一化),矩阵乘法的结果范围非常大,有的值是 50000 而有的是 0.000003,为了防止计算溢出或者梯度乱跳,需要把这些值统一处理为均值为 0 方差为 1。
Norm 计算包含四个步骤,假设以 (xcat): [10, 2, 12, 0] 为例。
-
求均值 (μ)
(10+2+12+0)÷4=6
-
求方差 (σ2)
10→(10−6)2=16
2→(2−6)2=16
12→(12−6)2=36
0→(0−6)2=36
方差=(16+16+36+36)÷4=26
标准差 (σ): 26≈5.1
-
归一化 (Normalize)
公式:标准差x−均值目的是把数据强行拉回到 “均值为 0,方差为 1” 的标准形态。
10→(10−6)/5.1≈0.78
2→(2−6)/5.1≈−0.78
12→(12−6)/5.1≈1.17
0→(0−6)/5.1≈−1.17
结果向量: [0.78, -0.78, 1.17, -1.17]
-
缩放与平移 (Scale & Shift)
如果每次都强行变成 0 均值,可能会破坏数据的含义。所以模型有两个可变参数:γ (缩放) 和 β (平移)。这里可以理解成一个一元二次方程的线性函数。
假设模型学到这一层需要数值稍微大一点:
γ=[2,2,2,2]
β=[1,1,1,1]
最终输出 = 归一化结果×γ+β
0.78×2+1=2.56
...
残差连接
经过 MHA 和 FFN 后,为了防止原始数据丢失,会加上处理前的原始值。
Output=New_Process(x)+x
MHA
回到公式,x_in 是形状为 [3,4096] 的矩阵,经过 Norm 后还是 [3,4096]。接着进入 MHA。
MHA 中有 WQ、WK、WV、WO 四个形状为 [d_modle,d_modle]([4096,4096]) 的矩阵。这是在阶段二训练好的。
Q 是 question,K 是 key,V 是 value,这是三个非常抽象的矩阵,他们的作用是把 "The cat sat" 中三个 token 的向量坐标互相融合,比如 the 要更关注 cat,经过融合后 the 这个 token 的向量值里就包含了大量 cat 的向量值。
MHA 叫多头注意力机制,比如有 32 个头,4096/32=128,就是把 WQ、WK、WV 分为 32 个形状为 [4096,128] 的小矩阵,并行计算得到 32 个结果,再合并起来乘以 WO 得到最终的产物。
大量的并行矩阵计算,这也是算力以 GPU 为核心的原因。
头也是抽象的概念,可以用角度类比,以 "The cat sat on the mat because it was tired." 为例。
Head 1 (语法眼): 专门盯着主谓关系。它发现 "it" 指代 "cat"。
Head 2 (逻辑眼): 专门盯着因果关系。它发现 "because" 导致了 "tired"。
Head 3 (位置眼): 专门盯着方位关系。它关注 "on the mat"。
真个 MHA 的计算过程用一个公式表示。
Z=softmax(dmodel(XWQ)(XWK)T)(XWV)
其中,
Scores=(XWQ)(XWK)T
XWQ : 算出 Q 矩阵。
XWK : 算出 K 矩阵。
(…)T : 转置 K 矩阵,为了能进行矩阵乘法(前一个矩阵的横行必须等于后一个矩阵的纵列)。
转置是沿对角线翻转。比如
K=[142536]
转置后是
KT=123456
d_modle 计算模型设定的维度,除以 dmodel 也是为了防止向量值之间差距太大。
A=softmax(4Scores)
softmax 是一种把一堆数字变为总和为 1 的小数算法,假设有一个输入向量(Logits) z=[z1,z2,...,zn]。 Softmax 函数 σ(z)i 的公式是:
σ(z)i=∑j=1Nezjezi
分子 (ezi) :算出当前元素的指数值。
分母 (∑ezj) :算出所有元素指数值的总和。
用到了 e 自然指数,把负数也转成正数(e−2≈0.135),由于指数函数特性,拉大的原始的差距(e2.0≈7.4 e1.0≈2.7)。
LLM 中常见的配置 temperature 就是在这个公式的分子和分母同时除以 T。
XWV : 算出 V 矩阵。
再乘以 softmax 后的概率 A,就是最终的结果。
这里 X 参数是形状为 [3,4096] 的矩阵,而不是单一的 token,在这个过程中,每个 token 之间都相互融合,由于前一个 token 不能看到后一个 token,所以后一个 token 的向量是包含了前面所有信息的。
后一个 token 看不到前一个 token 通过 Mask 矩阵实现。
Attention=softmax(dQKT+Mask)V
上面的公式简化后是这样的,Mask 矩阵是一个上三角矩阵(右上角全是负无穷 −∞),例如(为了简化这里 d_modle 是 3):
Mask=000−∞00−∞−∞0
这样任何矩阵加上这个矩阵右上角都是负无穷,而 e−∞=0,这保证了在原始句子 "The cat sat" 中 cat 的概率永远是 1。
如果说后一个 token 向量包含了所有信息,那为什么还要算所有向量的 Q K V?这是因为后一个向量计算的过程中,用到了这些数据。
现在用一个 dmodle = 3 的例子,完整演示 MHA 中的计算过程。
入参 X
X=100020002←The←cat←sat
阶段二训练后的参数(注意这里是 WQ,Q 是计算后的矩阵)
WQ=000001000
WK=000010000
WV=100010001
矩阵乘法后
Q=000002000←The (无需求)←cat (无需求)←sat (有需求)
K=000020000→KT=000020000
V=100020002←The 的货←cat 的货←sat 的货
根据公式计算 scores
Scores=000004000
Masked=000004000+000−∞00−∞−∞0=000−∞04−∞−∞0
softmax 后结果
A=1.00.50.0800.50.84000.08←The←cat←sat (聚焦 cat)
解释下第三行怎么算的
公式:softmax(3Scores)。
3≈1.73。
- cat 得分: 4/1.73≈2.3
- The/sat 得分: 0
- e2.3≈10, e0=1。
- 概率: 10/(1+10+1)≈0.84
所以是 [0.08,0.84,0.08]
最后计算 Z 矩阵(Z=A×V)
Z=1.00.50.0801.01.68000.16←输出给下一层的 The←输出给下一层的 cat←输出给下一层的 sat
至此初始值 X 向量经过 MHA 后变为了 Z 向量,可以看到 sat 的原始向量向 cat 倾斜了。
32 个头同时这么进行,每个头都会得到一个 Z 向量结果,取最后一行的 Zcat 把他们全部拼接起来,又变成了一个形状为 [1,4096] 的向量。(推理是只用最后一个 token,训练都用了)
但向量的 0128 位是头 1 的表示的是语法,头 2 的 128256 位表示位置等等,一直到头 32,最终需要再乘以 Wo 矩阵,把这些独立的信息融为一体,得到最终的结果。
这就是 MHA 的全部过程。
FFN
Z 经过残差连接和层归一化后,进入 FFN。
FFN 里面有两个在阶段二训练好的矩阵,W1 和 W2。
W1 是升维矩阵,形状是 [d_modle, 4d_modle],W2 是降维矩阵,形状是 [4d_modle,d_modle]。
假设 Z 已经完成残差连接和层归一化,乘以 W1 矩阵
Hup=Zsat×W1
升维后有更多更细节的信息,比如 Apple 这个 token 升维前的向量是 x = [0.8, 0.1, 0.5] 分别代码是水果,是公司,是红色的。
W1 就像包含 6 个问题的问卷:
- 是电子产品吗?
- 是红色的吗?
- 能吃吗?
- 是交通工具吗?
- 有毛吗?
- 是液体吗?
计算 H=x×W1:
结果向量(6维)可能是这样的:
[10, 8, 9, -5, -10, -2]
- 10 (是电子产品): 命中!
- 8 (是红色的): 命中!
- 9 (能吃): 命中!
- -5 (是交通工具): 完全不是,负分!
- -10 (有毛): 负分滚粗!
- -2 (是液体): 不太像。
然后经过 ReLU 激活函数处理,f(x)=max(0,x),把小于 0 的置位 0,结果就变成了 [10, 8, 9, 0, 0, 0]。
再通过 W2 把结果压缩回去,去掉无用信息,比如变成了 [5.0, 2.0, 8.0]。
再把这个值做一道残差连接,整个 Layer 层执行结束,把结果作为输入传给下一个 Layer 层。
所有 Layers 处理完后
输入的是 "The cat sat" 三个 token 组成的形状为 [3,4096] 的矩阵,96 层 Layers 处理完后输出的也是这个,但是我们只需要 sat 的 H = [1,4096] 这个值,因为他包含了 the cat 和混合。
接着 H 做完 Norm(处理为均值为0方差为1)后,形状还是 [1,4096]。
还记得最开始有个 embedding table 的形状是 [128k,4096] 吗,这里有个 unembedding 的矩阵形状是 [4096,128k],为了节省空间,这两个 table 其实是一样的。
用 H 和 unembedding 做乘法,得到 [1,128k],这正好对应整个词汇表的每个词,这个结果称为 Logits。
再用 Logits 做 softmax 转成概率,同时还有 temperature 的配置,并且降低之前出现过词语的 Logits。这时候的概率,正好对应词汇表中下一个词应该是什么的概率。
再结合 Top-K(保留概率最高的前K个)、Top-P(保留概率累计超过P的前n个),求交集,在保留下的 token 里面随机选一个,拼接到整个句子后面,再把新句子从头开始新一轮的循环。
直到达到 output token 的限制,或者标记为 EOS (End of Sequence) 的 token,整个推理过程完毕。