BOND论文解读

55 阅读24分钟

背景

现阶段,一个简单、有效的推理策略是Best-of-N采样,会从N个候选回复中选择reward最高的回复,但是Best-of-N在推理时成本较高,本文提出了Best-of-N Distillation(BOND)算法,此算法模仿了BON策略,同时避免推理时的巨大开销。BOND算法可以认为是一种分布匹配算法,它使得生成的分布接近BON采样的分布,这里使用Jeffreys divergence来平衡模式覆盖和模式探索行为,并推导出一种利用移动锚点提高效率的迭代公式。

知识速递

Jeffreys divergence:是一种对称的散度度量,用于衡量两个概率分布之间的差异,它由 Kullback-Leibler (KL) 散度对称化得到,即 Jeffreys divergence 为两个分布 P 和 Q 之间的 KL 散度与反向 KL 散度的平均值。

公式:J(P,Q)=DKL(PQ)+DKL(QP)2J(P,Q) = \frac{D_{KL}(P \parallel Q) + D_{KL}(Q \parallel P)}{2},其中 DKL(PQ)D_{KL}(P \parallel Q)表示KL散度。

差异性:JD与KL的区别是具备对称性、非负性和无界性,分别对应的好处是不需要区分分布的真实和预测、分布相同其值为0、分布不相交值无限大

应用:

  1. 用于生成模型和聚类分析:Jeffreys divergence 可用于衡量模型生成的分布与真实数据分布之间的差异,从而指导模型的训练和优化
  2. 统计学:Jeffreys divergence 可用于假设检验和模型选择等问题。通过比较不同模型对应的分布与观测数据分布之间的 Jeffreys divergence,可以选择与数据拟合更好的模型

强化学习阶段

Gemini、GPT-4等大模型通常会经过三个阶段的训练。

  • 阶段一

预训练,在大量的知识语料上面训练next-token prediction 能力。

  • 阶段二

监督微调(SFT),在预训练模型的基础上进行有监督的微调,主要是得到指令遵循能力。

  • 阶段三

人类反馈强化学习(RLHF)进一步提高模型生成的质量。RLHF主要包括基于人类偏好学习RM模型,然后使用强化学习算法优化 LLM,使其生成内容能够最大化预期奖励。

强化学习挑战

RL对LLM进行微调是具备挑战性的,显著的两个问题就是预训练知识的forgettingreward hacking,标准策略是使用策略梯度对SFT进行正则化,这些 RL 算法旨在寻找在低 KL 散度下具有高奖励的Pareto-optimal policies,以保持原始模型的通用能力,并解决对齐问题。

Best-of-N sampling

实际应用中,一种简单的推理时方法是从N次推理中产生N个候选集,并使用RM模型选择奖励最高的,这种方法优于奖励 - KL 散度权衡,但是推理成本会增加N倍。

Best-of-N Distillation

BOND是一种新颖的选择蒸馏策略,该策略推理时仅单次采样就实现N采样的效果,核心思想是:把策略对齐问题转换为分布匹配问题,通过微调策略使其模仿最佳 N 选分布。首先,需要推导出BON的解析表达式,这需要我们考虑优化不同的差异度量指标。

首先,需要最小化BON前向KL散度,这是一种模式覆盖行为的标准模仿学习设置(鼓励多样性,不出现集中某行为),然后需要最小化BON的后向KL散度,这是一种基于分位数的优势(可以认为是排名),它不依赖于奖励尺度,对应的是模式探索行为(鼓励高质量、少多样性)。

接着,提出了Jeffreys divergence,最小前向和后向KL散度的线性组合,可以融合两种方法的优势。同时为了在保持较低样本复杂度的同时优化性能,通过迭代地对移动锚点策略的最佳 N 选进行蒸馏。

最后,提出了J-BOND算法,这是新型、稳定、高效且实用的 RLHF 算法,用于对齐大型语言模型。

实验设置

任务:Xsum(抽象摘要)

对齐模型:Gemma

基础设置

定义了一个函数策略: π(x,)\pi(x, \cdot),表示在自回归生成模型中,基于prompt得到token序列y的过程。

给定一个预先训练且通常经过监督微调的参考策略 πref\pi_{ref},为了进一步反应人类偏好,本文后续假设可使用反映人类偏好的奖励模型为 r()r(\cdot)

标准RLHF

大多数强化学习(RL)算法会优化期望奖励与当前策略和参考策略之间KL散度的线性组合。

公式:

πRL=argmaxπEπ[r(y)]βRLKL(ππref)\pi_{RL} = \arg\max_{\pi} \mathbb{E}_{\pi}[r(y)] - \beta_{RL} \cdot KL(\pi || \pi_{ref})

βRL0\beta_{RL}\ge0:表示正则化强度,这种 KL 正则化可促使策略接近其初始化状态 πref\pi_{ref},从而减少知识遗忘和奖励欺骗的情况。这个公式采用的在线优化算法,强于离线优化,并且简单的方法容易产生更好的效果,例如REINFORCE算法强于PPO算法。

Best Of N

推理时策略,主要是从πref\pi_{ref}从多次采样,且并根据奖励模型 rr 选择奖励最高的生成结果,与RLHF不同,它不会微调大型语言模型(LLM)的权重,而是修改推理过程,该策略在经验和理论上均被证明正确,其缺点是N次推理的高成本,这里提出的BOND新型对齐方法,BOND 的目标是将BON策略蒸馏到策略中,这样,策略就能在仅需单次采样的情况下,达到BON采样的强大性能。

BOND方法

BOND分为两个步骤,首先推导BON的解析式,其次把这个表达式表述为一个分布匹配问题,希望将策略引导得更接近BON分布。

BON分布

首先,为了方便,后面表达式中会全部省略上下文 xx,假设奖励 r(y)r(y)能对所有生成结果 yy诱导出一个严格排序。

  • 定理1: p<(y)=Pyπref[r(y)<r(y)]p_<(y) = \mathbb{P}_{y' \sim \pi_{\text{ref}}} [r(y') < r(y)],表示: πref\pi_{ref}中随机生成的结果 yy'严格劣于 yy的概率

    • p(y)=Pyπref[r(y)r(y)]p_{\le}(y) = \mathbb{P}_{y' \sim \pi_{\text{ref}}} [r(y') \le r(y)],表示:表示 yy'不比 yy优的概率,那么,yy作为BON输出的概率可表示为: πBoN(y)=πref(y)×p(y)N1(A)×i=1N(p<(y)p(y))i1(B)\pi_{\text{BoN}}(y) = \pi_{\text{ref}}(y) \times \underbrace{p_{\leq}(y)^{N-1}}_{(A)} \times \sum_{i=1}^{N} \underbrace{\left( \frac{p_{<}(y)}{p_{\leq}(y)} \right)^{i-1}}_{(B)}

    公式解释: > > πref\pi_{ref}:表示结果被生成的概率,这个概率如果被生成都不可能,后面的概率计算也无意义,所以这是概率计算的基础和起点 > > (A):每次采样均是独立,假设当前已经有一个y,剩余的N-1次采样都不可以比y好,所以联合概率就是N-1次方,即放大好样本的概率,指数级抑制差样本概率 > > (B):这个公式主要是对于平局的修正,假设 p<(y)p_{<}(y)p(y)p_{\le}(y)相等,就不会出现平局,此时(B)=N,这个时候,只要y出现,其他都比y小,这也是希望看到的,鼓励此行为,但是如果它们不相等,说明会出现平局,此时(B)<N,概率会变小 > > 为什么三部分相乘:三个事件认为是独立性发生,互不干扰,最终同时发生概率为联合概率。

BOND目标

公式: πBOND=argminπΠD(ππBoN)\pi_{\text{BOND}} = \arg\min _{\pi \in \Pi} D(\pi \parallel \pi^{\text{BoN}}),其中 D(⋅∥⋅) 是一种散度度量,引导训练策略 π\piπBoN\pi^{BoN} 靠拢,这就是前面讲到的JD散度或者前向后向JD。

RLHF与BOND关系

公式: πRL(y)πref(y)exp(1βRLr(y))\pi_{RL}(y) \propto \pi_{ref}(y) \exp\left( \frac{1}{\beta_{RL}} r(y) \right)πRL(y)\pi_{RL}(y)表示要学到的策略,得到高质量,符合人类偏好的序列y, πref(y)\pi_{ref}(y)表示SFT后RLHF的起点或锚点。exp()表示的是奖励加权的部分,前后两部分公式表示的是正相关关系。

公式: rBOND(y)=logp(y)(A)+1N1logi=1N(p<(y)p(y))i1(B)r_{\text{BOND}}(y) = \underbrace{\log p^{\leq}(y)}_{(A)} + \underbrace{\frac{1}{N-1} \log \sum_{i=1}^{N} \left( \frac{p^{<}(y)}{p^{\leq}(y)} \right)^{i-1}}_{(B)},首先p(y)p_{\le}(y)表示抽取一个样本,不比y好的概率,象征的是y在所有可能生成中的分位数或排名,因为p(y)p_{\le}(y)的值在0-1之间,所以(A)的值小于等于0,在这种情况下,绝大多数生成基本上都是“惩罚”,只有p(y)p_{\le}(y)等于1时,"惩罚"为0,后面的(B)前面讲过,只是前面加了个系数用来缩放值。

总结:

BON采样等同于一个标准的、带KL正则化的RLHF问题,N表示正则化层次。

此处可以认为1/(N-1)等同于RHLF中的正则化,N越小,惩罚越大,N越大,惩罚越小

BON采样等同于优化的内容是期望对数奖励分位数,也就是让生成的奖励比从Ref中随机采样结果更大的对数似然,BOND模型鼓励模型尽量避免生成差的结果,而不是鼓励生成好的结果,它在对奖励函数r中单调变换是不变的,因为它依赖于生成结果的排名。

大概意思是:BOND在优化中注重的是奖励模型的"排名(分位数)"而非绝对数值,可以看到前面的logP,在0.0-0.4时非常陡峭,在0.4-1变得平滑,这种情况会使得模型花更大精力把不及格变成及格而非把95分提到99分,这种排名机制会保证不容易被奖励模型的数值漏洞所欺骗,BOND比起传统RLHF更不容易钻空子,健壮性更强。

BOND挑战及算法

BOND方法涉及的挑战有三个:(1)如何估算奖励分位数(2)应选择何种合适的散度度量(3)如何选择超参数N,即BON在生成过程中的数字。

蒙特卡洛分位数估计

第一个关键的难点在于估计分位数,给定生成结果yy的分位数,p(y)=Pyπref[r(y)r(y)]p_{\leq}(y) = \mathbb{P}_{y' \sim \pi_{\text{ref}}} [r(y') \leq r(y)],分位数 p(y)p_{\leq}(y) 衡量了在相同的prompt下yy相对于πref\pi_{ref}生成的质量,一种简单有效的方法是蒙特卡洛采样,即从πref\pi_{ref}中采样生成 kk个结果,并得到以下的经验估计,实验中,发现即使样本数量有限,这种方法也非常有效。

   p^(y)=1ki=1kI{r(yi)r(y)}\hat{p}_{\leq}(y) = \frac{1}{k} \sum_{i=1}^{k} \mathbb{I}\{r(y_i) \leq r(y)\}

使用JD散度作为的目标

在BOND方法中,散度度量方法很重要,不同的散度度量会走向不同的解决方案,提议采用JD散度作为稳健的分布匹配目标。两个JD的分布表达式定义如下。

JJeffreysβ(pq):=(1β)KL(qp)Forward KL+βKL(pq)Backward KLJ_{\text{Jeffreys}}^{\beta}(p \parallel q) := (1 - \beta) \cdot \underbrace{\text{KL}(q \parallel p)}_{\text{Forward KL}} + \beta \cdot \underbrace{\text{KL}(p \parallel q)}_{\text{Backward KL}}

广义的JD是一种前向KL和后向KL的加权平均,权重 β[0,1]\beta\in[0,1],在微调策略pp时,前向KL使在 q 下高概率的生成结果在 p 下也高概率,从而鼓励模式覆盖行为;后向KL引导策略 p 生成在 q 下高概率的结果,具有模式探索效果;前向 KL 易致分布过度扩散,后向 KL 可能引发策略和熵坍塌,两种结合更具优势。

在BOND框架下,上面公式转化为最小化 JJeffreysβ(ππBoN)J_{\text{Jeffreys}}^{\beta}(\pi \parallel \pi^{BoN}),可以通过从训练策略π\pi和参考策略πref\pi_{ref}采样来估计该散度。

  • 前向KL估计

KL(πBoNπ)=EyπBoN[logπBoN(y)logπ(y)]KL(\pi_{BoN} \parallel \pi) = \mathbb{E}_{y \sim \pi_{BoN}} \left[ \log \pi_{BoN}(y) - \log \pi(y) \right]

我们可以通过从 πBoN\pi^{BoN}采样(即从πref\pi_{ref}采样 N 次并选择最佳结果)直接估计前向 KL,并且可以将其视为在BON的监督微调损失。

πKL(πBoNπ)=EyπBoN[logπ(y)]\nabla_{\pi} KL(\pi_{BoN} \parallel \pi) = -\mathbb{E}_{y \sim \pi_{BoN}} \left[ \nabla \log \pi(y) \right]

此处是对 π\pi进行求梯度,不含π\pi的项会被约掉, [logπBoN(y)logπ(y)]\nabla{[\log\pi_{BoN}(y) - \log\pi(y)]}变为[logπ(y)]\nabla{[ - \log\pi(y)]}

  • 后向KL估计

KL(ππBoN)=Eyπ[logπ(y)logπBoN(y)]KL(\pi \parallel \pi_{BoN}) = \mathbb{E}_{y \sim \pi} \left[ \log \pi(y) - \log \pi_{BoN}(y) \right]

结合REINFORCE算法变成:

πKL(ππBoN)=N1Eyπ[πlogπ(y)(rBOND(y)βBOND(logπ(y)logπref(y)))]\nabla_{\pi} KL(\pi \parallel \pi_{BoN}) = -(N-1 )\mathbb{E}_{y \sim \pi} \left[ \nabla_{\pi} \log \pi(y) \left( r_{BOND}(y) - \beta_{BOND} \left( \log \pi(y) - \log \pi_{ref}(y) \right) \right) \right]

此公式跳跃性较大,大致理解如下:

  1. 两边分别求梯度:πKL(ππBoN)=πEyπ[logπ(y)logπBoN(y)]\nabla_{\pi} KL(\pi \parallel \pi_{BoN}) = \nabla_{\pi} \mathbb{E}_{y \sim \pi} \left[ \log \pi(y) - \log \pi_{BoN}(y) \right]

  2. 梯度展开并且约掉不相关项:πKL(ππBoN)=Eyπ[logπ(y)]\nabla_{\pi} KL(\pi \parallel \pi_{BoN}) = \mathbb{E}_{y \sim \pi} \left[ \nabla\log \pi(y) \right]

  3. 利用策略梯度方法中的技巧,引入奖励信号来表达梯度。根据 REINFORCE 算法的思想,策略梯度可以表示为: πEyπ[R(y)]=Eyπ[πlogπ(y)R(y)]\nabla_{\pi} \mathbb{E}_{y \sim \pi} [R(y)] = \mathbb{E}_{y \sim \pi} \left[ \nabla_{\pi} \log \pi(y) R(y) \right]

    1. REINFORCE 算法的策略梯度通用表达式为 WE(rW)=E[Wlnπ(as,W)R]\nabla_WE{(r|W)}=\mathbb{E}[\nabla_Wln\pi(a|s,W)\cdot R],其中 W\nabla_W表示对W的梯度 π(as,W)\pi(a|s,W)是状态s下采取动作a的概率,由策略参数W决定,R为奖励信号
    2. 所以上面的公式可以按照REINFORCE的算法来定义,然后 R(y)R(y )是奖励信号
    3. 奖励信号的设计:R(y)=rBOND(y)βBOND(logπ(y)logπref(y))R(y) = r_{BOND}(y) - \beta_{BOND} \left( \log \pi(y) - \log \pi_{ref}(y) \right)
    4. 把奖励信号公式代入得到:
    5.   πKL(ππBoN)=Eyπ[πlogπ(y)(rBOND(y)βBOND(logπ(y)logπref(y)))]\nabla_{\pi} KL(\pi \parallel \pi_{BoN}) = \mathbb{E}_{y \sim \pi} \left[ \nabla_{\pi} \log \pi(y) \left( r_{BOND}(y) - \beta_{BOND} \left( \log \pi(y) - \log \pi_{ref}(y) \right) \right) \right]
    6. 前面加一个调整系数-(N-1),这个地方加这个参数正好可以和BOND里面自带的1/(N-1)抵消:
    7.   πKL(ππBoN)=N1Eyπ[πlogπ(y)(rBOND(y)βBOND(logπ(y)logπref(y)))]\nabla_{\pi} KL(\pi \parallel \pi_{BoN}) = -(N-1 )\mathbb{E}_{y \sim \pi} \left[ \nabla_{\pi} \log \pi(y) \left( r_{BOND}(y) - \beta_{BOND} \left( \log \pi(y) - \log \pi_{ref}(y) \right) \right) \right]
  • 实验

    • 基础设置

    任务:Xsum摘要抽取 > > 参考模型: πref\pi_{ref},经过SFT后的T5模型 > > 奖励模型: r()r(\cdot),T5的推理模型 > > BOND方法:损失函数 JJeffreysβJ_{\text{Jeffreys}}^{\beta}β  {0,0.5,1}\beta \in {\{0,0.5,1\}} > > 训练细节: > > 1. 训练阶段:每个提示词使用16个蒙特卡洛采样去估计分位数 > 1. 评估阶段:32 个蒙特卡洛样本估计训练策略与 πBoN\pi_{BoN}分布之间的前向和后向 KL 散度 > > 参数设置: > > N=8,实际证明 β=0.5\beta=0.5能够最小化两种分布的差异,相对比β=1\beta=1(仅最小化后向)β=0\beta=0(仅最小化前向)效果更好。值得注意的是β=0.5\beta=0.5β=1\beta=1相似,也就是和仅最小化后向相似,最小化后向代表的是模式探索能力,也就是注重高质量,减少多样化,而模式化覆盖差一些。

    • 实验结果

    •   N=8时JD的表现图

    N=8时, β  \beta  在不同取值下下的JD表现,可以看左图和中图,β=0.5\beta=0.5时可以同时满足JD都比较低,另外奖励分位数相比仅β=0\beta=0更好。

      1.   BOND迭代

    •   BOND中N的选择是一个困难,N会因为3种情况变得困难。

    • 同RLHF,N有正则化作用,较大的 N 能提升下游任务表现,但若 N 过大,最终会导致奖励过度优化

    • N越大, πBoN\pi_{BoN}估计对分位数估计误差愈敏感,因为 πBoN(y)p(y)N1\pi_{BoN}(y) \propto p_{\leq}(y)^{N-1}

    • 因为前后两个公式成正相关关系,后面的概率值会因为N这个指数变大而导致误差放大,例子如下:

      • 例如,假设真实的 p≤(y)=0.6,而估计值为 p^≤(y)=0.6±0.1(即误差为 0.1)。当 N=8 时:

        • 如果 p^≤(y)=0.7,则 p≤(y)7=0.77≈0.082。
        • 如果 p^≤(y)=0.5,则 p≤(y)7=0.57≈0.0078。
    • 估计前向 KL 散度需要从 πBoN\pi_{BoN}采样,这对于较大的 N 来说计算成本过高(需要推理)

    •   应对上述挑战,提出BOND算法,方法源于事实:从BON分布中进行BON采样等同于从原始分布中进行 BON2BON^2采样。

    •   BoN(BoN(BoN()))M times(πref)BoNM(πref)\underbrace{\text{BoN}(\cdots \text{BoN}(\text{BoN}(\cdot))\cdots)}_{M \text{ times}} (\pi_{\text{ref}}) \equiv \text{BoN}^{M}(\pi_{\text{ref}})

    •   这里存在BOND方法的关键思想:如果我们知道如何蒸馏BON的分布(称之为BOND),那么就可以递归地进行M次BOND,等同于对初始化 πref\pi_{ref}BoNM{BoN}^{M}进行蒸馏。

    •   策略 πt\pi_{t}被训练为从一个移动的锚点中蒸馏出BON(图中n=2),此时并不需要一个较大的N持续提升策略。迭代 BOND 方法带来了更好的训练稳定性和最低的计算复杂度。由于在每次蒸馏步骤中都使用较小的 n,从而降低了计算成本并提高了训练过程的稳定性。

    • 迭代算法

    • 实验结果   迭代BOND(n=2、4)与非迭代BOND(n=4、8、16)对比

    迭代 BOND 方法能够持续提升奖励(左图)和对数分位数(中图),而非迭代 BOND 方法的性能则会趋于饱和(N 越小,饱和越早)。此外,迭代 BOND 方法能够在保持较小的 n 的同时,平滑地从初始策略 πref 进行改进,并实现与非迭代 BOND 方法相同的奖励与 KL 散度的权衡(右图) 总结 迭代 BOND 方法能够实现对任意大的 N 的指数级扩展(实际上,它无需预先设定 N),同时保持较低的样本复杂度和稳定的优化过程。

    1. 奖励信号饱和:

      1. 对于非迭代 BOND 方法,奖励信号早期就会饱和(N 越小,饱和越早)。
      2. 迭代 BOND 方法则持续提升性能(n 越大,提升越快)。
    1.  奖励与 KL 散度权衡:
    
        1.  在最右图中,我们将获得的对数分位数与相对于参考策略的 KL 散度进行了对比。
        1.  结果表明,迭代 BOND 方法在奖励与 KL 散度的权衡上与非迭代 BOND 方法基本一致。
        1.  关键优势在于,迭代方法允许保持较小的 *n*,并能平滑且持续地从初始策略 $$\pi_{ref}$$ 进行改进。
    

J-BOND算法

基于前面的一些结论,下面介绍J-BOND算法。

η\eta:用户挑战EMA的更新速率

β\beta:平衡前向KL与后向KL

γ \gamma :额外的KL正则化系数

yyy1y^ 、_1y2y^ 、_2:分别是当前策略和锚点生成的,用后两个的近似BON策略来估计y的好坏

前向KL:从y1y^ 、_1y2y^ 、_2中选出奖励分数最高的,是对Best-of-2的近似,然后计算KL梯度,目的是让当前 πt\pi_t可以生成类似于分数最高的高质量样本

后向KL:计算 rJBOND(x,y)r_{J-BOND}(x,y)和奖励分数 R(x,y)R(x,y),为了减少估计方差,增加一个基线B,然后完成策略梯度更新、整体策略梯度更新和锚点更新。

  1. 采样:从当前策略和少量锚点策略生成少量样本
  1. 计算梯度:

  1. 前向KL:监督学习方式,模仿锚点生成BON样本
  2. 后向KL:通过策略梯度方式,用稀疏的惩罚信号避免生成最差样本
  1. 合并更新:多个梯度信号加权组合,更新当前策略模型

  2. 移动锚点:用更新后策略模型,通过EMA缓慢更新锚点

  3. 持续迭代

  • 最小样本复杂度

相比于之前,J-BOND具有最小的样本复杂度,对于批次中的每个prompt,它从策略πt\pi_{t} 生成 1 个样本,从锚点πanchort\pi^t_{anchor}生成 2 个样本。虽然更多的锚点样本有助于更准确地估计散度,但自回归采样是在线 RLHF 的主要瓶颈。因此,我们选择了一种实用的方法,即使用少量样本进行操作。

  • 基于2个锚点样本粗略散度估计

通过对两个锚点样本中的较优者进行监督微调,来最小化前向 KL,为了最小化后向 KL可以把πKL(ππBoN)=N1Eyπ[πlogπ(y)(rBOND(y)βBOND(logπ(y)logπref(y)))]\nabla_{\pi} KL(\pi \parallel \pi_{BoN}) = -(N-1 )\mathbb{E}_{y \sim \pi} \left[ \nabla_{\pi} \log \pi(y) \left( r_{BOND}(y) - \beta_{BOND} \left( \log \pi(y) - \log \pi_{ref}(y) \right) \right) \right]中的 rBOND(y) r_{BOND}(y)替换为rJBOND(y) r_{J-BOND}(y),主要原因是仅有2个锚点的时候奖励函数 rBOND(y)=logp^(y)r_{BOND}(y) = \log \hat{p}_{\leq}(y)不具备信息量参考性(蒙特卡洛采样太少效果不好),此处假设yy是策略样本,y1y^ 、_1y2y^ 、_2为锚点样本,定义如下。

rJBOND(y)={log(16)if r(y)<min{r(y1),r(y2)}0otherwiser_{J-BOND}(y) = \begin{cases} - \log(16) & \text{if } r(y) < \min\{r(y'_1), r(y'_2)\} \\ 0 & \text{otherwise} \end{cases}

上述函数定义的原因有二:(1)策略样本比两个锚点样本都差的时候,才施加负奖励用于模拟函数的凹性,主要是凹函数特性是随着输入值增加,斜率会放缓,次数奖励会放缓,从而实现平稳优化(2)选择 log(16)- \log(16)的目的是 p(y)=0.5{p}_{\leq}(y)=0.5Ey1,y2πanchort[rJ-BOND(y)]=logp(y)\mathbb{E}_{y'_1, y'_2 \sim \pi^{t}_{\text{anchor}}} [r_{\text{J-BOND}}(y)] = \log p_{\leq}(y)这个等式成立,具体推理较为复杂,感兴趣请看附录。

  • EMA锚点

这里不是采用定期更新的锚点,而是在每个微调步骤中,将锚点权重θtanchor\theta^{anchor}_t 更新为策略权重θt\theta_t的移动平均值: θt+1anchor(1η)θtanchor+ηθt+1\theta_{t+1}^{\text{anchor}} \leftarrow (1 - \eta) \cdot \theta_{t}^{\text{anchor}} + \eta \cdot \theta_{t+1},这种权重平均方法通过降低方差对训练稳定性有积极影响,并且可以改善 J-BOND 的整体奖励与 KL 散度的权衡。

  1. 实验

这里对J-BOND进行了消融实验展示了EMA 锚点的优势,以及锚点更新速度和额外 KL 正则化的影响,然后将 J-BOND 与使用 REINFORCE 的经典 RLHF 基线进行比较,展示其有效性和更优的性能表现。

  • 实验设置

模型:Gemma的2B和7B

批量:batch=128

优化器:Adam

学习率:学习率为 3×10−6

预热步数:100

β=0.5\beta=0.5

  • 左图:展示了 J-BOND 在 γ=0 和 η∈{0.01,0.05,0.1} 时的表现。图中显示,η 越大,平均奖励增加得越快。这表明较大的 η 值使得策略更快地从参考策略 πref\pi_{ref}中学习,从而更快地提升性能。
  • 中图和右图:展示了 J-BOND 在 η=0.05 和 γ∈{0,0.5,1,2} 时的表现。图中显示,正则化参数 γ 越大,策略 πt\pi_{t}πref\pi_{ref}偏离的速度越慢,从而改善了奖励与 KL 散度的权衡。这意味着较大的 γ 值有助于保持策略的稳定性,避免过度偏离参考策略,从而在提升奖励的同时控制 KL 散度的增长。

J-BOND 在 Gemma 7B 模型上的表现优于标准 REINFORCE 方法,尤其是在奖励提升和训练稳定性方面。这表明 J-BOND 是一种更加高效和稳定的策略优化方法,适用于大型语言模型的微调和优化任务

  • EMA 锚点与定期更新锚点

在 Gemma 7B 上运行 J-BOND,设置 γ=0 和 EMA 系数 η=0.02,并将其与每 50 步更新一次锚点的变体进行比较,两次运行都产生了相同的奖励增长曲线(这并不令人意外,因为 η=0.02 的 EMA 大致对应于 50 步的更新周期)。然而,关键在于,使用 EMA 锚点的 J-BOND 显示出显著更低的 KL 散度增长,因此在奖励与 KL 散度之间实现了更好的权衡。

  • 对比RLHF

将 J-BOND 与标准 RLHF 算法进行了比较,使用了 REINFORCE(Williams,1992),每个提示使用 2 个策略样本,并采用留一法基线来计算策略梯度的优势。对于 J-BOND,我们将锚点混合系数设置为 η=0.02。对于 REINFORCE,我们测试了可能的正则化强度 βRL  {0.001,0.01,0.1,1}\beta_{RL} \in \{0.001,0.01,0.1,1\}

正如预期的那样,REINFORCE 对正则化系数 βRL \beta_{RL} 相当敏感:正则化强度越大,REINFORCE 实现的奖励越低(并且与 πref 的 KL 散度也越低)。这突显了 J-BOND 的一个关键优势:它不需要承诺使用特定的正则化水平,而是持续提升奖励,同时显示出稳定且线性的 KL 增长。此外,右图绘制了相应的奖励与 KL 散度的权衡曲线,表明 J-BOND 在所有 REINFORCE 基线中都实现了更好的奖励与 KL 散度的权衡。

  • J-BOND 用于 Gemma 开源模型

J-BOND 也被用于微调开源权重模型,例如 Gemma 1.1 2B 和 7B(Gemma 团队,2024)、RecurrentGemma 2B 和 9B(Botev 等,2024)以及 CodeGemma 1.1(CodeGemma 团队,2024)。这使得这些模型达到了具有竞争力的性能水平。例如,Gemma 1.1 IT 7B 模型在安全性和指令遵循方面均优于 Mistral 7B v0.2 Instruct。

  1. 总结

BOND,这是一种新颖的基于人类反馈的强化学习(RLHF)方法,通过在线蒸馏BON采样分布来微调策略。我们进一步提出了一个具体的算法—J-BOND,它整合了多个组件以增强其实际应用性和效率;这些组件包括蒙特卡洛分位数估计、前向与后向 KL 散度目标的结合,以及带有指数移动平均锚点的迭代过程。J-BOND 优化了解决方案的 KL-奖励帕累托前沿,并与最先进的基线方法相比表现优异。我们希望这项工作能够帮助改善人工智能系统的对齐性,使其更加安全和可靠。

  1. 示例代码

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from tqdm import tqdm
import copy
import random

# ----------------------------------------------------------------------------
# Helper Functions
# ----------------------------------------------------------------------------

def get_reward(reward_model, tokenizer, texts, device):
    """
    计算一批文本的奖励分数。
    在这个示例中,我们使用情感分析模型,积极情感的分数更高。
    """
    with torch.no_grad():
        inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        outputs = reward_model(**inputs)
        # 我们取“积极”情感(通常是label 1)的logit作为奖励分数
        # 这鼓励模型生成更积极的内容
        reward_scores = outputs.logits[:, 1] 
    return reward_scores

def get_log_probs(policy_model, tokenizer, sequences, prompt_tokens, device):
    """
    计算给定策略模型下,生成序列的对数概率。
    """
    # 将序列和prompt都放到设备上
    sequences = sequences.to(device)
    prompt_tokens = prompt_tokens.to(device)
    
    # 准备模型的输入和标签
    model_inputs = sequences
    labels = sequences.clone()
    
    # 我们只计算生成部分的loss,prompt部分用-100掩码掉
    prompt_len = prompt_tokens.shape[1]
    labels[:, :prompt_len] = -100

    with torch.no_grad(): # 在评估log_prob时不需要计算梯度
        outputs = policy_model(model_inputs, labels=labels)
        # loss是每个token的负对数概率的平均值,乘以token数量得到总和
        # 我们取负值,得到对数概率的总和
        log_probs = -outputs.loss * (sequences.shape[1] - prompt_len)
        
    return log_probs

# ----------------------------------------------------------------------------
# Main J-BOND Training Function (Algorithm 2)
# ----------------------------------------------------------------------------

def train_j_bond(
    # --- 模型和数据 ---
    policy_model_name: str,
    reward_model_name: str,
    prompts: list,
    # --- J-BOND 超参数 (Algorithm 2 Inputs) ---
    beta: float = 0.5,      # Jeffreys散度中的β,平衡前向和后向KL
    eta: float = 0.02,      # EMA更新锚点模型的速率
    gamma: float = 0.1,     # 额外的KL正则化强度
    # --- 训练参数 ---
    epochs: int = 5,
    batch_size: int = 4,
    learning_rate: float = 1e-6,
    # --- 生成参数 ---
    max_length: int = 60,   # 生成文本的最大长度
    n_anchor_samples: int = 2, # J-BOND使用2个锚点样本
    r_j_bond_negative_reward: float = -torch.log(torch.tensor(16.0)) # Eq (17)
):
    """
    使用J-BOND算法训练一个策略模型。
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # 1. 初始化 (Initialize policy and anchor)
    tokenizer = AutoTokenizer.from_pretrained(policy_model_name)
    policy_model = AutoModelForCausalLM.from_pretrained(policy_model_name).to(device)
    
    # J-BOND 使用一个移动的锚点模型
    anchor_model = copy.deepcopy(policy_model).to(device)
    
    reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_name).to(device)

    # 设置tokenizer的pad_token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    # 对于自回归模型,左填充更方便
    tokenizer.padding_side = 'left'

    optimizer = torch.optim.AdamW(policy_model.parameters(), lr=learning_rate)

    print("--- Starting J-BOND Training ---")

    # 2. 训练循环 (for t = 0, ... do)
    for epoch in range(epochs):
        random.shuffle(prompts)
        
        for i in tqdm(range(0, len(prompts), batch_size), desc=f"Epoch {epoch+1}/{epochs}"):
            batch_prompts = prompts[i:i+batch_size]
            
            # --- a. 采样 (Generate policy and anchor samples) ---
            prompt_tokens = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(device)
            
            # 从当前策略生成1个样本
            policy_outputs = policy_model.generate(
                **prompt_tokens, max_length=max_length, do_sample=True, pad_token_id=tokenizer.pad_token_id
            )

            # 从锚点策略生成N个样本 (J-BOND 中 N=2)
            with torch.no_grad():
                anchor_outputs = anchor_model.generate(
                    **prompt_tokens,
                    max_length=max_length,
                    num_return_sequences=n_anchor_samples,
                    do_sample=True,
                    pad_token_id=tokenizer.pad_token_id
                )
            
            # 解码文本以用于奖励模型
            policy_texts = tokenizer.batch_decode(policy_outputs, skip_special_tokens=True)
            anchor_texts = tokenizer.batch_decode(anchor_outputs, skip_special_tokens=True)

            # --- b. 前向 KL 梯度 (Forward KL Gradient) ---
            # /* Extract Best-of-2 sample */
            # 由于每个prompt有n_anchor_samples个输出,我们需要重塑它们
            reshaped_anchor_texts = [anchor_texts[j:j+n_anchor_samples] for j in range(0, len(anchor_texts), n_anchor_samples)]
            
            best_of_anchor_texts = []
            anchor_rewards_list = []
            for group in reshaped_anchor_texts:
                anchor_rewards = get_reward(reward_model, tokenizer, group, device)
                best_idx = torch.argmax(anchor_rewards)
                best_of_anchor_texts.append(group[best_idx])
                anchor_rewards_list.append(anchor_rewards)

            # /* Compute forward KL gradient */
            # 这等价于在Best-of-N样本上进行监督微调
            # 我们希望最大化生成 best_of_anchor_texts 的对数概率
            best_of_anchor_tokens = tokenizer(best_of_anchor_texts, return_tensors='pt', padding=True, truncation=True).to(device)
            
            # 计算 loss_fw (前向KL散度的梯度)
            fw_logits = policy_model(best_of_anchor_tokens['input_ids']).logits
            fw_labels = best_of_anchor_tokens['input_ids'].clone()
            fw_labels[:, :prompt_tokens.shape[1]] = -100 # 掩码掉prompt部分
            loss_fw = F.cross_entropy(fw_logits.view(-1, fw_logits.size(-1)), fw_labels.view(-1))
            
            # --- c. 后向 KL 梯度 (Backward KL Gradient) ---
            # /* Compute r_J-BOND */
            policy_rewards = get_reward(reward_model, tokenizer, policy_texts, device)
            
            r_j_bond = torch.zeros_like(policy_rewards)
            for j in range(len(batch_prompts)):
                min_anchor_reward = torch.min(anchor_rewards_list[j])
                if policy_rewards[j] < min_anchor_reward:
                    r_j_bond[j] = r_j_bond_negative_reward
            
            # /* Compute return R(x, y) */
            log_probs_policy = get_log_probs(policy_model, tokenizer, policy_outputs, prompt_tokens['input_ids'], device)
            with torch.no_grad():
                log_probs_anchor = get_log_probs(anchor_model, tokenizer, policy_outputs, prompt_tokens['input_ids'], device)
            
            kl_to_anchor = log_probs_policy - log_probs_anchor
            R = r_j_bond.to(device) - kl_to_anchor
            
            # /* Compute baseline B and advantages */
            baseline = R.mean()
            advantages = (R - baseline).detach()
            
            # /* Compute backward KL gradient */
            loss_bw = - (log_probs_policy * advantages).mean()

            # --- d. 额外的 KL 正则化 (Additional KL regularization) ---
            # /* KL regularization gradient */
            # 直接使用kl_to_anchor作为正则项
            loss_reg = kl_to_anchor.mean()

            # --- e. 整体策略更新 (Overall policy update) ---
            # /* Jeffreys divergence + KL regularization */
            total_loss = (1 - beta) * loss_fw + beta * loss_bw + gamma * loss_reg
            
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

            # --- f. 更新锚点 (Update moving anchor with EMA) ---
            with torch.no_grad():
                for policy_param, anchor_param in zip(policy_model.parameters(), anchor_model.parameters()):
                    anchor_param.data.mul_(1 - eta).add_(policy_param.data, alpha=eta)
                    
        print(f"Epoch {epoch+1} finished. Final loss: {total_loss.item():.4f}")
        
    print("--- J-BOND Training Finished ---")
    return policy_model, tokenizer

# ----------------------------------------------------------------------------
# Example Usage
# ----------------------------------------------------------------------------
if __name__ == "__main__":
    # --- 配置 ---
    # 策略模型:我们将要微调的模型
    POLICY_MODEL = "distilgpt2" 
    # 奖励模型:一个固定的、提供反馈的模型。这里用情感分析模型模拟。
    # 我们的目标是让 distilgpt2 生成更积极的评论。
    REWARD_MODEL = "distilbert-base-uncased-finetuned-sst-2-english"

    # 用于训练的 prompts
    PROMPTS = [
        "The movie was",
        "I think this product is",
        "The customer service I received was",
        "This new AI technology seems",
        "The restaurant's food was",
        "My holiday trip to the island felt",
        "The book I just read was",
        "Overall, my experience was",
        "I went to the concert and it was",
        "The new software update is",
        "The quality of the camera is",
        "The hotel we stayed at was",
    ] * 5 # 乘以5以增加训练数据量

    # --- 运行训练 ---
    trained_policy_model, trained_tokenizer = train_j_bond(
        policy_model_name=POLICY_MODEL,
        reward_model_name=REWARD_MODEL,
        prompts=PROMPTS,
        beta=0.5,      # 论文推荐的Jeffreys散度
        eta=0.02,      # 锚点更新速率
        gamma=0.1,     # KL正则化强度
        epochs=10,
        batch_size=4,
        learning_rate=2e-5,
        max_length=50
    )

    # --- 测试训练后的模型 ---
    print("\n--- Testing the trained J-BOND model ---")
    test_prompt = "The new phone is"
    inputs = trained_tokenizer(test_prompt, return_tensors="pt").to(trained_policy_model.device)
    
    # 生成一些样本并评估其奖励
    with torch.no_grad():
        # 确保模型处于评估模式
        trained_policy_model.eval()
        
        # 检查是否需要设置pad_token_id
        if trained_tokenizer.pad_token_id is None:
            trained_tokenizer.pad_token_id = trained_tokenizer.eos_token_id
            
        # 生成文本
        generated_outputs = trained_policy_model.generate(
            inputs.input_ids,  # 明确指定input_ids而不是使用**inputs
            attention_mask=inputs.attention_mask if hasattr(inputs, 'attention_mask') else None,
            max_length=50, 
            num_return_sequences=5, 
            do_sample=True, 
            top_k=50,
            pad_token_id=trained_tokenizer.pad_token_id,
            temperature=0.7  # 添加temperature参数控制生成的随机性
        )
    generated_texts = trained_tokenizer.batch_decode(generated_outputs, skip_special_tokens=True)
    rewards = get_reward(reward_model, trained_tokenizer, generated_texts, trained_policy_model.device)

    for text, reward in zip(generated_texts, rewards):
        print(f"Generated: '{text}' | Reward (positivity score): {reward.item():.4f}")

参考资料

  1. www.cnblogs.com/xyz/p/13929…
  2. arxiv.org/pdf/2407.14…