transformers库,把rope编码类型设为dynamic的大坑

12 阅读3分钟

用 transformers 库写模型时发现个很离谱的 bug。离谱不在于 bug 本身,而是 transformers 本身已经意识到这个 bug,但因为 “速度更重要” 而搁置不改了。

先说最终结论:如果要将 rope_type 设为 dynamic,为了让位置编码正确运作就必须设置 use_cache=False

背景

简单说一些背景知识。

为了让 transformer 模型获知输入序列的位置关系,需要想办法把位置信息嵌入到序列中。

旋转位置编码 RoPE(Rotary Position Embedding)是目前最为流行的位置编码方案,广泛应用于各种 transformer 模型,例如 Llama 和 Qwen。效果好,是位置编码中的豪杰。

一般来说,如果模型推理超过训练时的序列长度,困惑度就会猛然上升,输出质量急剧下降。于是有了一些扩展位置编码长度的方法。

本文所涉及的扩展方法叫做 Dynamic NTK,能够在不微调模型的前提下有效增加模型在更长上下文的能力。在 transformers 库中对应 rope_type = dynamic

问题

Dynamic NTK 本身没问题,本文不会探讨这个方法的具体原理。问题出在 Dynamic NTK 会相当于在推理过程中随着序列增长而动态调整 base(对应 transformers 库的配置 rope_theta),而 KV Cache 存储的历史所使用的 base 并没有随着序列增加而变化

新 KV 与旧 KV 所使用的位置编码不统一。除非在调整 base 后抛弃所有旧的 KV Cache,否则带着 KV Cache 进行推理是不合理的,脱离了 Dynamic NTK 原本意图。

将错就错

Github 上这个 Issue 就是探讨的这个问题,以及这个 Pull Request也是为了解决这个问题而提的。

Issue 里有一个实验,当前不考虑 base 发生变化的 Cache 方案会导致模型置信度在突破原定上下文长度时飙升,随后下降。这已经与预期的 Dynamic NTK 效果偏差太大了。

但维护人员认为应该保持现状,PR 也没有合并进来。至于原因,就是一旦考虑需要更换 Cache 的情况,推理就太慢了execution speed is paramount in LLMs nowadays(当下大语言模型的推理速度才是最重要的)。

HuggingFace 的文档也不说这个问题,Qwen 或者 Gemma 之类的模型也心照不宣的,要么使用 rope_type = linear 而不是 dynamic,要么就啥都不用。Dynamic NTK 像是一个摆设,一个坑,躺在 transformers 源码中。

碎碎念

其实影响不是那么大。开放的模型权重该怎么用就怎么用,官方的默认推荐配置就是最可用的。

就只是需要留个心眼。有较高自定义需求的话,要么不碰 rope_type = dynamic,要么用上 dynamic 就设置 use_cache=False

参考来源

  • github.com/huggingface…
    • 围绕该问题的 issue
  • 苏剑林 ,“Transformer升级之路:10、RoPE是一种β进制编码”,spaces.ac.cn/archives/96…
  • bloc97,“NTK-Aware Scaled RoPE allows LLaMA models to have extended (8k+) context size without any fine-tuning and minimal perplexity degradation.”,www.reddit.com/r/LocalLLaM…
    • 有直观的 NTK RoPE 困惑度对比图
  • emozilla,“Dynamically Scaled RoPE further increases performance of long context LLaMA with zero fine-tuning”,www.reddit.com/r/LocalLLaM…
    • 有直观的 Dynamic NTK RoPE 困惑度对比图