Multi-Head Attention

5 阅读3分钟

在不久之前我们已经搞懂了什么是Attention以及Attention机制的运作机理,那么按照《Attention is All You Need》浅读(简介+代码)的行文逻辑,接下去我们该看看那Attention机制的进阶版,也就是Multi-Head Attention机制了 回忆一下我们前面是怎么样得到修复后的结果矩阵ZZ的?其实是通过这个公式:Z=Attention(Q,K,V)Z=\mathrm{Attention}(Q, K, V) 那么我们又是怎么样得到Q,K,VQ,K,V的?答案是将输入XX点乘对应的WW矩阵,但是我们仔细思考会发现,如何找到一个optimal的W矩阵,才能使得我们得到的QKVQKV足够的好? 其实这个问题非常之重要,如果W矩阵找的不好(比如说很局限,包含的信息很少),那么我们得到的QKV必然质量不高,输入质量不高又会导致输出Z的质量不高。 Google解决这个问题的办法并不是想办法得到更好的W(感觉这会把任务变成一个串行的过程,速度很慢,并且只有一个W的话再好也是有局限的),而是采用分工合作的方式,这也是Multi一词的来源。具体来说,就是: 使用不同的WiW_i(每个WiW_i专注于观察数据的某一个特征),去乘输入矩阵X,得到多组Qi,Ki,Vi{Q_i},{K_i},{V_i},然后一组一组地把他们输入到Attention(Qi,Ki,Vi)\mathrm{Attention}({Q_i}, {K_i}, {V_i}),得到Zi{Z_i},再把所有的Zi{Z_i}concat起来得到一个ZZ,最后再通过LinearLinear层修复一下

~~但是实际上还是有出入,实际上的MultiHead AttentionMulti-Head \space Attention并不是像这样把整个X多次映射的,而是将一个X分成多个部分分别映射,比如将X的前两部分记为X1X_1,再下面两部分记为X2X_2,……XnX_n,分别得到QiKiVi{Q_i}{K_i}{V_i}之后得到了一个个维度更小的ZiZ_i,这时候终于可以把它们拼接起来了,而且我们一般会发现拼接起来的Z的维度和X应该是一样的,这样才能保证Z能作为下一层的输入,

上面这地方搞错了啊卧槽,不是把XX分成多个更小的XiX_i,而是直接用XX乘以更小的转换矩阵WQ,WK,WVW_Q,W_K,W_V,这样一来就能得到更小的Q,K,VQ,K,V了,然后就和上面一样了,“分别得到QiKiVi{Q_i}{K_i}{V_i}之后得到了一个个维度更小的ZiZ_i,这时候终于可以把它们拼接起来了,而且我们一般会发现拼接起来的Z的维度和X应该是一样的,这样才能保证Z能作为下一层的输入”


下面还是结合具体的例子讲解一下吧:

首先我们假设输入矩阵XX

X=[101202101101]X= \begin{bmatrix} 1 & 0 & 1 & 2 \\ 0 & 2 & 1 & 0 \\ 1 & 1 & 0 & 1 \end{bmatrix}

也就是 3 个 token,每个 token 4 维

  • token1: [1,0,1,2][1,0,1,2]
  • token2:[0,2,1,0][0,2,1,0]
  • token3: [1,1,0,1][1,1,0,1]

接下来,我们分成两个headhead来计算(head1,head2head_1,head_2):

head1 的投影矩阵

WQ(1)=WK(1)=WV(1)=[10010000]W_Q^{(1)}=W_K^{(1)}=W_V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0\\ 0 & 0 \end{bmatrix}​

head2 的投影矩阵

WQ(2)=WK(2)=WV(2)=[00001001]W_Q^{(2)}=W_K^{(2)}=W_V^{(2)}= \begin{bmatrix} 0 & 0\\ 0 & 0\\ 1 & 0\\ 0 & 1 \end{bmatrix}

通过前一篇文章的理解我们知道,接下去要做的一步实际是计算QK{QK^\top},这个乘法实际上算的是每一个词该怎么样,或者说多大程度的参考其他的词。我们再回头看看我们的这两个WQW_Q,我们发现实际上是非常粗暴的,X乘WQ(1)W_Q^{(1)},得到的就是X的前两维;X乘WQ(1)W_Q^{(1)},得到的则是X的后两维。这仿佛和我前面划掉部分的解释:直接把X分成X1,X2X_1,X_2是一样的?其实不然,此处只是为了展示的方便才这样子定W,在现实中,这个W的取值是经过训练得出的,它一般情况下不会这样等比例的参考局部的输入,而是会有权重的参考,比如它可能长这样:WQ(1)=[0.30.81.10.20.50.70.90.4]W_Q^{(1)}= \begin{bmatrix} 0.3 & -0.8\\ 1.1 & 0.2\\ -0.5 & 0.7\\ 0.9 & -0.4 \end{bmatrix}​这时它就不是抽维度,而是在做线性组合


Head1Head1
计算 Q(1)=XWQ(1)Q^{(1)}=XW_Q^{(1)}

X=[101202101101],WQ(1)=[10010000]X= \begin{bmatrix} 1 & 0 & 1 & 2 \\ 0 & 2 & 1 & 0 \\ 1 & 1 & 0 & 1 \end{bmatrix} ,\quad W_Q^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 1\\ 0 & 0\\ 0 & 0 \end{bmatrix}

点乘后得到: Q(1)=XWQ(1)=[100211]Q^{(1)}=XW_Q^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 2\\ 1 & 1 \end{bmatrix}

同理:

K(1)=V(1)=[100211]K^{(1)}= V^{(1)}= \begin{bmatrix} 1 & 0\\ 0 & 2\\ 1 & 1 \end{bmatrix}

然后再计算注意力得分: scores(1)=Q(1)(K(1))Tdk=[0.70700.70702.8281.4140.7071.4141.414]\text{scores}^{(1)}=\frac{Q^{(1)}(K^{(1)})^T}{\sqrt{d_k}}=\approx \begin{bmatrix} 0.707 & 0 & 0.707\\ 0 & 2.828 & 1.414\\ 0.707 & 1.414 & 1.414 \end{bmatrix}

再softmax得到:A(1)[0.4010.1980.4010.0450.7680.1870.1980.4010.401]A^{(1)}\approx \begin{bmatrix} 0.401 & 0.198 & 0.401\\ 0.045 & 0.768 & 0.187\\ 0.198 & 0.401 & 0.401 \end{bmatrix}​​最后用这个得到的比例乘上V就好:head1=A(1)V(1)=[0.8020.7970.2321.7230.5991.203]head_1​=A^{(1)}V^{(1)}=\begin{bmatrix} 0.802 & 0.797\\ 0.232 & 1.723\\ 0.599 & 1.203 \end{bmatrix}


Head2Head2

同理我们可以得到: head2[0.8981.7980.8021.0000.7161.436]{head}_2\approx \begin{bmatrix} 0.898 & 1.798\\ 0.802 & 1.000\\ 0.716 & 1.436 \end{bmatrix}

ConcatConcat

OK啊,算完所有Head之后(此处只有两个),终于来到激动人心的ConcatConcat环节: 现在两个 head 都算完了:

head1[0.8020.7970.2321.7230.5991.203]{head}_1\approx \begin{bmatrix} 0.802 & 0.797\\ 0.232 & 1.723\\ 0.599 & 1.203 \end{bmatrix} head2[0.8981.7980.8021.0000.7161.436]{head}_2\approx \begin{bmatrix} 0.898 & 1.798\\ 0.802 & 1.000\\ 0.716 & 1.436 \end{bmatrix}

concat 是按列拼接

Concat(head1,head2)=[0.8020.7970.8981.7980.2321.7230.8021.0000.5991.2030.7161.436]{Concat}(\text{head}_1,\text{head}_2) = \begin{bmatrix} 0.802 & 0.797 & 0.898 & 1.798\\ 0.232 & 1.723 & 0.802 & 1.000\\ 0.599 & 1.203 & 0.716 & 1.436 \end{bmatrix}


LinearLinear

最后,终于来到我们的LinearLinear环节啦,所谓LinearLinear,就是一个: 让Concat矩阵从各个head孤立的状态变成各个head沟通过了的状态的过程 那么怎么让这几个head彼此沟通呢,答案是乘输出矩阵WOW_O 设输出矩阵WoW_o​ 为:

WO=[1010010110010110]W_O= \begin{bmatrix} 1 & 0 & 1 & 0\\ 0 & 1 & 0 & 1\\ 1 & 0 & 0 & 1\\ 0 & 1 & 1 & 0 \end{bmatrix}

其实我们可以好好理解一下这个矩阵,虽然很简陋,但是意义非常简洁明了,比如,在当前这种矩阵所表示的沟通策略下,结果矩阵YY的左上角的元素1.700来自0.8021+0.7970+0.8981+1.7980=1.7000.802⋅1+0.797⋅0+0.898⋅1+1.798⋅0=1.700 ,它参考了head1的第一个元素和head2的第二个元素。

那么最终输出:

Y=Concat(head1,head2)WO[1.7002.5952.6001.6951.0342.7231.2322.5251.3152.6392.0351.919]Y=\text{Concat}(\text{head}_1,\text{head}_2)W_O\approx \begin{bmatrix} 1.700 & 2.595 & 2.600 & 1.695\\ 1.034 & 2.723 & 1.232 & 2.525\\ 1.315 & 2.639 & 2.035 & 1.919 \end{bmatrix}


太好了,到此为止我们终于把多头注意力机制讲完了!那么最后让我们一起看看相比单头而言,多头的优势在哪里吧!

以下内容全部来自chatGPT

1. 能同时看不同类型的关系

单头只有一套 Q,K,VQ,K,VQ,K,V 投影,所以本质上只有一种注意力模式

多头会有多套不同的投影矩阵:

WQ(1),WQ(2),…W_Q^{(1)}, W_Q^{(2)}, \dotsWQ(1)​,WQ(2)​,…

所以不同 head 可以各看各的:

  • 一个 head 看近距离依赖
  • 一个 head 看长距离依赖
  • 一个 head 看语法对应
  • 一个 head 看语义关联

也就是说:

同一个 token,可以同时从多个“视角”理解上下文。

这才是多头最本质的优势。


2. 把一个大空间拆成多个子空间来学,更容易分工

单头是在一个完整的 dmodeld_{model}dmodel​ 空间里一次性做 attention。
多头则是把它拆成多个小空间:

dmodel→h 个 dkd_{model} \rightarrow h \text{ 个 } d_kdmodel​→h 个 dk​

比如:

  • 单头:一次在 512 维里做
  • 多头:8 个 head,每个在 64 维里做

这样做的好处是:

每个 head 不用什么都学,它只学自己那部分模式就行。

这有点像:

  • 单头:一个人同时负责翻译、语法、逻辑、指代
  • 多头:4 个人分工合作

通常分工更稳,也更强。


3. 表达能力更强

单头最后只会产出一套加权平均结果

多头则是:

  1. 每个 head 各自产出一套结果
  2. concat 拼起来
  3. 再通过 WOW_OWO​ 融合

所以它不是“只做一次加权平均”,而是:

做很多次不同的加权平均,再融合。

这会让模型能表示更复杂的关系。

你可以把它理解成:

  • 单头:一个结论
  • 多头:多个角度的结论,再综合起来

4. 更不容易把不同信息混在一起

单头有个天然问题:

  • 语法关系
  • 语义关系
  • 指代关系
  • 位置关系

这些全要挤在同一个 attention 里面。

容易互相干扰。

多头的好处就是可以“隔离”一些模式:

  • 某些 head 专门抓代词指代
  • 某些 head 专门抓邻近词
  • 某些 head 专门抓结构边界

虽然这不是人工规定的,但训练后经常会出现这种现象。