自 2017 年问世以来,Transformer 架构已经彻底改变了自然语言处理(NLP)领域,推动行业范式转向具备自然语言理解(NLU)能力的模型。之所以能实现这一转变,是因为 Transformer 可以并行处理序列数据,从而比以往的顺序式模型——比如长短期记忆网络(LSTM)——更深入、更具上下文感知地理解语言。
近年来,Transformer 已经突破了 NLP 的最初边界,开始影响计算机视觉、语音识别、强化学习,甚至数学运算等广泛领域。它的适应性带来了许多重大进展,比如能够理解上下文的机器翻译,以及在科学研究中以惊人精度预测蛋白质结构。
其中最令人兴奋的发展之一,是推理模型(reasoning models)。这类模型属于更高级的大语言模型(LLM),通过强化学习训练,以执行复杂的多步推理任务。它们会在作答之前生成内部思维链,这种机制受到了人类思考过程的启发:先解决中间步骤,再得到最终答案。
在本书中,我默认你至少对 Transformer 架构已经有一定了解。也许你读过 O’Reilly 的《Natural Language Processing with Transformers》,或者类似的书。而且,我也默认你不只是对 Transformer 感到好奇。你来到这里,是因为你想用 Transformer 构建真正的应用,而且你想把这件事做对。
本章会对 Transformer 架构做一次聚焦式回顾,为后续章节中我将介绍的那些超越 NLP 的、更高级也更复杂的模型打下基础。
我会先从最基础的 Transformer 架构讲起,然后解释更长上下文是如何成为可能的,最后带你系统浏览各种注意力机制。在本章以及后续章节中,我都会穿插来自真实部署场景的实践经验,让你能够从我的经验中获益,学到当理论真正撞上生产环境时,哪些模式、陷阱和原则才是最关键的。
Transformer 基础
这一节将介绍原始 Transformer 模型中的主要架构组件,例如编码器与解码器、位置嵌入,以及注意力机制。
Transformer 架构最初是为机器翻译而设计的。机器翻译是一类具有挑战性的序列到序列任务,而在这类任务中,分词(tokenization)的概念至关重要。分词会把句子这样的序列拆分成模型能够有效处理的基本单元,也就是 token。比如,对于下面这句话:
The Transformer has revolutionized NLP.
其中单词 the 就是一个以单词为粒度的 token。
在深入架构组件之前,先理解分词非常关键,因为它决定了 Transformer 如何解释文本,也为它后来处理其他类型的序列奠定了基础。
Tokenizer:Transformer 中的文本表示
Tokenizer 的作用是把文本切分成 token。这是让自然语言能够被模型“消化”的第一步,之后才会应用 token embedding,最后再加入 positional embedding。常见的分词方式包括:
字符级分词
字符级分词会把底层字母表中的每一个字符都拆出来。如果你对下面这句话使用字符级分词:
"The Transformer has revolutionized NLP."
你会得到:
[T, h, e, ' ', …, 'N, L, P, .]
这种方式会导致序列变得非常长,从而增加计算复杂度。与此同时,模型学习长程依赖也会变得更加困难。不过,如果你的任务需要对细粒度信息有更强的理解能力,这种方式仍然可能有帮助。
词级分词
词级分词会把上面的句子拆成如下形式:
[The, Transformer, …, NLP, ., ]
也就是说,序列会按单词加标点符号进行拆分。它的缺点在于需要一个很大的词表,而且一旦语言发生变化,这种分词方式往往就无法理解新词。
子词分词
大多数现代 LLM 使用的是子词分词(subword tokenization),也就是把一个词继续拆成更小的片段。例如,一个子词 tokenizer 会把 hiking 拆成:
[h, ik, ing]
并把 cooking 拆成:
[cook, ing]
因此,子词分词会把一个单词(或序列)拆成更小、出现频率更高的片段,比如:
[ing]
其中也会包含单字符词。
现在你已经理解了分词的基础,接下来我们继续看 token embedding 和 positional embedding。
Token Embedding 与 Positional Embedding
Transformer 架构中包含可学习参数的一部分,就是 token embedding 和 positional embedding(PE)。Token embedding 的任务,是把词表中的每个元素编码成一个 维向量,也就是位于 ( ) 空间中的表示。
从数学上看,可以这样定义 token embedding:
设词表为 ,其大小为 。词表中的每个词 都会被分配一个唯一的 token ID,记作
。
Token embedding 是一个函数 , 它把每个 token ID 映射为一个 ( d_e ) 维向量。
这一映射是通过 token embedding 矩阵 实现的,其中 表示 embedding 的维度。下面以双向编码器表示模型 BERT 为例说明:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')
sentence = "The Transformer has revolutionized NLP."
inputs = tokenizer(sentence, return_tensors='pt')
input_ids = inputs['input_ids']
print(input_ids)
outputs = model(input_ids)
embeddings = outputs.last_hidden_state
print(embeddings)
从 Hugging Face 加载 tokenizer 和 model。
对句子进行分词。
获取 input IDs,并送入模型得到 embeddings。
取出 last hidden state,以访问各个 token 的向量表示。
这段代码会得到如下 input IDs 输出:
tensor([[ 101, 1996, 10938, 2121, 2038, 4329, 3550, 17953, 2361, 1012, 102]])
对应的 embeddings 输出为:
tensor([[[–0.5249, –0.2210, 0.2696, ..., –0.4204, 0.2605, 0.6457],
[–0.6665, –0.4994, 0.4651, ..., –0.2517, 0.2334, 0.0176],
[ 0.8416, –2.0561, 0.8323, ..., –0.2709, –0.1999, –0.1918],
...,
[–0.4018, –0.6402, 0.7791, ..., –0.0290, –0.4070, 0.2974],
[–0.3327, –0.8091, –0.0304, ..., 0.4745, 0.3230, –0.5991],
[ 0.4928, –0.0878, –0.0971, ..., 0.1629, –0.7012, –0.3848]]],
grad_fn=<NativeLayerNormBackward0>)
不过,这种表示并不包含单词在序列中的位置信息。而由于 Transformer 没有 recurrence,也就是说它不需要按照原始顺序逐步处理数据,因此你必须引入一个函数来表示位置。这就是为什么需要 positional embedding:如果没有它,模型就会把序列当成一个无序的词集合。
位置嵌入函数会学习如何把 token 在序列中的位置编码成 空间中的一个向量。原始 Transformer 对于位置 使用如下形式:
这里, 表示 维向量 中第 个元素。这意味着,第一个 token 的位置会由一个向量 表示,第二个 token 的位置则由另一个不同的向量 表示,以此类推。
这种技术使 Transformer 模型能够理解词语的顺序。接下来你将看到,Transformer 是如何利用这种向量表示来理解和学习文本的。
注意力机制
注意力机制是 Transformer 理解和解释文本能力的核心。它让模型能够以 token 对 token 的方式分析序列中某个词的重要性。
在这个语境下,你经常会听到 attribution matrix 这个术语,它是由输入 embeddings 计算出来的。这里的 attribution,指的是输入不同部分之间的关联重要性。这个 attribution matrix 是通过 矩阵和 矩阵计算得到的。其结果分数构成 和 之间的交互,用来决定注意力权重,而这些权重随后会应用在 矩阵上,从而得到注意力机制的输出:
这个 attribution matrix 对于理解模型如何解释和处理输入序列非常重要。比如,通过分析这些分数,你可以洞察模型在生成输出 token 时,认为哪些 token 比其他 token 更相关。像 Captum 这样的库,就可以帮助你把这种“决策过程”可视化出来。
不过,尽管 各自承担不同角色,它们最初的计算方式其实是类似的:都是对输入 embedding 做线性投影。也就是说,对于这三个矩阵,输入 embedding 都会分别与各自的权重矩阵相乘。其数学形式如下:
查询矩阵 Q:
键矩阵 K:
值矩阵 V:
这里,表示输入 embeddings, 分别表示 query、key 和 value 的投影权重矩阵。
接下来,对 query 矩阵和 key 矩阵做点积,再经过 Softmax 和缩放因子(用于 scaled dot-product attention),就会得到一个分数矩阵,也就是 self-attention 分数。这个矩阵表示:每个 token 在考虑与序列中其他所有元素关系之后,应当对其他每个 token 给予多少关注。然后,这些分数会用来对 矩阵中的值加权,最终生成注意力机制的加权求和输出:
这种动态机制使模型能够针对每一个输入 token,聚焦输入序列中的不同部分,从而理解每个 token 的上下文相关性与信息含义。
多头注意力
到目前为止你看到的注意力机制,描述的是单个 attention head 的计算过程。Attention head 是 Transformer 中负责执行注意力计算的基本组件。不过,原始 Transformer 以及当今的 SOTA 模型,都会同时使用多个 attention head。每个 attention head 都有自己独立的可学习参数,最后再把这些 head 的输出合并成一个统一的输出。
这样做的好处在于,模型能够从同一条序列中整合不同角度的信息,捕捉其中多种不同的关系。这会增强模型理解和表示数据中复杂依赖关系的能力。
更技术化地说,给定输入序列 等,多头注意力机制会在考虑 等序列信息的基础上,为 中的元素计算新的表示。这个过程包括几个步骤:每个 head 分别基于输入计算自己的 attention score 和输出向量,然后再把这些输出拼接起来,并通过线性变换得到最终输出向量 。
这个过程把多个 attention head 捕捉到的上下文信息整合成一个统一的输出,其中包含了整个输入序列中的关键信息。由于不同的 attention head 往往会关注输入序列中不同类型的关系,这对于模型形成更强的语言理解能力至关重要。
双向注意力与单向注意力
前面提到过,最早的 Transformer 模型是为机器翻译设计的。因此,它在架构中使用了两种不同的注意力机制:一种用于编码器,另一种用于解码器。
首先,编码器使用的是双向 self-attention,而不是传统序列处理方法那种单纯的从左到右处理方式。这意味着它会把整个序列中的所有 token 都当成上下文,对每一个 token 都施加注意力。这样,当模型为每个 token 生成表示时,就能够获得对整条输入序列的完整理解。
而解码器使用的是 masked attention,也叫 causal attention,其目的是防止模型看到未来 token(也就是后续位置上的 token)。在实践中,这意味着当模型预测第 个位置时,它只能关注位置小于 的内容。借助这种机制,模型会从左到右地、仅基于之前已经生成的 token 来生成下一个 token,从而避免提前利用未来信息。这对于所有必须逐 token 生成输出的任务都非常关键,比如翻译。
现在你已经理解了最早 Transformer 中使用的两种注意力变体,接下来我们来看编码器和解码器本身。
编码器与解码器结构
最早的 Transformer 模型采用的是编码器—解码器结构(见图 1-1)。而之后的一些模型则采用了纯解码器框架,例如 GPT、LLaMA、Mistral 和 Falcon。
编码器本身由六个完全相同的层组成,每一层包含两个核心组件:多头 self-attention 机制,以及一个逐位置的全连接前馈网络(pointwise fully connected feed-forward network)。这里所说的 pointwise,指的是对序列中的每个元素都应用同样的线性变换。这两个组件还会进一步结合残差连接(residual connections)和层归一化(layer normalization)。
解码器负责解释编码后的信息。它在层级结构上与编码器相似,但额外引入了一个关键特性:masked multi-head self-attention。这个特性可以防止模型访问序列中当前位置之后的内容。
图 1-1. Transformer 架构中的编码器与解码器部分。
该模型在所有子层中都保持一致的输出维度 512,包括 embedding 层在内。这意味着它的最大序列长度为 512 个 token。这个限制主要来自最初 Transformer 的具体架构设定:在当时的硬件条件下,高效处理更长的序列仍然非常困难。
Transformer 设计的增强:更长上下文与注意力变体
现在,我们来看看现代 Transformer 模型——比如 GPT-4.5 和 Qwen3——是如何实现更高性能与更高灵活性的,尤其是它们如何通过更长的上下文窗口,一次处理更多信息。与此同时,像 multi-query attention 和 flash attention 这样的注意力机制变体,也进一步提升了 SOTA Transformer 模型的效率和精度。
更长的上下文窗口与更好的性能
模型的 context window,指的是它在做预测或生成文本时能够处理的那一段文本范围。更长的上下文窗口,意味着模型可以理解更复杂的叙事结构,也能比小上下文窗口下被切块处理的文本更好地捕捉细微差别。
然而,单纯延长上下文长度,会使时间复杂度和内存使用量都呈二次增长,这会限制性能提升。因此,近年来出现的一些改进方法——比如 rotary positional embedding(RoPE)、position interpolation(PI)以及 Yet another RoPE extensioN method(YaRN)——就是为了在推理阶段更有效地管理长上下文。
RoPE 结合了绝对位置嵌入和相对位置嵌入。不过在深入解释 RoPE 的工作机制之前,我们先来看绝对位置嵌入和相对位置嵌入之间的关键区别:
使用绝对位置嵌入时,模型会为每个 token embedding 加上该 token 的绝对位置信息。绝对位置嵌入更简单,也更快。
相对位置嵌入关注的是序列元素之间的距离,而且可以跨序列共享,这有助于模型理解和解释一个序列中不同 token 之间的关系和距离。相对位置嵌入通常能带来更好的性能,但计算复杂度也更高。
RoPE 把绝对位置嵌入与相对位置嵌入结合起来,这是 Transformer 设计中的一个重要进步。采用这种机制的模型,能够更自然、更准确地处理更长的文本序列,同时保持较好的效率。
更具体地说,RoPE 引入了一个旋转矩阵 ,用来编码 token 的绝对位置,同时把相对位置信息显式注入 self-attention 机制。为了更直观地说明 RoPE 的实现方式,假设模型维度为 ,那么它可以写成如下形式:
更高维度的情形会被划分为 个子空间,因此维度数必须是偶数。下面我们用代码把这个理论概念具体化:
def simple_rotary_matrix(d, m, max_len):
assert d % 2 == 0, "Embedding dimension must be even."
theta = 10000 ** (-2 * torch.arange(d // 2).float() / d)
theta *= m
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
R = torch.zeros((d, d))
R[torch.arange(0, d, 2), torch.arange(0, d, 2)] = cos_theta
R[torch.arange(0, d, 2), torch.arange(1, d, 2)] = -sin_theta
R[torch.arange(1, d, 2), torch.arange(0, d, 2)] = sin_theta
R[torch.arange(1, d, 2), torch.arange(1, d, 2)] = cos_theta
return R
-
确保维度 为偶数(这是公式成立的必要条件)。
-
计算 theta。
-
计算旋转所需的正弦值与余弦值。
-
初始化旋转矩阵。
-
构造旋转矩阵。
调用该函数也很简单:
d = 6
max_len = 10
R_matrix = simple_rotary_matrix(d, m=1, max_len=max_len)
print(R_matrix)
-
定义 embedding 维度 ( d )。
-
定义序列长度。
-
生成旋转矩阵。
输出结果如下:
tensor([[ 0.5403, –0.8415, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.8415, 0.5403, 0.0000, 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.9989, –0.0464, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0464, 0.9989, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, –0.0022],
[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0022, 1.0000]])
图 1-2 展示了 RoPE 的处理过程。
图 1-2. 旋转位置嵌入示意图。图片改编自 Jianlin Su 等人(2021)。
在 self-attention 的语境中使用 RoPE 时,位置 上的 query ( ) 与位置 上的 key ( ) 之间的关系可定义为:
这里, 表示一个适配相对位置的旋转矩阵。
RoPE 能同时提升效率和精度,因此被用在 Qwen3 这样的 SOTA 模型中。但即便是最先进的 LLM,一次能够处理的 token 数量仍然是有限的。例如,Qwen3 系列模型单次输入最多可以处理 32,768 个 token。
这在长提示词或超长文档摘要等场景中会成为问题,因为这类场景通常希望 LLM 具备更大的上下文处理能力。然而,从零开始重新训练一个具备更大上下文窗口的新 LLM,需要非常高的计算成本。这就引出了一个重要问题:能否在一个已经预训练好的 LLM 基础上扩展其上下文窗口?好消息是:可以。PI 和 YaRN 都可以通过极少量微调来扩展预训练 LLM 的上下文能力。图 1-3 展示了 PI 技术如何应用在一个上下文窗口为 2048 的 LLaMA 模型上。
图 1-3. Position Interpolation(PI)方法在上下文窗口为 2048 的 LLaMA 模型上的工作方式。图中的圆点表示 LLM 训练时的位置范围,方块表示模型对新位置的适配;圆点和三角形表示 PI 如何把 的位置缩放回 ,从而保持在训练范围内。图片改编自 Shouyuan Chen 等人(2023)。
通常情况下,LLM 使用的是训练范围内的位置索引(图中的圆点)。而在长度外推时,模型需要处理新的位置索引(图中的方块),最大可达 4096。Position interpolation 会把这些索引(圆点和三角形)从 缩放到 ,确保它们仍然落在预训练范围内。
为了扩展上下文窗口,PI 会把位置索引插值回预训练限制范围内,并辅以少量微调。
也就是说,PI 会把 RoPE 的函数 扩展为 :
其中 ,表示一个超出预训练窗口的新上下文长度。
这里我想稍微退一步,先解释一个评估模型 性能的重要指标——困惑度(perplexity,PPL)。它衡量的是模型对上下文有多“惊讶”或多“困惑”。也就是说,困惑度衡量一个概率模型对样本的预测能力,数值越低,说明预测越准确。下面用一个具体的代码例子来说明:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("tiiuae/falcon-7b")
tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b")
wiki_text = tokenizer("Apple Inc. is an American multinational " +
"corporation and technology company headquartered " +
"in Cupertino, California, in Silicon Valley. ",
return_tensors = "pt")
loss = model(input_ids = wiki_text["input_ids"],
labels = wiki_text["input_ids"]).loss
ppl = torch.exp(loss)
print(ppl)
input_text = tokenizer("A Falcon is a generative transformer "+
"model and it can't fly.", return_tensors = "pt")
loss = model(input_ids = input_text["input_ids"],
labels = input_text["input_ids"]).loss
ppl = torch.exp(loss)
print(ppl)
计算 loss。
对于 wiki_text 这段输入,模型得到的分数是 5.08;而对于 input_text,得到的分数是 121.19。如此显著更高的 perplexity,说明模型觉得这句话非常“意外”或“不太可能”。原因是,模型训练数据里大概率学到的是 falcon 是一种以飞行能力著称的鸟,而不是一种生成式 Transformer 模型。
在评估长上下文窗口下的 LLM 性能时,通常会使用滑动窗口困惑度(sliding window perplexity)。这一指标会在固定大小的 token 窗口上计算 perplexity,并沿着文本不断滑动,从而更适合处理和评估大文本与大数据集。
RoPE 的一个缺点是,它会把 token 的位置信息展开到一个多维复向量中。由于其输入本质上是一维的,所以它在编码高频成分时存在困难,也就难以区分那些非常接近、非常相似的 token。
Softmax 与“大海捞针”问题
Transformer 中的注意力分布是通过 Softmax 函数计算出来的。随着上下文窗口不断变长,Softmax 往往会产生更“平”的分布。原因在于:分母,也就是所有 token 指数项之和,会随着上下文规模增长而变大;而每个分子,也就是某个 token 分数的指数值,本身却保持不变。结果就是,每个输出概率都会变小,模型会越来越难以聚焦真正重要的 token。
这通常被称作 haystack problem,也就是“大海捞针问题”:相关信号会被大量无关信息稀释掉。即使使用了像 RoPE 这样的高级技术,模型在超长上下文中优先识别关键元素的能力仍然会下降。为了解决这一问题,像 LLaMA 4 这样的 SOTA 模型会在长上下文上做后训练优化,在推理时对注意力做温度缩放,并引入一些架构级变化,比如不带位置嵌入的交错注意力层(iRoPE)。这些方法结合起来,可以把支持的上下文长度提升到最多 1000 万个 token,同时在“从草堆里找针”这类任务上依然保持良好表现。
为了解决这一问题,神经切线核(NTK)理论的研究者提出了 NTK-aware interpolation,它会在不同维度上以不同方式调整频率缩放,从而保留高频信息。NTK 理论的一个应用方向,是识别并缓解神经网络训练中的一些问题,例如学习高频成分的困难,或在内在维度较低的数据上学习模式时遇到的问题——而 RoPE 恰好就面临这类情况。所谓内在维度(intrinsic dimensionality),指的是在不丢失关键信息的前提下,准确描述一个数据集所需的最小参数数量,它反映了数据集本身的固有复杂度。
不过,NTK-aware interpolation 也可能把某些维度拉伸到边界之外,从而导致模型性能下降。此后又有人提出了 NTK-by-parts interpolation 和 dynamic NTK interpolation 作为更细化的策略:前者强调保留局部相对距离,后者则针对不同序列长度动态调整缩放因子。
在这些 NTK 技术基础之上,YaRN 在 attention Softmax 之前,为注意力分数引入了一个温度 ,并且会在不同数据样本与不同 token 位置上均匀影响困惑度。这种方法修改了注意力权重的计算方式,并使用一种长度缩放技术,同时按固定比例调整 和 ,从而在不改动底层注意力代码的情况下增强了注意力机制。RoPE embedding 可以预生成并重复使用,因此在推理和训练阶段都不会引入额外计算开销。当它与 NTK-by-parts interpolation 结合时,YaRN 在 LLaMA 和 LLaMA 2 等模型中表现非常有效。
图 1-4. 上下文窗口如何影响困惑度。图片改编自 Bowen Peng 等人(2023)。
正如你已经看到的,困惑度越低,模型表现就越好。例如,使用 YaRN 并做 128k 外推的 LLaMA 7b,明显优于未使用 YaRN 的 LLaMA 7b。
你现在大概很想知道,如何真正把 RoPE 或 YaRN 这类技术应用起来,从而增强上下文长度,并确保模型在长文本上的性能保持最佳。好消息是,大多数框架都允许你很方便地开启更长的上下文窗口。例如,vLLM 就支持 YaRN,其配置方式如下:
vllm serve Qwen3/Qwen3-8B --rope-scaling {"rope_type":"yarn","factor":4.0,"original_max_position_embeddings":32768} --max-model-len 131072
接下来,我们继续看不同的注意力变体,以及它们是如何进一步提升性能的。
注意力机制的变体
今天的 Transformer 比之前的模型——例如 LSTM——效率高得多。最初的 Transformer 模型只训练了 3.5 天,就取得了与那些需要训练数月的 LSTM 相近的高 BLEU 分数。不过,Transformer 依然可以说是“吃内存”的,因为 self-attention 的时间复杂度和空间复杂度都会随着序列长度呈二次增长。本节将介绍一些在高性能 SOTA LLM 中广泛使用的注意力机制改进方法,包括:
- Cross-attention
- Multi-query attention(MQA)
- Grouped-query attention(GQA)
- FlashAttention
- FlashAttention-2
- FlashAttention-3
模型同时组合多种注意力变体是很常见的。例如,Falcon 就同时使用了 multi-query attention 和 FlashAttention。
交叉注意力
在交叉注意力(cross-attention)中,会把来自两条序列的输入结合起来。通常情况下,query 来自解码器,而 key 和 value 来自编码器。从本质上讲,交叉注意力使一组 embeddings 能够与另一组 embeddings 发生交互。这对于那些在生成目标序列时需要关注源序列的任务非常重要,比如翻译或问答。下面用一段代码进一步说明这一概念:
def CrossAttention(x_1, x_2, W_query, W_key, W_value):
scaling_factor = W_query.shape[1]**0.5
Q = torch.einsum('bd,dk->bk', x_1, W_query)
K = torch.einsum('bd,dk->bk', x_2, W_key)
V = torch.einsum('bd,dv->bv', x_2, W_value)
attn_scores = torch.einsum('bk,mk->bm', Q, K)
attn_weights = F.softmax(attn_scores / scaling_factor, dim=-1)
Y = torch.einsum('bm,mv->bv', attn_weights, V)
return Y
从这段代码中你可以看出, 的输入来自 ,而 和 的输入来自 ,这就体现了两个序列之间的信息流动。借助多源信息,LLM 能够获得更丰富的理解,并生成更好的结果。
Multi-Query Attention
Multi-query attention(MQA)只使用一个 key-value head,而 multi-head attention(MHA)则为 query、key 和 value 分别使用 个 head。因此,MQA 能显著加快解码器的推理速度。图 1-5 对比了这两者。
图 1-5. Multi-head attention(左)与 multi-query attention(右)的比较。Multi-head attention 为 query、key、value 都使用 h 个 head;而 multi-query attention 则在所有 query head 之间共享同一组 key 和 value head。
为了让这种差异更直观,先看下面这段实现 MHA 的代码。注意, 上都带有字母 ,表示 head 维度:
def MultiheadAttention(x, M, W_query, W_key, W_value, P_o):
scaling_factor = W_key.shape[1]**0.5
Q = torch.einsum('d,hdk->hk', x, W_query)
K = torch.einsum('md,hdk->hmk', M, W_key)
V = torch.einsum('md,hdv->hmv', M, W_value)
attn_scores = torch.einsum('hk,hmk->hm', Q, K) / scaling_factor
attn_weights = F.softmax(attn_scores, dim=-1)
o = torch.einsum('hm,hmv->hv', attn_weights, V)
y = torch.einsum('hv,hdv->d', o, P_o)
return y
-
权重矩阵。
-
结合缩放因子,计算 scaled dot-product attention 的归因矩阵。
-
对注意力分数应用 Softmax。
-
得到最终注意力权重(上下文向量)。
而在 MQA 中, 和 矩阵里的被省略掉了:
def MultiqueryAttention(X, M, mask, W_query, W_key, W_value, P_o):
scaling_factor = W_key.shape[1]**0.5
Q = torch.einsum('bnd,hdk->bhnk', X, W_query)
K = torch.einsum('bmd,dk->bmk', M, W_key)
V = torch.einsum('bmd,dv->bmv', M, W_value)
attn_scores = torch.einsum('bhnk,bmk->bhnm', Q, K) / scaling_factor
attn_weights = F.softmax(attn_scores + mask, dim=-1)
O = torch.einsum('bhnm,bmv->bhnv', attn_weights, V)
Y = torch.einsum('bhnv,hdv->bnd', O, P_o)
return Y
这两段代码清楚地说明,MQA 与 MHA 的核心计算过程几乎相同,区别只在于:在 MQA 中,不同的 head 会共享同一组 keys 和 values。这个改动虽然可能会带来一定质量损失,但能显著加快解码器计算,而且总体性能仍优于传统 MHA。为了解决质量上的折中问题,后来又提出了 GQA。
Grouped-Query Attention
Grouped-query attention(GQA)会把 query heads 组织成 个组,每个组共享一组 key head 和一组 value head。图 1-6 对比了 MHA(左)与 GQA(右)。
图 1-6. Multi-head attention(左)与 grouped-query attention(右)的比较。Multi-head attention 为 query、key、value 都使用 h 个 head;而 grouped-query attention 则为每一组 query head 共享一组 key 和 value head,在 multi-head 和 multi-query 之间做折中。
把 MHA 与 GQA 对比起来看,你会发现 GQA 把多个 key/value head 合并成了单个 key/value head,从而有效减小了 key-value(KV)大小。
KV 缓存
KV caching 会在自回归解码过程中,把已经生成 token 的 key tensor 和 value tensor 缓存下来,以优化推理延迟。也就是说,模型不再需要在每一步都重新计算完整的 attention 上下文,而只需要把新生成 token 对应的 key 和 value 追加进去。这能显著降低注意力机制的计算成本。
不过,虽然 KV caching 在推理速度上有明显收益,但它的内存开销会随着序列长度和层数成比例增加。在内存成为瓶颈的场景下,你可能不得不缩小模型规模或限制上下文窗口,而这又可能导致模型精度下降。在大规模生产系统中部署 KV caching,还会引入缓存生命周期管理的复杂性,包括缓存淘汰策略、动态内存分配,以及跨请求或跨会话缓存复用策略的评估。
这也意味着,在计算过程中需要加载到内存中的数据会显著减少,使带宽和容量需求按 的因子下降。下面这段代码演示了 GQA 的基本形式:
def GroupedQueryAttention(Q, K, V, num_heads, group_size):
batch_size, seq_len, embed_dim = Q.shape
scaling_factor = (embed_dim // num_heads) ** 0.5
Q = rearrange(Q, 'b s (h d) -> (b h) s d', h=num_heads)
K = rearrange(K, 'b s (h d) -> (b h) s d', h=num_heads)
V = rearrange(V, 'b s (h d) -> (b h) s d', h=num_heads)
attn_scores = torch.einsum('bid,bjd->bij', Q, K) / scaling_factor
attn_weights = F.softmax(attn_scores, dim=-1)
attn_output = torch.einsum('bij,bjd->bid', attn_weights, V)
Y = rearrange(attn_output, '(b h) s d -> b s (h d)',
b=batch_size, h=num_heads)
return Y
GQA 对更大的模型尤其有益,因为大模型通常会扩展注意力 head 的数量。采用 GQA 后,在模型规模继续增大的同时,内存带宽和容量占用都能大幅降低,而性能仍然可以保持。
因此,在大模型中,attention 带来的内存带宽开销影响会相对更小。原因在于:KV cache 的大小会随着模型维度线性增长,而模型的 FLOPs 和参数量则会随着模型维度二次增长。
即使有了这些改进,attention 对 GPU 内存的利用方式仍然还有优化空间。这也就是 FlashAttention 和 FlashAttention-2 出现的原因。
FlashAttention
FlashAttention 通过 tiling(分块)的方式重新组织 attention 计算流程,从而避免显式构建一个 的注意力矩阵。所谓 tiling,就是把输入数据的块在 GPU 的高带宽显存(HBM)和 GPU 片上 SRAM(高速缓存)之间来回搬运。FlashAttention 会在外层循环中遍历 和 矩阵的各个分块,并将它们加载到高速缓存中;在每个分块内部,再遍历 矩阵的若干小块,把这些小块也放进 SRAM,之后执行 attention 计算,并把结果写回 HBM(见图 1-7)。
图 1-7. FlashAttention 通过 tiling 消除了庞大的 注意力矩阵。它在外层循环中遍历 和 的分块(红色箭头),把这些分块加载到快速片上 SRAM 中;对于每一个分块,再处理 Q 的分块(浅灰色箭头),把它们加载到 SRAM 中,然后把 attention 输出写回 HBM。图片改编自 Tri Dao 等人。
这种做法既提升了计算速度,又把内存消耗从相对于序列长度的二次增长,降到了线性增长。FlashAttention 不再需要把巨大的中间 attention 矩阵保存到 HBM 中,从而显著减少内存操作,处理速度甚至可以提升 2 倍到 4 倍。此外,FlashAttention 还使 Transformer 能够支持更长的上下文窗口,从而带来更好的困惑度表现,也就是更高质量的模型。
这已经很出色了,但仍然还有进一步优化空间。例如,还可以继续减少非矩阵乘法 FLOPs 的数量。接下来我们就看这个方向。
FlashAttention-2
前面提到过,扩展 Transformer 的上下文窗口并不容易。核心 attention 层的运行时间和内存开销都会随着输入序列长度呈二次增长。RoPE、PI 和 YaRN 在提升效率和降低困惑度方面都有帮助,前面你已经看过。
FlashAttention-2 在不改变输出结果的前提下,进一步减少了非 matmul FLOPs。虽然这些非矩阵乘法 FLOPs 在总 FLOPs 中所占比例不高,但它们执行起来更慢。GPU 拥有专门的硬件单元,可以让矩阵乘法的执行速度比非矩阵乘法快最多 16 倍。因此,减少非矩阵乘法 FLOPs,并尽可能让更多时间花在矩阵乘法 FLOPs 上,是提升计算速度的关键。
FlashAttention-2 的实现方式,是优化 GPU 资源的利用率。它通过跨不同 thread block 的并行计算,以及在单个 thread block 内不同 warp 之间更合理的任务划分,来减少共享内存访问。这里的 warp,指的是一组同时执行计算的线程。这些调整可以带来 2 到 3 倍的加速。
这种方法还会反转循环层级:把原本的外层循环改为优先处理 attention 矩阵的行分块,而内层循环处理列分块。也就是说,它把 FlashAttention 原来的处理顺序反了过来,并引入了沿序列长度维度的并行处理。图 1-8 展示了这一点。
图 1-8. 在前向传播中(左),任务(thread blocks)以并行方式分发,每个任务负责 attention 矩阵的一段行;在反向传播中(右),每个任务负责 attention 矩阵的一段列。图片改编自 Tri Dao 等人(2022,2023)。
图 1-9 对比了 FlashAttention 和 FlashAttention-2 在前向传播过程中,不同 warp 之间的任务划分方式。对于包括 Transformer 在内的深度学习模型来说,如何高效地在 warp 之间分配工作,会显著影响并行计算性能。
图 1-9. FlashAttention(左)与 FlashAttention-2(右)在前向传播中不同 warp 的任务划分对比。图片改编自 Tri Dao 等人(2022,2023)。
FlashAttention-3
FlashAttention-3 引入了一系列新的编程技术,能够充分利用 Hopper GPU 架构——尤其是 NVIDIA H100——把 attention 计算加速到此前方法难以达到的水平。虽然 FlashAttention-2 在大多数 GPU 上表现已经很好,但在 H100 这样的新架构上,它的 GPU 利用率仍然只有 35%。
FlashAttention 和 FlashAttention-2 的重点,在于减少内存带宽占用并优化计算调度;而 FlashAttention-3 则进一步利用硬件异步能力以及 FP8 这样的低精度格式,提升性能。其中一个关键创新,是 producer-consumer asynchrony:把不同的 GPU warp 分配成不同角色,一部分 warp 作为 producer,通过 Tensor Memory Accelerator(TMA)加载 ;另一部分 warp 作为 consumer,在 Tensor Cores 上执行矩阵乘法。这种策略通常被称作 pingpong scheduling,它允许数据传输和计算并发执行,从而有效隐藏延迟并最大化吞吐量。
PagedAttention:更高吞吐量
虽然 FlashAttention-3 通过充分利用 Hopper 架构和 FP8 等低精度格式实现了巨大提升,但它基本只针对 H100 GPU 做了优化。而 H100 在云上运行通常非常昂贵。因此,对于大多数团队和生产环境来说,PagedAttention 是一种更容易落地、成本也更可控的方案,能够在不依赖专用硬件的情况下提高推理吞吐量。
PagedAttention 是一种内存高效的 attention 变体,专门用于提升 LLM 推理阶段的吞吐量。前面你已经看到过我对 KV caching 的说明,而如果你正在评估如何优化 KV caching,PagedAttention 正是对应的解决方案。它会把 KV cache 存储在非连续的内存块中,这和操作系统里的虚拟内存分页机制很像。这些内存块可以动态分配、在多个序列之间共享,并利用 copy-on-write 语义进行复用。
PagedAttention 内置于 vLLM 服务系统中,通过减少 KV cache 的浪费,并允许更多请求一起做 batching,吞吐量最高可以提升 4 倍。对于那些包含长序列、解码长度变化大,以及 beam search 或 parallel sampling 这类复杂推理算法的工作负载,PagedAttention 特别有价值。需要注意的是,PagedAttention 只在 vLLM 中可用。而且,如果并发请求很多,vLLM 也可能出现瓶颈,此时整体吞吐量未必一定优于 Hugging Face 的 Text Generation Inference(TGI);后者在高并发场景下一向比较稳定。我建议你使用 TGI 的基准测试工具,针对自己的应用场景做验证。
FlashAttention-3 的另一项创新,是 GEMM-Softmax pipelining。一般矩阵乘法(GEMM)是深度学习中的基础操作,它会把两个矩阵相乘得到第三个矩阵,并且在 GPU 上通过 Tensor Cores 等专用硬件得到高度优化。在 Transformer 中,Softmax 操作依赖于 GEMM 的输出,因此天然存在顺序依赖。FlashAttention-3 通过在不同迭代之间把 GEMM 和 Softmax 做流水线化,打破了这一瓶颈:当一个 block 正在执行 Softmax 时,下一个 GEMM 已经可以启动。这种重叠执行对于发挥 Hopper 的异步计算能力至关重要。
FlashAttention-3 还引入了基于 FP8 的低精度 attention,其吞吐量几乎是 FP16 的两倍。为了在不牺牲精度的前提下实现这一点,它会调整 的内存布局,以满足 Hopper 对 FP8 GEMM 的约束,同时使用两种技术来降低量化误差:block quantization 和 incoherent processing。后者的做法,是在量化之前先用一个由 Hadamard transform 构造出来的随机正交矩阵去乘 和 。Hadamard transform 是一种数学变换,它通过只包含加法和减法的方式,把一个向量映射到新的空间中。它依赖于 Hadamard 矩阵,而这个矩阵的元素只会是 和 。这种变换计算效率很高,并且能够把信息打散到不同维度上,因此对抑制低精度量化中的离群值影响很有帮助。
总结
本章带你从最初的 Transformer 基础思想,一路走到了定义今天 SOTA 模型的一些最强架构与推理优化技术。从分词与多头注意力,到旋转位置嵌入、更长的上下文窗口,以及像 PagedAttention 和 FlashAttention 这样的高级内存优化手段,你已经看到了这个架构是如何一步步演化,以满足真实世界应用中不断增长的需求的。
这一演进过程本身已经说明,Transformer 不再是一个局限于语言任务的静态蓝图。它已经成为一种动态的、可扩展的框架,并且还在持续提升自己的准确性与效率。
在接下来的章节中,我们会走出纯语言场景,进一步探索这些模型——以及本章介绍的架构改进——如何在视觉、时间序列、强化学习和结构化推理等领域带来突破。你将学会如何在实践中应用这些工具,也会学会如何依据具体问题空间的需求,做出合适的架构选择。