面试官:Transformer如何优化到线性级?

69 阅读3分钟

面试官:我们来聊聊Transformer吧,Transformer的复杂度为什么这么高?有什么办法能优化到线性级吗?

面对这种原理与优化问题,其实都是有模板的,下面我们来看一看该怎么回答。

所有相关源码示例、流程图、模型配置与知识库构建技巧,我也将持续更新在Github:LLMHub,欢迎关注收藏!

一、先看原始的Self-Attention在干嘛

输入序列长度是 n,每个token是一个向量,维度是 d


Self-Attention要做三步:

  1. 把输入投影成 Q, K, V,维度都是 n × d。
  2. 算相似度矩阵:image

这步是关键!QKᵀ的结果是 n × n

  1. 然后加权求和:image

所以,

  • 计算量:QKᵀ是 O(n²*d)
  • 存储量:A 是 O(n²)

这就是Transformer卡顿的根源——n²炸裂增长

二、如何优化?

思路其实很直接:别算完整的 n×n。我们要么近似计算,要么减少交互
下面是几种经典方案:

1. Sparse Attention(稀疏注意力)

不是所有token都需要看所有token,局部邻域就够了。

比如Sparse、Transformer Longformer、BigBird

  • 只计算局部块(local window)或部分全局token。
  • 计算量从 O(n²) → O(n*k),k ≪ n。

2. Low-Rank Approximation(低秩近似)

比如 Linformer:把 K 和 V 投影到低维空间。

image

其中 E 是个 n×k 投影矩阵。
本质上是假设注意力矩阵是低秩的,不需要全秩表示。复杂度变成 O(n*k)。

3. Kernelized Attention(核函数注意力)

这类方法最聪明,比如Performer、Linear Transformer

核心技巧是把 Softmax(QKᵀ) 换成一个核函数形式:image

这样就能重排计算:image

计算顺序从 O(n²d) → O(nd²)。

其中步骤就是先算出 K 部分的“全局汇总”,再乘上 Q,就不用两两相乘了,一次搞定,完美线性化

4. Recurrent / Chunk-based Attention

比如 LongNet、Transformer-XL、RetNet


它们利用递归或状态缓存,让模型记住过去的信息,而不是重新计算所有注意力。每一段计算 O(n),再接上状态,复杂度线性。就像人类的短期记忆一样:“不用每次都从头回忆,记住重点就行。”

三、小结

Self-Attention从平方到线性,靠的不是一个“奇迹算法”,而是一系列聪明的折中。有的舍弃部分交互,有的重新排列计算顺序,有的直接假设低秩结构。

至此,Transformer为什么有这么高的复杂度的原理以及优化方法的回答就可以结束了。

但是要注意在回答的过程中不要死记硬背,要理解之后转化为自己的知识再回答,面试官往往注重的不是你回答的一丝不漏,而是你回答的逻辑性与对问题的理解深度,所以一定要注重展示出自己的思考能力。

关于深度学习和大模型相关的知识和前沿技术更新,请关注公众号 coting