Attention is all you need

0 阅读1分钟

今天看的是《Attention is All You Need》浅读(简介+代码),分享自本人的师兄。 感觉这篇文章很碎片化的讲解了对论文的理解,所以我也分模块记录一下自己的收获吧(先放一张可爱的流程图hh)

Pasted image 20260328195901.png

Attention

论文里官方给的Attention定义是:Attention(Q,K,V)=softmax(QKdk)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V 那么在深入理解这个公式之前呢,我们有必要先从宏观上理解一下到底什么是Attention,我自己对它的理解是,他是一种机制,或者说一个流程,它能够更新输入的向量InputInput中的每一个元素,使得更新后的元素包含了对其他元素的理解。 比方说我现在有一个输入:猫 抓 老鼠,然后我们把每个词嵌入到一个四维的向量空间中去(意思就是说每一个词用四个数字来表示,比如[0.23,1.2,0.98,0.17][0.23,1.2,0.98,-0.17]),也就是 X=[101001011111]X = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1\\ 1 & 1 & 1 & 1 \end{bmatrix} 或者也可以抽象化的表示:X=[X1,X2,X3]X=[{X_1},{X_2},{X_3}],只不过这里的每一Xi{X_i}代表一个个四维的向量。

在这个阶段,“猫”仅仅是 [1,0,1,0][1, 0, 1, 0],它不知道自己后面跟着一个动词,也不知道它抓的是老鼠。注意力机制的目的,就是让这些孤立的向量通过互相“交流”,融合上下文信息。 所以当我们把这个XX输入到Attention之后,就会得到一个融合了上下文信息的新的向量ZZ: Z=[1.4851.941.031.4851.4851.031.941.4851.901.901.901.90]Z = \begin{bmatrix} 1.485 & 1.94 & 1.03 & 1.485 \\ 1.485 & 1.03 & 1.94 & 1.485 \\ 1.90 & 1.90 & 1.90 & 1.90 \end{bmatrix} 所以,如果以“猫”为例,实际上我们完成了这样一个变化: 初始 Xcats=[1,0,1,0]Attention Mechanism最终 Zcats=[1.485,1.94,1.03,1.485]\text{初始 } X_{cats} = [1, 0, 1, 0] \quad \xrightarrow{\text{Attention Mechanism}} \quad \text{最终 } Z_{cats} = [1.485, 1.94,1.03,1.485]

输出的张量 ZZ 就是大语言模型进入下一层神经网络的输入。 最初的 XX 是一个静态的词典向量;而现在的 ZZ,是一个“上下文感知”(Contextualized)**的向量。 在 ZZ 张量中,第一行虽然依然对应“猫”这个位置,但它里面已经融入了“抓”和“老鼠”的 VV 值特征。现在的它不再是字典里那个抽象的“猫”,而是“那只正在抓老鼠的猫”。这就是 Attention 能够理解复杂语境的底层逻辑。(是的没错这段话是Gemini那里抄来的hh)


那么到这里,我们就搞定了AttentionAttention这个概念。。。 了吗? 孩子你先看看开头的那个公式呢?Attention(Q,K,V)=softmax(QKdk)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) VAttention(Q,K,V)\mathrm{Attention}(Q, K, V)不是Attention(X)\mathrm{Attention}(X)吧?所以说,我们先要对XX进行一层处理,使XX通过某种方式映射到Q,K,VQ,K,V才行(其实我这样讲有点马后炮了,没有触及到思维层面的本质,不过暂时先这样写着吧hh),那么具体是怎么样的呢,答案是通过权重矩阵 WQ,WK,WVW^Q, W^K, W^V查询权重矩阵 WQW^Q (决定词去寻找什么):

WQ=[1010010110010110]W^Q = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 0 & 0 & 1 \\ 0 & 1 & 1 & 0 \end{bmatrix}

键权重矩阵 WKW^K (决定词对外展示什么):

WK=[0101101001101001]W^K = \begin{bmatrix} 0 & 1 & 0 & 1 \\ 1 & 0 & 1 & 0 \\ 0 & 1 & 1 & 0 \\ 1 & 0 & 0 & 1 \end{bmatrix}

值权重矩阵 WVW^V (决定词被关注时提供什么内容):

WV=[1010010100111100]W^V = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 0 & 0 & 1 & 1 \\ 1 & 1 & 0 & 0 \end{bmatrix}

看到这里,你可能会问:这三个矩阵是怎么得到的?好问题,但是这个我们后面再讲

通过这三个矩阵,我们这样操作: Q=X×WQQ = X \times W^QK=X×WKK = X \times W^KV=X×WVV = X \times W^V,最终得到: Q=[201102112222]Q = \begin{bmatrix} 2 & 0 & 1 & 1 \\ 0 & 2 & 1 & 1 \\ 2 & 2 & 2 & 2 \end{bmatrix} K=[021120112222]K = \begin{bmatrix} 0 & 2 & 1 & 1 \\ 2 & 0 & 1 & 1 \\ 2 & 2 & 2 & 2 \end{bmatrix} V=[102112012222]V = \begin{bmatrix} 1 & 0 & 2 & 1 \\ 1 & 2 & 0 & 1 \\ 2 & 2 & 2 & 2 \end{bmatrix} 所以我们接下来就是讲这个三个东西输入进去就行了


讲完了输入的处理,接下去我们正式进入公式的内部,为了防止你看了前面就忘了后面(好吧其实是我自己忘了),我再一次把我们的明星公式请出来:Attention(Q,K,V)=softmax(QKdk)V\mathrm{Attention}(Q, K, V) = \mathrm{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V 首先我们把注意力(不是Attention)放到QK{QK^\top}上,这个乘法实际上算的是每一个词该怎么样,或者说多大程度的参考其他的词(听不懂没关系,看下面的例子就明白了),用论文里的说法,就是Raw Attention Scores(注意力得分),带入数据得到:Scores=[2686288816]Scores = \begin{bmatrix} 2 & 6 & 8 \\ 6 & 2 & 8 \\ 8 & 8 & 16 \end{bmatrix} 然后我们再把外面那个奇怪的东西softmax(QKTdk)\text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)也算进去,得到我们想要的东西: A=softmax(QKTdk)=[0.0350.2600.7050.2600.0350.7050.0180.0180.964]A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)=\begin{bmatrix} 0.035 & 0.260 & 0.705 \\ 0.260 & 0.035 & 0.705 \\ 0.018 & 0.018 & 0.964 \end{bmatrix} 再继续往下讲之前,我们有必要看看这个A到底有什么含义,其实正如前文所说,它表示每一个词该怎么样,或者说多大程度的参考其他的词,那第一行举例[0.035,0.260,0.705][0.035,0.260,0.705]这个向量代表:“猫”应该分配0.035的比例给“猫”,分配0.260的比例给”抓“,分配0.705的比例给”老鼠“

好了,现在我们已经知道了比例了,接下来就该按照这个比例去具体的分配,然后得到结果了,这个时候VV终于要用上了! Z=A×V=[0.0350.2600.7050.2600.0350.7050.0180.0180.964]×[102112012222](猫 V)(抓 V)(鼠 V)=[1.7051.9301.4801.7051.7051.4801.9301.7051.9641.9641.9641.964]Z = A \times V=\begin{bmatrix} 0.035 & 0.260 & 0.705 \\ 0.260 & 0.035 & 0.705 \\ 0.018 & 0.018 & 0.964 \end{bmatrix} \times \begin{bmatrix} 1 & 0 & 2 & 1 \\ 1 & 2 & 0 & 1 \\ 2 & 2 & 2 & 2 \end{bmatrix} \begin{matrix} \text{(猫 V)} \\ \text{(抓 V)} \\ \text{(鼠 V)} \end{matrix}= \begin{bmatrix} 1.705 & 1.930 & 1.480 & 1.705 \\ 1.705 & 1.480 & 1.930 & 1.705 \\ 1.964 & 1.964 & 1.964 & 1.964 \end{bmatrix}

太好了,我们终于完成了XAttention MechanismZX \quad \xrightarrow{\text{Attention Mechanism}} \quad Z的转变!

attention 结构图 学术风.png