论文笔记:TransMLA,将 kv 映射到低秩空间

142 阅读1分钟

留下阅读 (2025) TransMLA: Multi-Head Latent Attention Is All You Need 的痕迹。

一言蔽之,就是让 Multi-Head Attention 的 kv 映射到低秩空间上,就像 LoRA 一样。

论文其实就只给了个思路,实验结果不是很全面。实验部分用了 Qwen2.5-7B 与其 MLA 改版,使用数据集 SmolTalk 进行微调,然后看到 MLA 版本模型的测试准确度更高。

实现方法

其实很直接。

Multi-Head Attention 的 qkv 是这样从输入 X\bold{X} 获得的:

Q=XWQRT×(nhdh)\bold{Q}=\bold{X}\bold{W}_Q\in \mathbb{R}^{T\times (n_h\cdot d_h)}
K=XWKRT×(nhdh)\bold{K}=\bold{X}\bold{W}_K\in \mathbb{R}^{T\times (n_h\cdot d_h)}
V=XWVRT×(nhdh)\bold{V}=\bold{X}\bold{W}_V\in \mathbb{R}^{T\times (n_h\cdot d_h)}
  • WQ,WK,WVRD×(nhdh)\bold{W}_Q, \bold{W}_K, \bold{W}_V\in \mathbb{R}^{D\times (n_h\cdot d_h)},权重矩阵
  • nkn_k,一个头的维度
  • dhd_h,头数

对于 MLA,是这样的形式:

Q=XWQRT×(nhdh)\bold{Q}=\bold{X}\bold{W}_Q\in \mathbb{R}^{T\times (n_h\cdot d_h)}
K=XWKaWKbRT×(nhdh)\bold{K}=\bold{X}\bold{W}_K^a\bold{W}_K^b\in \mathbb{R}^{T\times (n_h\cdot d_h)}
V=XWVaWVbRT×(nhdh)\bold{V}=\bold{X}\bold{W}_V^a\bold{W}_V^b\in \mathbb{R}^{T\times (n_h\cdot d_h)}
  • WKa,WVaRD×r\bold{W}_K^a, \bold{W}_V^a\in \mathbb{R}^{D\times r}
  • WKb,WVbRD×r\bold{W}_K^b, \bold{W}_V^b\in \mathbb{R}^{D\times r}
  • rr 是压缩维度。r<nhdhr<n_h\cdot d_h

可见和 LoRA 思路一样,只是 LoRA 变成了 kv 计算的主干。