如何从公式和矩阵操作中理解自注意力机制?(上)

345 阅读3分钟

自注意力机制有多重要?

当前AI技术已经是一个社会热点现象,新闻中充斥着各种相关报道。总体来说,这一波AI的热点来自于2023年OpenAI公司推出的ChatGPT产品,一款聊天机器人。相比于传统的聊天机器人,ChatGPT更加智能,能理解人文字表达,听懂笑话,甚至写工作总结、写邮件、写文章、写代码。

ChatGPT依赖的底层的技术是大语言模型(LLM),而LLM起源于一篇名为《Attention Is All You Need》的论文。论文中提出了一个新的Transformer架构,此架构由编码器和解码器构成。它们的核心是由自注意力机制构建出的自注意力层。

自注意力机制(Self Attention)是这一波AI技术(大语言模型)的核心,相比之前的卷积神经网络和循环神经网络它大大提升了AI的能力,谈到AI技术就绕不开它。

如何简单理解自注意力机制?

人类看到一个画面时,对画面上不同部分关注度是不一样的。例如如果是画面中是人物,那么下意识会关注人脸。人类对世界中事物不同关注度,可以称之为注意力。

人类的注意力如果用数学来表达,那么可以认为人类总的注意力AA是100%,不同事物TiT_i占有一个0到100%的注意力权重AiA_i。我们可以用AATT的加权和来表达人类观察的最终结果。

Transformer中的自注意力机制是在对原始数据进行反复的变换。这种变换可以类比成人类的注意力机制,其运算中会包含所有数据,所以相比循环神经网络不容易忘记之前的状态。在反向传播的过程中,参数矩阵会被反复调整,相当于是在改上面说到的AiA_i,以此让最终的变换满足需要。

类似于让一个人在充满水果的图片中寻找苹果,那么对苹果的注意力权重自然增加了。如果模型希望解决的问题就是找苹果,那么训练过程中就会增加苹果的权重,让结果偏向苹果。

自注意力机制中的“自”的意思是从数据本身来得到,而不是外界赋予权重。例如在大语言模型中,训练的数据是文本语句,输入句子的前几个单词后希望模型可以预测下个单词是什么。此种状况下,正确答案是包含在数据中的,训练过程就是使用反向传播来调整参数,让其尽可能靠近答案。

以上是对自注意力机制的表面理解,若希望理解实质在做什么则需要去到公式和矩阵操作中,熟悉数据的变换流程。以上描述只是对矩阵操作流程的一个解读,不同人有不同的理解。

如何将一句话变成输入数据?

大语言模型的输入是文本,具体来说可以认为是一个句子。处理输入数据时,目的是将句子变为数字的矩阵。

这里我们用”无可奈何花落去,似曾相识燕归来“为例,假设我们为所有汉字和标点符号都赋予一个编号且以此句作为编号的开始,那么此句就可以表达为一组有15个数字的行向向量。例如如下:

x=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]\vec x = [0, 1, 2, 3, 4,5,6,7,8,9,10,11,12,13,14]

在大语言模型中需要对数据进行嵌入和位置编码,限于篇幅,此处不再展开。简单理解就是对x\vec x中的元素进行变换,最终使得其中的数字变成一个维度为dmodeld_{\text{model}}的行向量。dmodeld_{\text{model}}是一个重要的维度数值,后续的许多矩阵变换都涉及到

于是x\vec x变成了一个形状为15×dmodel15\times d_{\text{model}}的矩阵,假设我们用 n 来替代15, 那么最终的输入矩阵可以表示为:

X=[x11x12x1dmodelx21x22x2dmodelxn1xn2xndmodel]\mathbf{X} = \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1d_{\text{model}}} \\ x_{21} & x_{22} & \cdots & x_{2d_{\text{model}}} \\ \vdots & \vdots & \ddots & \vdots \\ x_{n1} & x_{n2} & \cdots & x_{nd_{\text{model}}} \end{bmatrix}

X\mathbf{X}的行向量可以是句子中的一个字或符号被变化后的结果,最终文本数据使用矩阵表达了出来,为后续的矩阵变变换做好了准备。

为了简化表达,以下dmodeld_{\text{model}}记为dmd_m

如何从X\mathbf{X}得到Q\mathbf{Q}K\mathbf{K}V\mathbf{V}矩阵?

首先自注意力机制的核心公式如下:

Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(\mathbf{Q},\mathbf{K},\mathbf{V})=\text{softmax}(\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{d_k}})\mathbf{V}

其中Q\mathbf{Q}代表代表查询(Query)矩阵; K\mathbf{K}代表键(Key)矩阵;V\mathbf{V}代表值(Value)矩阵;dkd_k代表键矩阵的列数。公式的最终结果是一个矩阵Z\mathbf{Z}

想要计算最终的注意力矩阵,需要首先得到查询矩阵、键矩阵和值矩阵,计算他们的公式如下:

Q=XWQ\mathbf{Q}=\mathbf{X}\mathbf{W}_{Q}
K=XWK\mathbf{K}=\mathbf{X}\mathbf{W}_{K}
V=XWV\mathbf{V}=\mathbf{X}\mathbf{W}_{V}

三者的计算使用相同的输入X\mathbf{X}和不同的参数矩阵,所以可以认为三者是对原始输入矩阵的不同变换结果。

由于是矩阵的乘法,所以输入矩阵和参数矩阵的维度需要匹配。即输入矩阵的形状为n×dmn\times d_m,参数矩阵的形状应该为dm×dq/k/vd_m \times d_{q/k/v}。其中dqd_qdkd_k必须相同,原因在后面的矩阵操作中可以看到。dvd_v在数学上可以不同,但是为了将结果最终结果作为下一个注意力层的输入,也会使其和dkd_kdqd_q相同。

确定维度后,以Q\mathbf{Q}为例可以得到以下的矩阵操作:

Q=[x11x12x1,dmx21x22x2,dmxn1xn2xn,dm][w11w12w1,dkw21w22w2,dkwdm,1wdm,2wdm,dk]=[q11q12q1,dkq21q22q2,dkqn1qn2qn,dk]\mathbf{Q}= \begin{bmatrix} x_{11} & x_{12} & \cdots & x_{1, d_{m}} \\ x_{21} & x_{22} & \cdots & x_{2,d_{m}} \\ \vdots & \vdots & \ddots & \vdots \\ x_{n1} & x_{n2} & \cdots & x_{n,d_{m}} \end{bmatrix} \begin{bmatrix} w_{11} & w_{12} & \cdots & w_{1,d_k} \\ w_{21} & w_{22} & \cdots & w_{2,d_k} \\ \vdots & \vdots & \ddots & \vdots \\ w_{d_{m},1} & w_{d_{m},2} & \cdots & w_{d_m,d_k} \end{bmatrix}=\begin{bmatrix} q_{11} & q_{12} & \cdots & q_{1,d_k} \\ q_{21} & q_{22} & \cdots & q_{2,d_{k}} \\ \vdots & \vdots & \ddots & \vdots \\ q_{n1} & q_{n2} & \cdots & q_{n,d_{\text{k}}} \end{bmatrix}

如上矩阵可以得到Q\mathbf{Q}的形状是n×dkn\times d_k,一个字的向量表达在计算中被变换成了一个维度不同的行向量。

获得K\mathbf{K}的矩阵操作过程与Q\mathbf{Q}完全相同,只是参数矩阵的数值不同,所以他们的形状相同。

QKT\mathbf{Q}\mathbf{K}^T的意义是什么?

核心公式中KT\mathbf{K}^T代表对键矩阵的转置,所以KT\mathbf{K}^T的形状变为了dk×nd_k \times n

公式下一步是查询矩阵与键矩阵的转置矩阵相乘,从矩阵的形状变化看,即n×dkn\times d_k乘上dk×nd_k \times n得到n×nn\times n。矩阵表达为:

QKT=[q11q12q1dqq21q22q2dqqn1qn2qndq][k11k21kn1k12k22kn2k1dkk2dkkndk]=[(QKT)11(QKT)12(QKT)1n(QKT)21(QKT)22(QKT)2n(QKT)n1(QKT)n2(QKT)nn]\mathbf{Q}\mathbf{K}^T = \begin{bmatrix} q_{11} & q_{12} & \cdots & q_{1d_q} \\ q_{21} & q_{22} & \cdots & q_{2d_q} \\ \vdots & \vdots & \ddots & \vdots \\ q_{n1} & q_{n2} & \cdots & q_{nd_q} \end{bmatrix} \begin{bmatrix} k_{11} & k_{21} & \cdots & k_{n1} \\ k_{12} & k_{22} & \cdots & k_{n2} \\ \vdots & \vdots & \ddots & \vdots \\ k_{1d_k} & k_{2d_k} & \cdots & k_{nd_k} \end{bmatrix} = \begin{bmatrix} (QK^T)_{11} & (QK^T)_{12} & \cdots & (QK^T)_{1n} \\ (QK^T)_{21} & (QK^T)_{22} & \cdots & (QK^T)_{2n} \\ \vdots & \vdots & \ddots & \vdots \\ (QK^T)_{n1} & (QK^T)_{n2} & \cdots & (QK^T)_{nn} \end{bmatrix}

了解了矩阵操作的流程后,再来看看这样的操作有什么意义?

Q\mathbf{Q}的行向量中代表一个字的向量表达,KT\mathbf{K}^T的列向量代表一个字的向量表达。做矩阵乘法时,对于Q\mathbf{Q}的一个行向量,是在和后者的列向量做向量的点积。向量点积数学上是代表两个向量的相似性,越是相似则结果越大。

有趣的地方就来了,结果矩阵的一行代表的意义从高层面来理解就是某个字相对句子中所有字的某种关系的得分。之所以是某种关系,是因为这种关系得分是由WQ\mathbf{W}_QWK\mathbf{W}_K来决定的。

在训练过程中,根据某种规则来修改参数矩阵则可以使得某个字和某个字具有最强的某种关系。这也被称之为注意力得分,训练的过程就是想要让模型能够得出合适的注意力得分。


总结一下,本文首先讨论了自注意力机制在AI中的重要性,接着从非公式和非矩阵操作简单理解自注意力机制。最后引入公式和矩阵,说明自注意力机制的计算过程。

为了让阅读更加轻松,对自注意力机制的理解分为了上下两篇,下篇将讨论如何将注意力得分转换为注意力权重?如何获得多头注意力?多个自注意力层之间是如何连接的?等问题。

欢迎关注,一起思考AI问题。