社区项目ROSA-Tuning:验证RWKV-8 ROSA效果

36 阅读6分钟

本项目来自 RWKV 社区开发者 zyaaa-ux ,项目链接:github.com/zyaaa-ux/RO…

本项目为社区提出的一种 ROSA 实现,不代表 RWKV-8 ROSA 的实际性能,效果供参考。

本项目提出 ROSA-Tuning,一种通过检索回忆机制增强预训练模型长上下文建模能力的方法。该方法在传统注意力机制之外并行引入基于 CPU 的 ROSA(RWKV Online Suffix Automaton)检索模块,从长上下文中高效定位与当前查询相关的历史位置,并以可训练的方式将检索到的信息注入模型状态,故随后的加权融合可由状态受限的高效注意力完成。

为实现端到端训练,本文设计了二值离散化策略与反事实梯度算法,并通过 CPU–GPU 异步流水线进一步优化整体执行效率。

在 Qwen3-Base-1.7B 上的系统性评估结果表明,ROSA-Tuning 能够显著恢复窗口注意力模型的长上下文建模能力,在 LongBench 等基准上取得接近甚至匹配全局注意力的性能,同时保持与窗口注意力方法几乎相当的计算效率与显存占用,为高效长上下文处理提供了一条新的技术路径。

性能测试

困惑度 (PPL) 对比

在 PG-19 数据集上的测试显示,ROSA 适配器成功修复了窗口注意力的 PPL 劣化,甚至优于全局注意力。

ModelPPL (越低越好)
Global Attention18.96
Windowed Attention74.50
Windowed + ROSA17.63

实验设置:

  • 基础模型:具有全局注意力或窗口注意力的 Qwen3-Base-0.6B
  • 训练:28,000 个样本,在 PG-19 训练集上训练 1 个轮次,原始模型冻结,仅训练 ROSA 适配器
  • 评估:PG-19 测试集,序列长度 16k,窗口大小 1024

长文本能力 (LongBench)

在大海捞针 (NIAH) 测试,ROSA 实现了 100% 的召回率。综合评分恢复至全局注意力的 96.5%。

Task / MetricGlobal AttentionWindowed (2048)Windowed + ROSA
NIAH (大海捞针)100.006.20100.00
TriviaQA86.2061.5684.34
Multi_news23.2310.4323.76
Samsum42.0432.5140.53
TREC72.6752.6768.00
Gov_report31.1113.0826.19
LongBench 平均分59.2129.4157.14

实验设置:

  • 基础模型:Qwen3-1.7B-Base,具有全局注意力或窗口化注意力(窗口大小2048)
  • 训练数据:约 37B tokens,其中约 30B 来自 prolong,约 7B 来自其他上下文推理数据集,且不与测试集重叠

使用方法

项目作者进行了非常多的实验,本次介绍 2025 年 12 月 29 日更新的 2025.12.29 qkv_update.py 的使用方法。

注意在本地准备 Hugging Face datasets 库保存到磁盘的格式(Arrow 格式)的数据。

环境准备和代码获取

首先运行下列代码安装需要的库:

pip install torch transformers datasets deepspeed numba numpy

可选安装 flash-attn 库,该库能够提升代码运行速度,但首次安装时需要编译。

然后运行下列命令,获取项目代码:

git clone https://github.com/zyaaa-ux/ROSA-Tuning

准备 DeepSpeed 配置文件

项目使用了 DeepSpeed 进行加速,因此需要在本地创建一个 deepspeed_config.json 文件,示例如下:

{
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },
  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 2,
    "allgather_partitions": true,
    "allgather_bucket_size": 200000000,
    "overlap_comm": true,
    "reduce_scatter": true,
    "reduce_bucket_size": 200000000,
    "contiguous_gradients": true,
    "offload_optimizer": {
        "device": "cpu", 
        "pin_memory": true
    },
    "offload_param": {
        "device": "none"
    }
  },
  "gradient_accumulation_steps": "auto",
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 20,
  "wall_clock_breakdown": false
}

如果显存足够,可以删除 offload_optimizer 中的 pin_memory 参数,并将 device 的值修改为 none,来获得更快的运行速度。

修改配置

2025.12.29 qkv_update.py 的 68~73 行定义了路径参数,需要修改路径参数为你本地的路径。

MODEL_LOCAL_DIR = "/path/to/base/model/" # 本地基础模型路径
MODEL_DIR = "/path/to/checkpoint/" # 模型检查点保存路径
DATASET_DIR     = "/path/to/processed/dataset/" # 数据集路径
OUTPUT_DIR      = "/path/to/output/" # 输出路径
DEEPSPEED_CONFIG_PATH = "/path/to/deepspeed/config.json" # DeepSpeed 配置文件路径

如果需要更加节省显存,可以修改代码第 119 行为 True,打开梯度累计:

GRADIENT_CHECKPOINTING = True # 源代码是 False

如果本地无 flash-attn 库,可以修改第 78 行的代码,关闭 flash-attn 的使用:

USE_FLASH_ATTN = False # 原来是 True

运行启动命令

由于使用了 DeepSpeed 和分布式训练逻辑(is_main_process 等检查),推荐 deepspeed 命令启动。

deepspeed --num_gpus=1 2025.12.29 qkv_update.py

启动成功后,会输出以下内容:

该图为使用 200 条长度为 128 的数据在单卡 4090 上进行流程测试的示例,实际训练 16k 长度数据时需要很大的显存。

🧠 原理概述

1. 状态更新公式

考虑第 ll 层,其隐藏状态记为 h(l)RT×Ch^{(l)} \in \mathbb{R}^{T \times C},其中 C=RMC = R \cdot M。该层的更新公式如下:

h(l+1)=h(l)+Attnwin(l)(LN(h(l)))+ROSA(l)(h(l))+MLP(l)(LN())h^{(l+1)} = h^{(l)} + \text{Attn}^{(l)}_{\text{win}}(\text{LN}(h^{(l)})) + \text{ROSA}^{(l)}(h^{(l)}) + \text{MLP}^{(l)}(\text{LN}(\cdot))

2. 特征二值化与符号化

定义参数 R=CMR = \frac{C}{M}K=2MK = 2^M。对于 X{Q,K,V}X \in \{Q, K, V\},令 bb,t,cXb^{X}_{b,t,c} 为输入特征的二值化表示(指示函数 1[xb,t,cX>0]\mathbf{1}[x^{X}_{b,t,c} > 0]),ab,t,rXa^{X}_{b,t,r} 为对应的整数符号表示:

bb,t,cX=1[xb,t,cX>0],ab,t,rX=m=0M1bb,t,(r,m)X2mb^{X}_{b,t,c} = \mathbf{1}[x^{X}_{b,t,c} > 0], \quad a^{X}_{b,t,r} = \sum_{m=0}^{M-1} b^{X}_{b,t,(r,m)} \cdot 2^m

3. 构建 Key 序列与 Run 索引

初始化 s0=0,  sym0=ab,0,rKs_0 = 0, \; \text{sym}_0 = a^{K}_{b,0,r}。当检测到符号跳变(即 ab,t,rKab,t1,rKa^{K}_{b,t,r} \neq a^{K}_{b,t-1,r})时,记录新的 run 起始点:

sl+1=t,syml+1=ab,t,rKs_{l+1} = t, \quad \text{sym}_{l+1} = a^{K}_{b,t,r}

定义时刻 tt 的有效 run 边界 rcap(t)\text{rcap}(t) 为:

rcap(t)=max{lslt}\text{rcap}(t) = \max \{ l \mid s_l \le t \}

4. 历史匹配与检索

执行匹配操作以确定下一状态。设 ns=match_next(s,x)ns = \mathrm{match\_next}(s, x)rpos=e[ns]rpos = e[ns],以及 nxt=rpos+1nxt = rpos + 1。索引 τb,r,t\tau_{b,r,t} 定义如下:

τb,r,t={snxt,若匹配成功且 nxtrcap(t),1,其他情况\tau_{b,r,t} = \begin{cases} s_{nxt}, & \text{若匹配成功且 } nxt \le \text{rcap}(t), \\ -1, & \text{其他情况} \end{cases}

5. 输出嵌入向量计算

对于 Q-run 的首个符号 aa 及其第 jj 个比特位,定义位操作后的符号状态:

a(j,0)=a¬(1j),a(j,1)=a(1j)a^{(j,0)} = a \wedge \neg(1 \ll j), \qquad a^{(j,1)} = a \vee (1 \ll j)

对应的扰动时间索引 τ(j,b)\tau^{(j,b)} 为:

τ(j,b)={snxt(j,b),若匹配成功,1,其他情况\tau^{(j,b)} = \begin{cases} s_{nxt^{(j,b)}}, & \text{若匹配成功}, \\ -1, & \text{其他情况} \end{cases}

计算输出嵌入向量。设 Δ\DeltaEmb1\text{Emb}_1Emb0\text{Emb}_0 之差,则输出 yb,t,cy_{b,t,c} 表示为:

Δ=Emb1Emb0,yb,t,c=1[τ0](Emb0[c]+Δ[c]1[vb,τ,r,m>0])\Delta = \text{Emb}_1 - \text{Emb}_0, \qquad y_{b,t,c} = \mathbf{1}[\tau \ge 0] \cdot \left(\text{Emb}_0[c] + \Delta[c] \cdot \mathbf{1}[v_{b,\tau,r,m} > 0]\right)

ROSA 模块的输出定义为 yy 的线性变换:

ROSA(l)(h)=Linear(y)\text{ROSA}^{(l)}(h) = \text{Linear}(y)

6. 反向传播与梯度计算

应用 Sigmoid 激活函数计算各分量的概率值:

pQ=σ(TQq),pK=σ(TKk),pV=σ(TVv)p^{Q} = \sigma(T_Q q), \qquad p^{K} = \sigma(T_K k), \qquad p^{V} = \sigma(T_V v)

计算反向传播梯度。定义损失函数关于 yy 的加权梯度 θb,t,c\theta_{b,t,c},其中 θb,t,r,m\theta_{b,t,r,m}θb,t,c\theta_{b,t,c} 的 reshape 形式:

θb,t,c=Lyb,t,cΔ[c]\theta_{b,t,c} = \frac{\partial \mathcal{L}}{\partial y_{b,t,c}} \cdot \Delta[c]
  • V 的梯度:定义辅助项 Sb,r,τ,mVS^{V}_{b,r,\tau,m},则 vv 的梯度计算如下:

    Sb,r,τ,mV=t:τb,r,t=τθb,t,r,m,Lvb,t,r,m=pb,t,r,mV(1pb,t,r,mV)Sb,r,t,mVS^{V}_{b,r,\tau,m} = \sum_{t: \tau_{b,r,t} = \tau} \theta_{b,t,r,m}, \qquad \frac{\partial \mathcal{L}}{\partial v_{b,t,r,m}} = p^{V}_{b,t,r,m}(1 - p^{V}_{b,t,r,m}) S^{V}_{b,r,t,m}
  • Q 的梯度:令 Vb,r,τ,mQ{1[v>0],pV}\mathcal{V}^{Q}_{b,r,\tau,m} \in \{\mathbf{1}[v > 0], p^{V}\}。定义差分项 db,t,r(j)d^{(j)}_{b,t,r},则 qq 的梯度为:

    db,t,r(j)=mθb,t,r,m(Vb,r,τ(j,1),mQVb,r,τ(j,0),mQ)d^{(j)}_{b,t,r} = \sum_{m} \theta_{b,t,r,m} \left(\mathcal{V}^{Q}_{b,r,\tau^{(j,1)},m} - \mathcal{V}^{Q}_{b,r,\tau^{(j,0)},m}\right)
    Lqb,t,r,j=pb,t,r,jQ(1pb,t,r,jQ)db,t,r(j)\frac{\partial \mathcal{L}}{\partial q_{b,t,r,j}} = p^{Q}_{b,t,r,j}(1 - p^{Q}_{b,t,r,j}) d^{(j)}_{b,t,r}
  • K 的梯度:定义累积项 Ub,r,l,j(b)U^{(b)}_{b,r,l,j} 及其差分 ΔUb,r,l,j\Delta U_{b,r,l,j}kk 的梯度仅在 run 起始点 sls_l 处非零:

    Ub,r,l,j(b)=tmθb,t,r,mVb,r,sl,mK,(b),ΔUb,r,l,j=Ub,r,l,j(1)Ub,r,l,j(0)U^{(b)}_{b,r,l,j} = \sum_{t} \sum_{m} \theta_{b,t,r,m} \mathcal{V}^{K,(b)}_{b,r,s_l,m}, \quad \Delta U_{b,r,l,j} = U^{(1)}_{b,r,l,j} - U^{(0)}_{b,r,l,j}
    Lkb,sl,r,j=pb,sl,r,jK(1pb,sl,r,jK)ΔUb,r,l,j,Lkb,tsl,r,j=0\frac{\partial \mathcal{L}}{\partial k_{b,s_l,r,j}} = p^{K}_{b,s_l,r,j}(1 - p^{K}_{b,s_l,r,j}) \Delta U_{b,r,l,j}, \qquad \frac{\partial \mathcal{L}}{\partial k_{b,t \neq s_l,r,j}} = 0

加入 RWKV 社区

欢迎大家加入 RWKV 社区,可以从 RWKV 中文官网了解 RWKV 模型,也可以加入 RWKV 论坛、QQ 频道和 QQ 群聊,一起探讨 RWKV 模型。

欢迎大家基于 RWKV-7 进行创业、科研,我们也会为基于 RWKV 的项目提供技术支持。

如果您的团队正在基于 RWKV 创业或开展研究,请联系我们!(在“RWKV元始智能”微信公众号留言您的联系方式,或发送邮件到“contact@rwkvos.com”。)