Transformer的并行计算与长序列处理瓶颈

7 阅读2分钟

Transformer相比RNN(循环神经网络)的核心优势之一是天然支持 并行计算,这源于其自注意力机制和网络结构的设计.并行计算能力长序列处理瓶颈是其架构特性的两个关键表现:

  • 并行计算:指 Transformer 在训练 / 推理时通过矩阵运算并行化、模块独立性实现高效计算的能力;

  • 长序列处理瓶颈:指当输入序列长度(n)增加时,自注意力机制的计算 / 内存复杂度呈O(n²)增长,导致效率骤降的问题。

  1. 并行计算

1.自注意力机制的并行性

自注意力的计算公式为:

Attention(Q,K,V)=softmax(QKTdk)\text{Attention}(Q,K,V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})

对于序列长度为nn的输入,自注意力中每个位置的计算不依赖其他位置的中间结果

  • 计算QKVQ、K、V的线性变换时,所有token的qikiviq_i、k_i、v_i可同时生成(并行);
  • 计算QKTQK^Tn×nn×n的分数矩阵)时,每个元素score(i,j)score(i,j)的计算独立于其他元素(可并行);
  • 即使是softmax和加权求和步骤,也可对整个序列的所有位置同时执行(并行)。

而RNN需要按序列顺序计算(hih_i依赖hi1h_{i-1}),完全串行,无法并行。

2. 网络结构的并行性

  • 编码器 / 解码器 层的 并行:编码器的每一层(多头注意力+前馈网络)对整个序列的处理是“批量”的,所有token共享层参数,可同时更新;
  • 训练时的 并行 优化:结合数据并行(同一模型在不同样本上并行训练)、模型并行(将网络层拆分到不同设备),可充分利用GPU/TPU的并行计算能力,大幅加速训练。

核心观点:Transformer的并行能力源于模块独立性和矩阵运算的可并行性。

  1. 底层:矩阵运算天然支持并行(GPU的SIMD架构可并行处理矩阵元素);

  2. 中层:模块独立(前馈网络对每个位置的计算独立;多头注意力的“头”之间无依赖);

  3. 顶层:训练时可通过批处理(batch维度)、序列分片进一步提升并行效率。

根本原理:并行能力源于“计算单元的独立性”和“矩阵运算的可拆分性”。

  • 前馈网络:对序列中每个位置的计算是独立函数(FFN(x_i) = W2·ReLU(W1·x_i + b1) + b2),无跨位置依赖,可完全并行;

  • 多头注意力:每个“头”的计算独立(head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)),头之间可并行;

  • 矩阵运算:QK^T的每个元素(QK^T)[i][j] = Q[i]·K[j],元素间无依赖,可由GPU并行计算。

  1. 长序列瓶颈

长序列处理的核心瓶颈

当序列长度nn增大(如文档级文本、长视频帧、基因组序列,nn可达10410^4甚至10510^5),Transformer的性能会急剧下降,核心瓶颈来自自注意力的O(n2)O(n²)复杂度

1. 计算复杂度瓶颈

自注意力的核心步骤(QKTQK^T矩阵乘法)的计算量为O(n2d)O(n²·d)dd为隐藏层维度):

  • n=1000n=1000时,计算量约为106d10^6·d
  • n=10000n=10000时,计算量增至108d10^8·d(是前者的100倍)。

这种平方级增长会导致:

  • 单次前向/反向传播时间大幅增加(训练/推理变慢);
  • 难以利用并行计算优势(过多计算量超出硬件算力上限)。

2. 内存瓶颈

自注意力过程中需要存储多个n×nn×nn×dn×d的中间张量:

  • QKQ、K、的形状为(n,d)(n,d),总内存为O(3nd)O(3nd)
  • QK^的分数矩阵形状为(n,n)(n,n),内存为O(n2)O(n²)
  • 注意力权重矩阵(softmax结果)同样为(n,n)(n,n),内存O(n2)O(n²)

n=10000n=10000时,n2=108n²=10^8,若每个元素为4字节(float32),仅分数矩阵就需要400MB内存,加上其他张量,单头注意力就可能占用数GB内存,远超普通GPU的显存上限(如16GB GPU难以处理n=20000n=20000的序列)。

3. 优化器的额外负担

训练时,优化器(如Adam)需要存储所有参数的梯度和动量信息,长序列会导致中间变量(如注意力权重的梯度)的内存占用也随n2增长,进一步加剧内存压力。

三、长序列处理的解决方案

为突破O(n2)O(n²)瓶颈,研究者提出了多种优化思路,核心是用“稀疏注意力”或“线性复杂度注意力”替代全局注意力

1. 稀疏注意力(Sparse Attention)

仅计算部分位置的注意力,将复杂度降至O(nw)O(n·w)ww为局部窗口大小):

  • 滑动窗口注意力(如Longformer):每个位置仅关注左右ww个相邻位置(总窗口2w+12w+1),适合时序相关的长序列;
  • 固定稀疏模式(如BigBird):每个位置关注“局部窗口+随机采样+全局标记”,兼顾局部相关性和全局信息;
  • 轴向注意力(如Axial Transformer):将长序列拆分为多个维度(如文本拆分为“句-词”),在每个维度单独计算注意力,复杂度降至O(nn)O(n·\sqrt{n})

2. 线性注意力(Linear Attention)

用“核函数”替换QKTQK^T的矩阵乘法,将复杂度降至O(nd)O(n·d)

  • 核心思路:将softmax(QKT/d)V\text{softmax}(QK^T/\sqrt{d})V改写为KT(softmax(QKT/d)TV)Z\frac{K^T(\text{softmax}(QK^T/\sqrt{d})^T V)}{Z}ZZ为归一化项),通过核函数(如exp(qk)\exp(q·k))的性质,将矩阵乘法转化为逐元素操作;
  • 代表模型:Performer(用随机特征映射近似核函数)、Linformer(用低秩矩阵近似KVK、V)。

3. 分层/压缩注意力

通过“序列压缩”减少有效长度:

  • ** hierarchical Attention**:先对长序列分块,计算块内注意力得到“块表示”,再计算块间注意力(如文档先分句子,再对句子表示计算注意力);

  • Downsampling:用池化(如平均池化)或卷积将长序列压缩为短序列(如ViT中的Patch Embedding将图像压缩为n=14×14n=14×14的patch序列)。

核心观点:长序列处理瓶颈源于自注意力的全连接关联特性,导致复杂度随长度平方增长。分层展开:

  1. 底层:自注意力需计算“每个位置与所有位置”的关联(QK^T矩阵为n×n);

  2. 中层:计算复杂度O(n²d)d为隐藏维度)、内存占用O(n²)(存储注意力权重);

  3. 顶层:当n过大(如n>10k),计算耗时、内存溢出,效率骤降。

根本原理:自注意力的“全关联定义”导致复杂度随长度平方增长,是机制固有属性。

自注意力的核心公式为:

Attention(Q,K,V) = softmax((QK^T)/√d_k)·V

其中QK^Tn×n矩阵(n为序列长度),其计算/存储复杂度必然是O(n²);即使优化实现(如稀疏化),也只能降低系数,无法改变O(n²)的本质(因“注意力”定义本身要求衡量位置间的关联)。