LLM中的强化学习:PPO、DPO、KTO等

2,594 阅读20分钟

为什么要用强化学习训练

注重优化长期回报,而非即时准确性。 例如,在对话系统中,这意味着对模型的训练不仅仅是为了提供正确的下一个单词预测,而是为了进行连贯且与上下文相适应的对话,从而在整个会话过程中最大限度地提高用户满意度。

另外,有研究表明,使用RLHF训练,比SFT训练,效果好很多。

强化学习系列算法

zhuanlan.zhihu.com/p/627747410 北大王选研究所,介绍的比较简洁清楚

0. 概述

强化学习是一种利用反馈来学习策略的范式。具体而言,如下图所示,强化学习的模型(Agent)与环境(Environment)交互,对于每个给定状态st采取动作at并从环境获得奖励rt,同时进入下一状态s[t+1],这一过程循环往复。在积累了这一系列交互经验后,模型通过调整自己的策略以让交互过程得到的奖励最大化。这样一来Agent就学习到了在给定状态下采取有益的动作的策略,实现了强化学习的目标。

  1. Policy Gradient RL,策略梯度,强化学习的基础算法

如何调整策略以最大化奖励呢?换言之,如何设计一个可学习的优化目标,使奖励最大化呢?

具体而言,我们首先与环境进行一系列交互,从初始状态s1开始到结束状态s_final,模型依次做出了动作a1~an并分别获得了奖励r1~rn,在每步做决策时,模型都会给出概率分布π(at|st)。这一从开始到结束的交互过程我们称之为一条轨迹,将这条轨迹的所有奖励求和即可得到轨迹的总奖励****R(π)

不难发现,上图中的公式就衡量了给定策略下采样若干轨迹所能得到的期望奖励,利用梯度上升优化这一目标即可使得策略往“奖励更大”的方向优化,即得到更优的策略。

这一直接的思路被称为策略梯度,是RL的基础方法之一。

  1. Q-learning 与 Actor-Critic

策略梯度方法虽然直观,但在实践中往往难以取得效果,这是因为每条轨迹的奖励本身具有较大的方差,可能导致训练难以收敛

具体而言,如果有些较大价值的轨迹没有被采样到,根据现有优化目标,模型可能反而会提升一些价值较小的轨迹的策略概率。

...

我们希望直接估计某条轨迹的“价值”,亦即它的总奖励的期望,这便是接下来即将介绍的Actor-Critic的思想。

我们希望估计 R(τ)-b 的期望。从另一个角度理解,这个值也可以被视为 st 上采取动作 at 后,未来的期望收益能带来多大的提升。 我们将这个值的期望写作**A(st,at)** (未来期望收益的提升) ,这个A是优势函数(Advantage)的缩写。那么,这个优势函数又该如何计算呢?

先简单介绍一下Q-learning相关的概念。现在我们 V(s) 代表采取某种策略下,状态 s 后所能获得的期望收益。如果我们能准确估计V(s),那优势函数即可被表示为 A(st,at) = rt+V(s[t+1])-V(st)

最后,我们只需要考虑如何估计V函数即可。值得注意的是,对V的估计本身就是强化学习中另一主流分支Q-learning的研究课题,在这一领域内有大量的研究( DQN 、Double-DQN等) ,在这里我们只介绍最基础、最直觉的V估计方法。

如下图所示,顾名思义,既然V(s)衡量状态s以后的期望收益,我们可以直接采样轨迹,平均所有采样到的s后收益来衡量V,这一方法被称为Monte-Carlo(MC)

完成了对V的估计,我们便可以利用优势函数训练策略网络了。在训练中我们会同时训练两个模型,一个是策略网络(Actor),一个是估计V的估值网络(Critic),因此这类方法也被称为Actor-Critic(AC)。

延伸:Actor-critic 和 GAN的异同

本回答由chat-gpt生成

强化学习中的actor-critic方法和生成对抗网络(GAN)在设计思想上有相似之处,但也存在本质的不同。下面将详述这两种方法的共同点和区别。

共同点:

  1. 两部分组成:actor-critic和GAN都由两部分组成。在actor-critic中,"actor"代表采取行动的策略,而"critic"评估这些行动的好坏;在GAN中,"generator"生成数据,而"discriminator"评估数据的真实性。
  2. 对抗性思想:这两种方法都包含了一种对抗性的思想。Actor-critic中的"critic"评估"actor"的行动并提供反馈,这和GAN中"discriminator"试图识别"generator"生成的假数据有着相似的过程。
  3. 不断的迭代学习:actor-critic方法和GAN都依靠不断的迭代学习过程。在actor-critic中,"actor"和"critic"相互学习和适应;在GAN中,"generator"和"discriminator"互相对抗以提升性能。

区别:

  1. 目标不同:actor-critic旨在解决决策问题,通过学习策略来最大化长期奖励;而 GAN 旨在生成逼真的数据样本,学习真实数据的分布。
  2. 学习信号的来源:actor-critic中,"critic"基于环境反馈学习如何评估"actor"的行动;GAN中的"discriminator"则是通过区分真实数据与生成数据来训练自身。
  3. 训练过程:尽管两者都涉及迭代学习,但它们的迭代机制不同。Actor-critic模型通常同时更新"actor"和"critic";而在 GAN 中,"generator"和"discriminator"通常是交替更新。
  4. 环境交互:actor-critic需要与环境互动,而环境对"actor"采取的行动提供了实时的奖励;GAN并不直接与外部环境互动,它学习的是内部"generator"和"discriminator"之间的博弈。
  5. 训练稳定性:GAN著名的挑战之一是训练不稳定性,这需要精心的网络设计和训练过程调整。相比之下,actor-critic方法虽然也面临决策方面的挑战,但通常情况下训练过程更加稳定。
  6. 输出类型:"actor"的输出是行动或者行动的概率,旨在与环境交互;"generator"的输出是新的数据样本,旨在模仿真实数据的特征。

总的来说,actor-critic方法和GAN在概念上都包含了两个网络或模型,分别负责生成输出和评估这些输出。尽管它们在以这种方式训练网络方面有着相似的思想,但actor-critic用于强化学习中的决策任务,而GAN用于生成数据任务,两者的目标、激励机制和应用场景有着本质的差异。

  1. Proximal Policy Optimization(PPO)

训练AC时需要与环境交互来采样很多轨迹,然后利用这些轨迹训练Actor和Critic;然而,这一过程是十分费时的,这可能导致我们无法高效的采集大量数据,进而充分的训练模型。因此,我们考虑是否能将已有的轨迹数据复用以提高训练效率。

这一思路将我们指向了off-policy RL的道路。具体而言,我们希望有两个策略网络π1和π2,其中π1不断与环境交互收集数据,这些数据可以重复使用以训练π2的参数。

有了这些铺垫,我们终于得到了一个可以高效训练的RL算法:Proximal Policy Optimization(PPO),近期获得很大关注的InstructGPT、ChatGPT便在底层使用了PPO进行强化学习。PPO是一种对上述Off-policy RL目标的实现,分析其优化目标不难发现,它首先最大化原始优化目标A*π2/π1,其次又防止π2/π1偏离1太多,即控制了两个分布的差距。

PPO在RLHF中的用法

PPO在RL中的地位,相当于BERT和 GPT 在NLP中的地位

  1. 收集人类反馈,人工标注数据

    1. 以summary任务为例,随机从数据集中抽取问题,对于每个问题,生成多个不同的回答
    2. 人工标注,判断哪个回答更符合人类期望,给出排名
  2. 训练奖励模型(reward model, RM)

    1. 对多个排序结果,两两组合,形成多个训练数据对
    2. 奖励模型接受一对输入输出数据,给出评价:回答质量分数 (标量奖励,数值上表示人的偏好)
    3. 调节参数使得高质量回答的打分比低质量的打分要高。
  3. 采用PPO强化学习,优化策略(Proximal Policy Optimization,近端策略优化)

    1. 从数据集中抽取问题,使用 PPO 模型(包括ref model、actor model)生成回答(即不需要人工标注)并利用第二阶段训练好的 奖励模型 打分
    2. 把奖励分数依次传递,由此产生策略梯度,通过强化学习的方式更新 PPO 模型参数,训练目标是使得生成的文本要在 奖励模型 上获得尽可能高的得分。

PPO模型训练的细节:

奖励模型RM的训练
  1. RM的训练:我们只需人工标注一些偏好数据(例如对于一个输入,我们让模型给出若干输出,并由标注人员对这些输出的好坏程度进行排序),并通过对比学习让RM最大化好输出与坏输出的分数差

    1. pairwise ranking loss: log(σ((x,yw)−(x,yl)))
    2. RM 模型的目标是使得排序高的答案yw对应的标量分数要高于排序低的答案yl对应的标量分数,且越高越好,也就是使得损失函数中的rθ(x,yw)−rθ(x,yl)这个差值越大越好。将相减后的分数通过 sigmoid 函数,差值变成 - 1 到 1 之间,由于 sigmoid 函数是单调递增的函数,因此σ(rθ(x,yw)−rθ(x,yl))越大越好。σ(rθ(x,yw)−rθ(x,yl))越接近 1,表示ywyl排序高,属于 1 这个分类,反正属于 - 1 这个分类,所以这里也可以看成是一个二分类问题。
    3. 奖励模型 中每个问题对应的答案数量即**K**值为什么选 9 更合适,而不是选择 4 呢?
    4. 进行标注的时候,需要花很多时间去理解问题,但答案之间比较相近,假设 4 个答案进行排序要 30 秒时间,那么 9 个答案排序可能就 40 秒就够了。9 个答案与 4 个答案相比生成的问答对多了 5 倍,从效率上来看非常划算;
    5. K=9时,每次计算 loss 都有 36 项rθ(x,y)需要计算,RM 模型的计算所花时间较多,但可以通过重复利用之前算过的值(也就是只需要计算 9 次即可),能节约很多时间。
    6. 奖励模型 的损失函数为什么会比较答案的排序,而不是去对每一个答案的具体分数做一个回归?
    7. 每个人对问题的答案评分都不一样,无法使用一个统一的数值对每个答案进行打分。如果采用对答案具体得分回归的方式来训练模型,会造成很大的误差。但是,每个人对答案的好坏排序是基本一致的。通过排序的方式避免了人为的误差。

生成模型训练
  1. 生成模型的训练:我们可以将“输入-生成模型输出- RM 反馈”作为一个只有一步的轨迹(输入是s1,输出是a1,RM的反馈是奖励),并在这些轨迹上利用 PPO 进行强化学习。如下图所示,我们只需最大化PPO的优化目标即可实现对生成模型的训练。

    1. 训练过程中,policy model 会不断更新,为了不让它偏离SFT阶段的模型太远,OpenAI在训练过程中增加了KL离散度约束,保证模型在得到更好的结果同时不会跑偏这是因为Comparison Data不是一个很大的数据集,不会包含全部的回答,对于任何给定的提示,都有许多可能的回答,其中绝大多数是 RM 以前从未见过的。对于许多未知(提示、响应)对,RM 可能会错误地给出极高或极低的分数。如果没有这个约束,模型可能会偏向那些得分极高的回答,它们可能不是好的回答。

    2. RL 模型的优化目标是使得RL模型生成的文本在 奖励模型 中的得分越高越好,损失函数可以分为三个部分,打分部分、KL 散度部分以及预训练部分。

      1. 打分部分: 将 RL 模型的问题数据集x,通过π_ϕRL模型得到答案y,然后再把这对(x,y)代入 RW 模型进行打分,即损失函数公式中的rθ(x,y)该分数越高,代表模型生成的答案越好。
      2. KL 散度部分:在每次更新参数后,π_ϕRL会发生变化,x通过π_ϕRL生成的y也会发生变化,而rθ(x,y)奖励模型是根据π_SFT模型的数据训练而来。如果**π_ϕRLπSFT差的太多,则会导致rθ(x,y)** 的分数估算不准确。因此需要通过 KL 散度来计算,π_ϕRL生成的答案分布和πSFT生成的答案分布之间的距离,使得两个模型之间不要差的太远。损失函数公式中的log(π_ϕRL(y∣x)/πSFT(y∣x))就是在计算 KL 散度。由于 KL 散度是越小越好,而训练目标是损失函数越大越好,因此在前面需要加上一个负号。
      3. 预训练部分:预训练部分对应损失函数中的Ex∼Dpretrain[log(πϕRL(x))]。如果没有该项,那么模型最终可能只对这一个任务能够做好,在别的任务上会发生性能下降。因此,需要将预训练阶段的目标函数加上,使得前面两个部分在新的数据集上做拟合的同时保证原始的数据也不会丢弃。

  1. DPO

DPO 的核心思想

  • 跳过了奖励建模步骤,直接使用偏好数据优化语言模型;
  • 解决三个阶段的训练(SFT->RM->PPO)过程较长,更新迭代较慢的问题
  • juejin.cn/post/730299…

DPO (Direct Preference Optimization) 提出了一种使用二进制交叉熵目标来精确优化 LLM 的方法,以替代基于 RLHF 的优化目标,从而大大简化偏好学习 pipeline。也就是说,完全可以直接优化语言模型以实现人类的偏好,而不需要明确的 奖励模型 或强化学习。

与现有的算法一样,DPO 也依赖于理论上的偏好模型(如 Bradley-Terry 模型),以此衡量给定的奖励函数与经验偏好数据的吻合程度。然而,现有的方法使用偏好模型定义偏好损失来训练奖励模型,然后训练优化所学奖励模型的策略,而 DPO 使用变量的变化来直接定义偏好损失作为策略的一个函数。 鉴于人类对模型响应的偏好数据集,DPO 因此可以使用一个简单的二进制交叉熵目标来优化策略,而不需要明确地学习奖励函数或在训练期间从策略中采样。

DPO 的Loss 设计

  1. KTO: Model Alignment as Prospect Theoretic Optimization

zhuanlan.zhihu.com/p/693163438

DPO依赖特殊的训练数据:

- 问题
  - 期望回答
  - 拒绝回答

KTO避免了这个问题:只要告诉我**这个回答是不是所期望的**就行。

具体做法也很简单,把DPO的loss拆分成正负两部分:

  • 如果只有正样本,那就只计算正样本的loss

  • 如果只有负样本,那就只计算负样本的loss

  • 如果正负样本都有(像DPO那种数据),那就都计算

r是DPO、RLHF中的loss,z是KL散度

  1. Direct Preference Optimization with an Offset

出发点是并非所有的偏好对都是相等程度的:在某些情况下,首选响应只比不受欢迎的响应稍微好一点,而在另一些情况下,对一个响应的偏好可能更强烈,例如,当另一个响应包含有害或有毒内容时。

本文提出了DPO的一种泛化形式,称为带偏移的DPO(ODPO), 微调 过程中不会平等对待每个偏好对。 直观地说,ODPO要求首选响应和不受欢迎响应之间的可能性差异大于一个偏移值。偏移值的确定基于一个响应相对于另一个响应的偏好程度。

实现也很简单,类似于triplet loss,加了一个offset

  • 方法一,手动指定绝对值
  • 方法二,指定比例
# https://github.com/rycolab/odpo/blob/main/trainers.py
def preference_loss(policy_chosen_logps: torch.FloatTensor,
                    policy_rejected_logps: torch.FloatTensor,
                    reference_chosen_logps: torch.FloatTensor,
                    reference_rejected_logps: torch.FloatTensor,
                    beta: float,
                    label_smoothing: float = 0.0,
                    ipo: bool = False,
                    reference_free: bool = False,
                    offset: bool = False,
                    ratio: bool = False,
                    alpha: float = 1.,
                    chosen_rewards: torch.FloatTensor = None,
                    rejected_rewards: torch.FloatTensor = None) -> Tuple[
    torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """Compute the DPO loss for a batch of policy and reference model log probabilities.

    Args:
        policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
        policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
        reference_chosen_logps: Log probabilities of the reference model for the chosen responses. Shape: (batch_size,)
        reference_rejected_logps: Log probabilities of the reference model for the rejected responses. Shape: (batch_size,)
        beta: Temperature parameter for the DPO loss, typically something in the range of 0.1 to 0.5. We ignore the reference model as beta -> 0.
        label_smoothing: conservativeness for DPO loss, which assumes that preferences are noisy (flipped with probability label_smoothing)
        ipo: If True, use the IPO loss instead of the DPO loss.
        reference_free: If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.

    Returns:
        A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
        The losses tensor contains the DPO loss for each example in the batch.
        The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
    """
    pi_logratios = policy_chosen_logps - policy_rejected_logps
    ref_logratios = reference_chosen_logps - reference_rejected_logps

    if reference_free:
        ref_logratios = 0

    logits = pi_logratios - ref_logratios  # also known as h_{\pi_\theta}^{y_w,y_l}

    if ipo:
        losses = (logits - 1 / (
                    2 * beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
    elif offset:
 if ratio:
logging.warning( "using ratio" )
margin = torch.stack(chosen_rewards, dim= 0 ) / torch.stack(rejected_rewards, dim= 0 )
 else :
margin = torch.stack(chosen_rewards, dim= 0 ) - torch.stack(rejected_rewards, dim= 0 )
margin = torch.log(margin.to(logits.device))

losses = -F.logsigmoid(beta * logits - alpha * margin)
    else:
        # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
        losses = -F.logsigmoid(beta * logits) * (1 - label_smoothing) - F.logsigmoid(
            -beta * logits) * label_smoothing

    chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
    rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

    return losses, chosen_rewards, rejected_rewards

7. ### ORPO

mp.weixin.qq.com/s/kp4LQMygO… (解释的很清晰,从PPO到DPO到ORPO)

ORPO insight

ORPO是一种创新的方法,旨在通过单步过程直接优化语言模型,而不需要独立的奖励模型或参考模型。这种方法通过修改传统的有监督微调(SFT)阶段的损失函数实现,以便更直接地对偏好进行学习和优化。

  • key sights:SFT可以有效提升正确样本的生成概率,但是SFT之后正样本和负样本的生成概率有同时上升的现象,从而导致SFT之后还需要一个进一步的纠正。
  • 该现象的原因:SFT交叉损失熵: L=−m1k=1∑mi=1∑∣Vyi(k)⋅log(pi(k))

  • 其中 yi 是一个布尔值,表示词汇表V中的第i个标记是否为标签标记, pi 表示第 i 个标记出现的概率,m表示序列的长度。单独使用交叉熵不会对non-answer令牌的对数进行直接惩罚 ,因为 yi 将被设置为0,没有机制来惩罚被拒绝的响应,因此被拒绝的响应中令牌的对数概率随着选择的响应而增加,这从偏好对齐的角度来看是不希望的

SFT loss对于rejected data没有惩罚项, 所以SFT阶段不仅使得chosen data的Log Probability增加了,同时也增加了rejected data的Log Probability。通过实验观察,验证了这个假设。可以看到只用正例做SFT的同时观察负例的生成概率,会得到结论两者是同时上升的。通过ORPO的loss改进,这个问题得到了解决。

modelDPOORPO
pipelineSFT + Policy learningPolicy learning
公式DPO是RLHF的一个无损推理加入惩罚项来加速模型更新
计算负荷在两个模型上发生4次前向传递无需参考模型-内存友好;2次传递
实践模型:pythia 2.8B硬件:4张80GB A100数据集:Anthropic-HH train-161k test-8.55kSFT:1小时30分钟DPO:2 小时 45 分钟模型:Gemma 2B硬件:i9-13900HX/32GB,GPU 4090/16GB数据集:argilla/dpo-mix-7k运行了3个多小时
缺点相比之下硬件资源、时间花费更多训练使用的模型较小;对比的模型不够广泛
  1. simPO

相比DPO,simPO不需要reference model,并且有更好的效果。simPO的另一个好处是,能够保持生成结果在较短长度下的质量。

DPO的缺陷

理论上,DPO的优化目标和RLHF是一致的,但是DPO有两个缺陷:

  • 仍然需要一个reference model,这样依然有比较大的内存和计算开销
  • 训练过程中优化的reward和推理时的生成指标存在差异,也就是训练和推理的目标不完全对齐

第二点怎么理解呢?模型在自回归生成response时,理论上是寻找最大化所有 token 平均 log likelihood的组合,即

当然实际上这个组合空间太大了,没法直接遍历寻找,因此会使用一些解码策略来寻找局部最优解,比如greedy decoding、beam search或者top-k sampling等,不过我们还是可以按这个公式近似计算。另外这个公式还是可用在多个response/多选题的排序上的。

可以看到推理时的这个目标和DPO的reward差了个referenc model。那么在DPO里,满足

的偏好数据并不一定意味着

SimPO的实现

从上面这个分析,我们自然就想到要把训练的目标往推理目标上靠拢对齐。那么最直接的做法,就是把reward从

r(x,y)=βlogπθ(yx)πref(yx)\begin{aligned}r^*(x,y)=\beta\log\frac{\pi_\theta(y\mid x)}{\pi_\text{ref}(y\mid x)}\end{aligned}(这里省略了配分函数Z)

变成$$\begin{aligned}r_{\text{SimPO}}(x,y)=\frac{\beta}{|y|}\log\pi_\theta(y\mid x)=\frac{\beta}{|y|}\sum_{i=1}^{|y|}\log\pi_\theta(y_i\mid x,y_{

注意这里有个长度归一化项,这个很重要,没有这一项的话,模型会倾向于生成长度更长但是低质量的内容。

除了修改reward的计算,simPO和IPO、ODPO一样,引入了一个reward margin,这是一个固定的超参,要求winning response和losing response的reward差值要大于reward margin

p(ywylx)=σ(r(x,yw)r(x,yl)γ)p(y_w\succ y_l\mid x)=\sigma\left(r(x,y_w)-r(x,y_l)-\gamma\right)

按已有的经验,增大这个margin有助于提高模型泛化能力,但是太大的margin也会导致模型的退化。

至此我们得到了simPO的损失函数

LSimPO(πθ)=E(x,yw,yl)D[logσ(βywlogπθ(ywx)βyllogπθ(ylx)γ)]\mathcal{L}_{\text{SimPO}}(\pi_\theta)=-\mathbb{E}_{(x,y_w,y_l)\thicksim\mathcal{D}}\left[\log\sigma\left(\frac{\beta}{|y_w|}\log\pi_\theta(y_w|x)-\frac{\beta}{|y_l|}\log\pi_\theta(y_l|x)-\gamma\right)\right]

各种RL loss 对比

参考