深度强化学习中的DPO算法:理论与实践

742 阅读6分钟

深度强化学习中的DPO算法:理论与实践

引言

随着大型语言模型(LLMs)在自然语言处理领域的卓越表现,如何有效地对其进行微调以使其行为更符合人类偏好成为了一个重要的研究方向。传统的监督微调(SFT)虽然有效,但其目标函数往往无法直接捕捉人类的复杂偏好。而强化学习(RL)提供了一种更直接的方式来优化模型行为。本文将深入探讨一种名为直接偏好优化(Direct Preference Optimization, DPO) 的强化学习算法,它在无需求助复杂的奖励模型(RM)的情况下,也能有效地将人类偏好融入到大型语言模型的训练中。我们将从其数学理论出发,并提供详细的算法代码实现。

1. 强化学习基础回顾

在深入DPO之前,我们先简要回顾强化学习的一些基本概念。

强化学习的核心思想是智能体(agent)通过与环境(environment)的交互学习,以最大化累积奖励。

  • 策略(Policy)π(as)\pi(a|s),表示在状态ss下选择动作aa的概率。
  • 奖励函数(Reward Function)R(s,a,s)R(s, a, s'),表示从状态ss采取动作aa到达状态ss'所获得的奖励。
  • 价值函数(Value Function)Vπ(s)V^\pi(s)Qπ(s,a)Q^\pi(s, a),分别表示在策略π\pi下,从状态ss开始的期望累积奖励,或在状态ss采取动作aa后的期望累积奖励。

在RLHF(Reinforcement Learning from Human Feedback)的背景下,奖励函数通常由一个奖励模型(Reward Model, RM)来学习,该模型根据人类对模型输出的偏好数据进行训练。

2. DPO算法的数学理论

2.1 RLHF的挑战

传统的RLHF流程通常分为几个阶段:

  1. 预训练(Pre-training):在大规模文本数据上训练一个基础语言模型。
  2. 监督微调(Supervised Fine-tuning, SFT):在高质量的人类标注数据上对模型进行微调,使其能够生成符合指令的文本。
  3. 奖励模型训练(Reward Model Training, RM):收集人类偏好数据(例如,对模型生成的两个回复进行比较,选择更偏好的那个),训练一个奖励模型来预测人类偏好。
  4. 强化学习(Reinforcement Learning):使用训练好的奖励模型作为奖励函数,通过PPO(Proximal Policy Optimization)等算法对语言模型进行微调,使其生成更高奖励的文本。

这种多阶段的RLHF流程存在一些挑战:

  • 奖励模型训练的复杂性:训练一个高质量的奖励模型本身就是一项复杂的任务,需要大量的偏好数据和仔细的模型设计。
  • 奖励模型的准确性限制:奖励模型并不能完美地捕捉人类偏好,其误差会累积到策略优化中。
  • RL训练的不稳定性:PPO等RL算法通常难以调参,训练过程可能不稳定,且计算成本较高。

2.2 DPO的核心思想

DPO旨在简化RLHF流程,直接优化策略以满足人类偏好,而无需显式地训练一个奖励模型。DPO的核心思想是:如果我们有一个数据集,其中包含人类偏好的成对样本(x,yw,yl)(x, y_w, y_l),其中xx是提示,ywy_w是人类更偏好的响应,yly_l是人类不偏好的响应,那么我们可以直接通过这些偏好数据来优化策略。

DPO的理论基础是:存在一个最优策略,其在人类偏好数据上的表现是最好的。 DPO的目标是找到一个策略πθ\pi_\theta(例如,一个LLM),使得对于给定的提示xx,模型生成ywy_w的概率高于生成yly_l的概率。

2.3 偏好数据的概率模型

假设我们有一个提示xx,以及两个候选响应y1y_1y2y_2。人类选择ywy_w而非yly_l的概率可以通过Bradley-Terry模型来建模:

P(ywylx)=exp(r(x,yw))exp(r(x,yw))+exp(r(x,yl))P(y_w \succ y_l | x) = \frac{\exp(r(x, y_w))}{\exp(r(x, y_w)) + \exp(r(x, y_l))}

其中r(x,y)r(x, y)是奖励函数,表示响应yy在提示xx下的质量。

DPO的关键洞察在于,对于一个给定的策略πθ\pi_\theta,其生成的响应yy的奖励函数r(x,y)r(x, y)可以由策略πθ\pi_\theta与参考策略πref\pi_{\text{ref}}(通常是SFT后的模型)之间的对数概率比来表示:

r(x,y)logπθ(yx)πref(yx)r(x, y) \propto \log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}

这个比例关系是RL中关于策略优化的一个重要结果,特别是与KL散度惩罚相关的理论。具体来说,当我们在RL中加入KL散度惩罚项,即最大化E[rewardβDKL(πθπref)]E[\text{reward} - \beta D_{KL}(\pi_\theta || \pi_{\text{ref}})]时,最优策略πθ\pi_\theta^*的形式为:

πθ(yx)πref(yx)exp(1βr(x,y))\pi_\theta^*(y|x) \propto \pi_{\text{ref}}(y|x) \exp(\frac{1}{\beta} r(x, y))

从这个式子中,我们可以反推得到奖励函数r(x,y)r(x, y)与策略概率之间的关系。

2.4 DPO目标函数推导

r(x,y)r(x, y)的表达形式代入Bradley-Terry模型,我们得到:

P(ywylx)=exp(1βlogπθ(ywx)πref(ywx))exp(1βlogπθ(ywx)πref(ywx))+exp(1βlogπθ(ylx)πref(ylx))P(y_w \succ y_l | x) = \frac{\exp(\frac{1}{\beta} \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)})}{\exp(\frac{1}{\beta} \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)}) + \exp(\frac{1}{\beta} \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)})}

为了最大化这个概率,DPO的目标函数是最大化所有偏好对的对数似然:

LDPO(θ)=E(x,yw,yl)D[logσ(βlogπθ(ywx)πref(ywx)βlogπθ(ylx)πref(ylx))]\mathcal{L}_{\text{DPO}}(\theta) = \mathbb{E}_{(x, y_w, y_l) \sim \mathcal{D}} \left[ \log \sigma \left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right]

其中:

  • D\mathcal{D}是偏好数据集。
  • σ()\sigma(\cdot)是Sigmoid函数。
  • β\beta是控制KL散度惩罚项强度的超参数,类似于PPO中的KL系数。它平衡了模型生成高质量响应和保持与参考模型接近之间的关系。

通过最大化这个目标函数,我们直接鼓励模型生成更偏好的响应,同时惩罚生成不偏好的响应,而无需显式地训练一个奖励模型。这使得DPO成为一个直接且稳定的优化方法。

3. DPO算法实现

3.1 数据准备

DPO需要一个包含偏好对的数据集。每个数据点应包含一个提示xx,一个被选中的响应ywy_w,以及一个被拒绝的响应yly_l

from transformers import AutoTokenizer
import torch

class PreferenceDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = item['prompt']
        chosen = item['chosen']
        rejected = item['rejected']

        # 对prompt, chosen, rejected进行编码
        prompt_inputs = self.tokenizer(prompt, truncation=True, max_length=self.max_length, return_tensors="pt")
        chosen_inputs = self.tokenizer(chosen, truncation=True, max_length=self.max_length, return_tensors="pt")
        rejected_inputs = self.tokenizer(rejected, truncation=True, max_length=self.max_length, return_tensors="pt")

        return {
            "prompt_input_ids": prompt_inputs.input_ids.squeeze(0),
            "prompt_attention_mask": prompt_inputs.attention_mask.squeeze(0),
            "chosen_input_ids": chosen_inputs.input_ids.squeeze(0),
            "chosen_attention_mask": chosen_inputs.attention_mask.squeeze(0),
            "rejected_input_ids": rejected_inputs.input_ids.squeeze(0),
            "rejected_attention_mask": rejected_inputs.attention_mask.squeeze(0),
        }

# 示例数据
sample_data = [
    {"prompt": "写一首关于秋天的诗。", "chosen": "秋风习习,落叶飘零,枫叶如火,染红山林。", "rejected": "夏天到了,天气很热。"},
    {"prompt": "介绍一下DPO算法。", "chosen": "DPO是一种直接偏好优化算法,无需奖励模型。", "rejected": "PPO是一种强化学习算法。"},
]

# 初始化tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") # 假设使用BERT tokenizer,实际LLM需要对应tokenizer

# 创建数据集
# dataset = PreferenceDataset(sample_data, tokenizer)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=2)

3.2 DPO 模型的构建

DPO 算法需要两个模型:

  • 策略模型(Policy Model):这是我们要优化的 LLM,它将生成响应。
  • 参考模型(Reference Model):通常是经过 SFT 后的模型,其参数在 DPO 训练过程中保持不变。它用于计算 KL 散度惩罚。

我们将使用 transformers 库中的 AutoModelForCausalLM 来构建这两个模型。

from transformers import AutoModelForCausalLM, AdamW
import torch.nn.functional as F

class DPOTrainer:
    def __init__(self,
                 policy_model_path: str,
                 ref_model_path: str,
                 tokenizer_path: str,
                 beta: float = 0.1,
                 learning_rate: float = 5e-5,
                 device: str = "cuda" if torch.cuda.is_available() else "cpu"):
        
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
        self.policy_model = AutoModelForCausalLM.from_pretrained(policy_model_path).to(device)
        # 参考模型通常是policy_model的一个副本,且不进行梯度更新
        self.ref_model = AutoModelForCausalLM.from_pretrained(ref_model_path).to(device)
        self.ref_model.eval() # 设置为评估模式,不进行梯度更新
        
        self.beta = beta
        self.learning_rate = learning_rate
        self.device = device

        self.optimizer = AdamW(self.policy_model.parameters(), lr=self.learning_rate)

    def compute_dpo_loss(self,
                         prompt_input_ids: torch.Tensor,
                         prompt_attention_mask: torch.Tensor,
                         chosen_input_ids: torch.Tensor,
                         chosen_attention_mask: torch.Tensor,
                         rejected_input_ids: torch.Tensor,
                         rejected_attention_mask: torch.Tensor):
        
        # 计算chosen响应的对数概率
        # 注意:这里需要计算整个序列的对数概率,而不仅仅是最后一个token
        # 我们需要 masks 来正确地处理填充,并只计算生成部分的概率
        
        # 将prompt和chosen拼接起来作为policy model的输入
        chosen_full_input_ids = torch.cat((prompt_input_ids, chosen_input_ids), dim=-1)
        chosen_full_attention_mask = torch.cat((prompt_attention_mask, chosen_attention_mask), dim=-1)

        # 计算policy model对chosen responses的logits
        # outputs.logits 形状 (batch_size, sequence_length, vocab_size)
        policy_chosen_logits = self.policy_model(
            input_ids=chosen_full_input_ids,
            attention_mask=chosen_full_attention_mask
        ).logits

        # 计算reference model对chosen responses的logits
        ref_chosen_logits = self.ref_model(
            input_ids=chosen_full_input_ids,
            attention_mask=chosen_full_attention_mask
        ).logits

        # 计算chosen log probabilities
        # log_softmax 应用在最后一个维度 (vocab_size)
        policy_chosen_log_probs = F.log_softmax(policy_chosen_logits, dim=-1)
        ref_chosen_log_probs = F.log_softmax(ref_chosen_logits, dim=-1)

        # 接下来,我们只关注生成部分的概率,即chosen responses的部分
        # 我们需要根据prompt的长度来截取
        prompt_len = prompt_input_ids.shape[1]
        
        # 计算生成部分的log_probs
        # chosen_log_probs_policy: (batch_size, chosen_seq_len)
        chosen_log_probs_policy = torch.gather(policy_chosen_log_probs[:, prompt_len-1:-1], 2, chosen_input_ids.unsqueeze(-1)).squeeze(-1)
        chosen_log_probs_ref = torch.gather(ref_chosen_log_probs[:, prompt_len-1:-1], 2, chosen_input_ids.unsqueeze(-1)).squeeze(-1)

        # 计算每个token的mask,排除padding和prompt部分
        chosen_mask = (chosen_input_ids != self.tokenizer.pad_token_id).float()
        
        # 乘以mask并求和,得到整个序列的log_prob
        chosen_log_probs_policy = (chosen_log_probs_policy * chosen_mask).sum(dim=-1)
        chosen_log_probs_ref = (chosen_log_probs_ref * chosen_mask).sum(dim=-1)

        # 对rejected响应做同样的操作
        rejected_full_input_ids = torch.cat((prompt_input_ids, rejected_input_ids), dim=-1)
        rejected_full_attention_mask = torch.cat((prompt_attention_mask, rejected_attention_mask), dim=-1)

        policy_rejected_logits = self.policy_model(
            input_ids=rejected_full_input_ids,
            attention_mask=rejected_full_attention_mask
        ).logits
        ref_rejected_logits = self.ref_model(
            input_ids=rejected_full_input_ids,
            attention_mask=rejected_full_attention_mask
        ).logits

        policy_rejected_log_probs = F.log_softmax(policy_rejected_logits, dim=-1)
        ref_rejected_log_probs = F.log_softmax(ref_rejected_logits, dim=-1)

        rejected_log_probs_policy = torch.gather(policy_rejected_log_probs[:, prompt_len-1:-1], 2, rejected_input_ids.unsqueeze(-1)).squeeze(-1)
        rejected_log_probs_ref = torch.gather(ref_rejected_log_probs[:, prompt_len-1:-1], 2, rejected_input_ids.unsqueeze(-1)).squeeze(-1)

        rejected_mask = (rejected_input_ids != self.tokenizer.pad_token_id).float()
        
        rejected_log_probs_policy = (rejected_log_probs_policy * rejected_mask).sum(dim=-1)
        rejected_log_probs_ref = (rejected_log_probs_ref * rejected_mask).sum(dim=-1)

        # 计算ratio
        pi_log_ratio = chosen_log_probs_policy - rejected_log_probs_policy
        ref_log_ratio = chosen_log_probs_ref - rejected_log_probs_ref

        # 计算DPO损失
        # 这里的log_prob_diff是论文中的 r_theta(x, y_w) - r_theta(x, y_l)
        # 对应 beta * (log(pi_theta(y_w|x)/pi_ref(y_w|x)) - log(pi_theta(y_l|x)/pi_ref(y_l|x)))
        # 也就是 beta * (chosen_log_probs_policy - chosen_log_probs_ref - (rejected_log_probs_policy - rejected_log_probs_ref))
        # 简化为 beta * (pi_log_ratio - ref_log_ratio)
        
        logits = self.beta * (pi_log_ratio - ref_log_ratio)
        loss = -F.logsigmoid(logits).mean() # 对数似然

        return loss

    def train(self, dataloader, num_epochs: int = 1):
        self.policy_model.train()
        for epoch in range(num_epochs):
            total_loss = 0
            for batch in dataloader:
                # 移动数据到设备
                prompt_input_ids = batch["prompt_input_ids"].to(self.device)
                prompt_attention_mask = batch["prompt_attention_mask"].to(self.device)
                chosen_input_ids = batch["chosen_input_ids"].to(self.device)
                chosen_attention_mask = batch["chosen_attention_mask"].to(self.device)
                rejected_input_ids = batch["rejected_input_ids"].to(self.device)
                rejected_attention_mask = batch["rejected_attention_mask"].to(self.device)

                self.optimizer.zero_grad()
                loss = self.compute_dpo_loss(
                    prompt_input_ids, prompt_attention_mask,
                    chosen_input_ids, chosen_attention_mask,
                    rejected_input_ids, rejected_attention_mask
                )
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {epoch+1}, Average Loss: {total_loss / len(dataloader):.4f}")

    def save_model(self, path: str):
        self.policy_model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        print(f"Policy model saved to {path}")

3.3 训练过程解释

  • 数据加载PreferenceDataset 负责加载和编码偏好数据。它将每个提示、选定响应和拒绝响应转化为模型可以处理的 token ID。
  • 模型初始化DPOTrainer 初始化两个模型:
    • policy_model 是我们要训练的模型。
    • ref_model 是参考模型(通常是经过 SFT 的模型,其参数在 DPO 训练期间被冻结)。
  • DPO 损失计算compute_dpo_loss 是 DPO 算法的核心。
    • 它首先分别计算 policy_modelref_modelchosenrejected 响应的对数概率。
    • 关键在于,我们只关注生成部分的概率,即 chosenrejected 相对于 prompt 部分的概率。
    • 然后,它计算 chosenrejected 响应的对数概率比,以及这些比率的差值。
    • 最后,将这个差值乘以 beta,并通过 Sigmoid 函数,计算负对数似然作为损失函数。
  • 优化器:使用 AdamW 优化器更新 policy_model 的参数。
  • 训练循环:在每个 epoch 中,遍历数据加载器,计算损失,执行反向传播和梯度下降。