在不久之前我们已经搞懂了什么是Attention以及Attention机制的运作机理,那么按照《Attention is All You Need》浅读(简介+代码) 的行文逻辑,接下去我们该看看那Attention机制的进阶版,也就是Multi-Head Attention机制了
回忆一下我们前面是怎么样得到修复后的结果矩阵Z Z Z 的?其实是通过这个公式:Z = A t t e n t i o n ( Q , K , V ) Z=\mathrm{Attention}(Q, K, V) Z = Attention ( Q , K , V ) 那么我们又是怎么样得到Q , K , V Q,K,V Q , K , V 的?答案是将输入X X X 点乘对应的W W W 矩阵,但是我们仔细思考会发现,如何找到一个optimal的W矩阵,才能使得我们得到的Q K V QKV Q K V 足够的好?
其实这个问题非常之重要,如果W矩阵找的不好(比如说很局限,包含的信息很少),那么我们得到的QKV必然质量不高,输入质量不高又会导致输出Z的质量不高。
Google解决这个问题的办法并不是想办法得到更好的W(感觉这会把任务变成一个串行的过程,速度很慢,并且只有一个W的话再好也是有局限的),而是采用分工合作的方式,这也是Multi一词的来源。具体来说,就是:
使用不同的W i W_i W i (每个W i W_i W i 专注于观察数据的某一个特征),去乘输入矩阵X,得到多组Q i , K i , V i {Q_i},{K_i},{V_i} Q i , K i , V i ,然后一组一组地把他们输入到A t t e n t i o n ( Q i , K i , V i ) \mathrm{Attention}({Q_i}, {K_i}, {V_i}) Attention ( Q i , K i , V i ) ,得到Z i {Z_i} Z i ,再把所有的Z i {Z_i} Z i concat起来得到一个Z Z Z ,最后再通过L i n e a r Linear L in e a r 层修复一下 。
~~但是实际上还是有出入,实际上的M u l t i − H e a d A t t e n t i o n Multi-Head \space Attention M u lt i − He a d A tt e n t i o n 并不是像这样把整个X多次映射的,而是将一个X分成多个部分分别映射,比如将X的前两部分记为X 1 X_1 X 1 ,再下面两部分记为X 2 X_2 X 2 ,……X n X_n X n ,分别得到Q i K i V i {Q_i}{K_i}{V_i} Q i K i V i 之后得到了一个个维度更小的Z i Z_i Z i ,这时候终于可以把它们拼接起来了,而且我们一般会发现拼接起来的Z的维度和X应该是一样的,这样才能保证Z能作为下一层的输入,
上面这地方搞错了啊卧槽,不是把X X X 分成多个更小的X i X_i X i ,而是直接用X X X 乘以更小的转换矩阵W Q , W K , W V W_Q,W_K,W_V W Q , W K , W V ,这样一来就能得到更小的Q , K , V Q,K,V Q , K , V 了,然后就和上面一样了,“分别得到Q i K i V i {Q_i}{K_i}{V_i} Q i K i V i 之后得到了一个个维度更小的Z i Z_i Z i ,这时候终于可以把它们拼接起来了,而且我们一般会发现拼接起来的Z的维度和X应该是一样的,这样才能保证Z能作为下一层的输入”
下面还是结合具体的例子讲解一下吧:
首先我们假设输入矩阵X X X :
X = [ 1 0 1 2 0 2 1 0 1 1 0 1 ] X= \begin{bmatrix} 1 & 0 & 1 & 2 \\ 0 & 2 & 1 & 0 \\ 1 & 1 & 0 & 1 \end{bmatrix} X = 1 0 1 0 2 1 1 1 0 2 0 1
也就是 3 个 token,每个 token 4 维
token1: [ 1 , 0 , 1 , 2 ] [1,0,1,2] [ 1 , 0 , 1 , 2 ]
token2:[ 0 , 2 , 1 , 0 ] [0,2,1,0] [ 0 , 2 , 1 , 0 ]
token3: [ 1 , 1 , 0 , 1 ] [1,1,0,1] [ 1 , 1 , 0 , 1 ]
接下来,我们分成两个h e a d head h e a d 来计算(h e a d 1 , h e a d 2 head_1,head_2 h e a d 1 , h e a d 2 ):
head1 的投影矩阵
W Q ( 1 ) = W K ( 1 ) = W V ( 1 ) = [ 1 0 0 1 0 0 0 0 ] W_Q^{(1)}=W_K^{(1)}=W_V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0\\ 0 & 0 \end{bmatrix} W Q ( 1 ) = W K ( 1 ) = W V ( 1 ) = 1 0 0 0 0 1 0 0
head2 的投影矩阵
W Q ( 2 ) = W K ( 2 ) = W V ( 2 ) = [ 0 0 0 0 1 0 0 1 ] W_Q^{(2)}=W_K^{(2)}=W_V^{(2)}= \begin{bmatrix} 0 & 0\\ 0 & 0\\ 1 & 0\\ 0 & 1 \end{bmatrix} W Q ( 2 ) = W K ( 2 ) = W V ( 2 ) = 0 0 1 0 0 0 0 1
通过前一篇文章的理解我们知道,接下去要做的一步实际是计算Q K ⊤ {QK^\top} Q K ⊤ ,这个乘法实际上算的是每一个词该怎么样,或者说多大程度的参考其他的词。我们再回头看看我们的这两个W Q W_Q W Q ,我们发现实际上是非常粗暴的,X乘W Q ( 1 ) W_Q^{(1)} W Q ( 1 ) ,得到的就是X的前两维;X乘W Q ( 1 ) W_Q^{(1)} W Q ( 1 ) ,得到的则是X的后两维。这仿佛和我前面划掉部分的解释:直接把X分成X 1 , X 2 X_1,X_2 X 1 , X 2 是一样的?其实不然,此处只是为了展示的方便才这样子定W,在现实中,这个W的取值是经过训练得出的,它一般情况下不会这样等比例的参考局部的输入,而是会有权重的参考,比如它可能长这样:W Q ( 1 ) = [ 0.3 − 0.8 1.1 0.2 − 0.5 0.7 0.9 − 0.4 ] W_Q^{(1)}= \begin{bmatrix} 0.3 & -0.8\\ 1.1 & 0.2\\ -0.5 & 0.7\\ 0.9 & -0.4 \end{bmatrix} W Q ( 1 ) = 0.3 1.1 − 0.5 0.9 − 0.8 0.2 0.7 − 0.4 这时它就不是抽维度 ,而是在做线性组合 。
H e a d 1 Head1 He a d 1
计算 Q ( 1 ) = X W Q ( 1 ) Q^{(1)}=XW_Q^{(1)} Q ( 1 ) = X W Q ( 1 )
X = [ 1 0 1 2 0 2 1 0 1 1 0 1 ] , W Q ( 1 ) = [ 1 0 0 1 0 0 0 0 ] X= \begin{bmatrix} 1 & 0 & 1 & 2 \\ 0 & 2 & 1 & 0 \\ 1 & 1 & 0 & 1 \end{bmatrix} ,\quad W_Q^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0\\ 0 & 0 \end{bmatrix} X = 1 0 1 0 2 1 1 1 0 2 0 1 , W Q ( 1 ) = 1 0 0 0 0 1 0 0
点乘后得到:
Q ( 1 ) = X W Q ( 1 ) = [ 1 0 0 2 1 1 ] Q^{(1)}=XW_Q^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 2\\ 1 & 1 \end{bmatrix} Q ( 1 ) = X W Q ( 1 ) = 1 0 1 0 2 1
同理:
K ( 1 ) = V ( 1 ) = [ 1 0 0 2 1 1 ] K^{(1)}= V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 2\\ 1 & 1 \end{bmatrix} K ( 1 ) = V ( 1 ) = 1 0 1 0 2 1
然后再计算注意力得分:
scores ( 1 ) = Q ( 1 ) ( K ( 1 ) ) T d k = ≈ [ 0.707 0 0.707 0 2.828 1.414 0.707 1.414 1.414 ] \text{scores}^{(1)}=\frac{Q^{(1)}(K^{(1)})^T}{\sqrt{d_k}}=\approx \begin{bmatrix} 0.707 & 0 & 0.707\\ 0 & 2.828 & 1.414\\ 0.707 & 1.414 & 1.414 \end{bmatrix} scores ( 1 ) = d k Q ( 1 ) ( K ( 1 ) ) T =≈ 0.707 0 0.707 0 2.828 1.414 0.707 1.414 1.414
再softmax得到:A ( 1 ) ≈ [ 0.401 0.198 0.401 0.045 0.768 0.187 0.198 0.401 0.401 ] A^{(1)}\approx \begin{bmatrix} 0.401 & 0.198 & 0.401\\ 0.045 & 0.768 & 0.187\\ 0.198 & 0.401 & 0.401 \end{bmatrix} A ( 1 ) ≈ 0.401 0.045 0.198 0.198 0.768 0.401 0.401 0.187 0.401 最后用这个得到的比例乘上V就好:h e a d 1 = A ( 1 ) V ( 1 ) = [ 0.802 0.797 0.232 1.723 0.599 1.203 ] head_1=A^{(1)}V^{(1)}=\begin{bmatrix} 0.802 & 0.797\\ 0.232 & 1.723\\ 0.599 & 1.203 \end{bmatrix} h e a d 1 = A ( 1 ) V ( 1 ) = 0.802 0.232 0.599 0.797 1.723 1.203
H e a d 2 Head2 He a d 2
同理我们可以得到:
h e a d 2 ≈ [ 0.898 1.798 0.802 1.000 0.716 1.436 ] {head}_2\approx \begin{bmatrix} 0.898 & 1.798\\ 0.802 & 1.000\\ 0.716 & 1.436 \end{bmatrix} h e a d 2 ≈ 0.898 0.802 0.716 1.798 1.000 1.436
C o n c a t Concat C o n c a t
OK啊,算完所有Head之后(此处只有两个),终于来到激动人心的C o n c a t Concat C o n c a t 环节:
现在两个 head 都算完了:
h e a d 1 ≈ [ 0.802 0.797 0.232 1.723 0.599 1.203 ] {head}_1\approx \begin{bmatrix} 0.802 & 0.797\\ 0.232 & 1.723\\ 0.599 & 1.203 \end{bmatrix} h e a d 1 ≈ 0.802 0.232 0.599 0.797 1.723 1.203
h e a d 2 ≈ [ 0.898 1.798 0.802 1.000 0.716 1.436 ] {head}_2\approx \begin{bmatrix} 0.898 & 1.798\\ 0.802 & 1.000\\ 0.716 & 1.436 \end{bmatrix} h e a d 2 ≈ 0.898 0.802 0.716 1.798 1.000 1.436
concat 是按列拼接 :
C o n c a t ( head 1 , head 2 ) = [ 0.802 0.797 0.898 1.798 0.232 1.723 0.802 1.000 0.599 1.203 0.716 1.436 ] {Concat}(\text{head}_1,\text{head}_2) = \begin{bmatrix} 0.802 & 0.797 & 0.898 & 1.798\\ 0.232 & 1.723 & 0.802 & 1.000\\ 0.599 & 1.203 & 0.716 & 1.436 \end{bmatrix} C o n c a t ( head 1 , head 2 ) = 0.802 0.232 0.599 0.797 1.723 1.203 0.898 0.802 0.716 1.798 1.000 1.436
L i n e a r Linear L in e a r
最后,终于来到我们的L i n e a r Linear L in e a r 环节啦,所谓L i n e a r Linear L in e a r ,就是一个:
让Concat矩阵从各个head孤立的状态变成各个head沟通过了的状态的过程
那么怎么让这几个head彼此沟通呢,答案是乘输出矩阵W O W_O W O
设输出矩阵W o W_o W o 为:
W O = [ 1 0 1 0 0 1 0 1 1 0 0 1 0 1 1 0 ] W_O= \begin{bmatrix} 1 & 0 & 1 & 0\\ 0 & 1 & 0 & 1\\ 1 & 0 & 0 & 1\\ 0 & 1 & 1 & 0 \end{bmatrix} W O = 1 0 1 0 0 1 0 1 1 0 0 1 0 1 1 0
其实我们可以好好理解一下这个矩阵,虽然很简陋,但是意义非常简洁明了,比如,在当前这种矩阵所表示的沟通策略下,结果矩阵Y Y Y 的左上角的元素1.700来自0.802 ⋅ 1 + 0.797 ⋅ 0 + 0.898 ⋅ 1 + 1.798 ⋅ 0 = 1.700 0.802⋅1+0.797⋅0+0.898⋅1+1.798⋅0=1.700 0.802 ⋅ 1 + 0.797 ⋅ 0 + 0.898 ⋅ 1 + 1.798 ⋅ 0 = 1.700 ,它参考了head1的第一个元素和head2的第二个元素。
那么最终输出:
Y = Concat ( head 1 , head 2 ) W O ≈ [ 1.700 2.595 2.600 1.695 1.034 2.723 1.232 2.525 1.315 2.639 2.035 1.919 ] Y=\text{Concat}(\text{head}_1,\text{head}_2)W_O\approx \begin{bmatrix} 1.700 & 2.595 & 2.600 & 1.695\\ 1.034 & 2.723 & 1.232 & 2.525\\ 1.315 & 2.639 & 2.035 & 1.919 \end{bmatrix} Y = Concat ( head 1 , head 2 ) W O ≈ 1.700 1.034 1.315 2.595 2.723 2.639 2.600 1.232 2.035 1.695 2.525 1.919
太好了,到此为止我们终于把多头注意力机制讲完了!那么最后让我们一起看看相比单头而言,多头的优势在哪里吧!
以下内容全部来自chatGPT
1. 能同时看不同类型的关系
单头只有一套 Q,K,VQ,K,VQ,K,V 投影,所以本质上只有一种注意力模式 。
多头会有多套不同的投影矩阵:
WQ(1),WQ(2),…W_Q^{(1)}, W_Q^{(2)}, \dotsWQ(1),WQ(2),…
所以不同 head 可以各看各的:
一个 head 看近距离依赖
一个 head 看长距离依赖
一个 head 看语法对应
一个 head 看语义关联
也就是说:
同一个 token,可以同时从多个“视角”理解上下文。
这才是多头最本质的优势。
2. 把一个大空间拆成多个子空间来学,更容易分工
单头是在一个完整的 dmodeld_{model}dmodel 空间里一次性做 attention。
多头则是把它拆成多个小空间:
dmodel→h 个 dkd_{model} \rightarrow h \text{ 个 } d_kdmodel→h 个 dk
比如:
单头:一次在 512 维里做
多头:8 个 head,每个在 64 维里做
这样做的好处是:
每个 head 不用什么都学,它只学自己那部分模式就行。
这有点像:
单头:一个人同时负责翻译、语法、逻辑、指代
多头:4 个人分工合作
通常分工更稳,也更强。
3. 表达能力更强
单头最后只会产出一套加权平均结果 。
多头则是:
每个 head 各自产出一套结果
concat 拼起来
再通过 WOW_OWO 融合
所以它不是“只做一次加权平均”,而是:
做很多次不同的加权平均,再融合。
这会让模型能表示更复杂的关系。
你可以把它理解成:
4. 更不容易把不同信息混在一起
单头有个天然问题:
这些全要挤在同一个 attention 里面。
容易互相干扰。
多头的好处就是可以“隔离”一些模式:
某些 head 专门抓代词指代
某些 head 专门抓邻近词
某些 head 专门抓结构边界
虽然这不是人工规定的,但训练后经常会出现这种现象。