DeepSpeed-Chat RLHF 阶段代码解读(0) —— 原始 PPO 代码解读

440 阅读4分钟

为了理解 DeepSpeed-Chat RLHF 的 RLHF 全部过程,这个系列会分三篇文章分别介绍: 原始 PPO 代码解读RLHF 奖励函数代码解读RLHF PPO 代码解读 这是系列的第一篇文章,我们来一步一步的看 PPO 算法的代码实现,对于 PPO 算法原理不太了解的同学,可以参考之前的文章:

  1. 深度强化学习(DRL)算法 2 —— PPO 之 Clipped Surrogate Objective 篇 - 掘金 (juejin.cn)
  2. 深度强化学习(DRL)算法 2 —— PPO 之 GAE 篇 - 掘金 (juejin.cn)

Clipped Surrogate 函数实现

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/ppo.py
for start in range(0, args.batch_size, args.minibatch_size):
    end = start + args.minibatch_size
    mb_inds = b_inds[start:end]

    _, newlogprob, entropy, newvalue = agent.get_action_and_value(b_obs[mb_inds], b_actions.long()[mb_inds])
    logratio = newlogprob - b_logprobs[mb_inds]
    ratio = logratio.exp()

    mb_advantages = b_advantages[mb_inds]
    if args.norm_adv:
        mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8)

    # Policy loss
    pg_loss1 = -mb_advantages * ratio
    pg_loss2 = -mb_advantages * torch.clamp(ratio, 1 - args.clip_coef, 1 + args.clip_coef)
    pg_loss = torch.max(pg_loss1, pg_loss2).mean()

Clipped Surrogate 函数的实现很简单,这里不再赘述,理解算法原理,代码自然而然就可以看懂,核心是 get_action_and_value 函数的理解。

def get_action_and_value(self, x, action=None):
    logits = self.actor(x)
    # probs 相当于计算 softmax
    probs = Categorical(logits=logits)
    if action is None:
        action = probs.sample()
    # probs.log_prob(action) 计算的是 p(a|s) 的 log 形式,方便计算 Clipped Surrogate 函数里的 ratio
    return action, probs.log_prob(action), probs.entropy(), self.critic(x) 

GAE 实现

直接来看 gae 可能比较抽象,我们先来看蒙特卡洛方法实现的优势估计,对蒙特卡洛方法不熟悉的同学,可以参考之前的文章。 深度强化学习(DRL)算法 附录 3 —— 蒙特卡洛方法(MC)和时序差分(TD) - 掘金 (juejin.cn) 两种方法都采用了反向迭代(因为反向迭代更好计算)的方式来实现优势估计。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)

returns = torch.zeros_like(rewards).to(device)
for t in reversed(range(args.num_steps)):
    if t == args.num_steps - 1:
        nextnonterminal = 1.0 - next_done
        next_return = last_value
    else:
        nextnonterminal = 1.0 - dones[t+1]
        next_return = returns[t+1]
    returns[t] = rewards[t] + args.gamma * nextnonterminal * next_return
advantages = returns - values

上面的代码做了什么事情呢,last_value 对应最后的 step(对应 step t) 产生的期望回报,如果 step t-1 整个流程没有结束,那么 t-1 时刻的期望回报就是 reward(t-1) + args.gamma * nextnonterminal * next_return,这样一步一步往后推,就可以计算每一个 step 的期望回报,从而得到每一步的优势,还没理解的话,看下面每个时间步的拆解。关于 last_value 的使用,这里由于没有后续的回报可以累积,所以直接使用 last_value 作为最后一个时间步的回报。关于下面为啥用 return[t-1] 替换原始公式的 value[t-1],这样计算的话就相当于蒙特卡洛方法的优势估计,如果next_return = returns[t+1] 改成 next_value = values[t+1] 就相当于 TD(1) 的优势估计。

# t
return(t) = v(t)
# t - 1
return(t-1) = reward(t-1) + gamma * return(t) = reward(t-1) + gamma * return(t)
# t - 2
return(t-2) = reward(t-2) + gamma * return(t-1) = reward(t-2) + gamma * (reward(t-1) + gamma * return(t))
......
# 我们可以看到一步一步往前推,最后就得到蒙特卡洛方法的优势估计

理解了上面讲的蒙特卡洛方法实现的优势估计,再来看 gae 的实现,我们可以看到代码实现上十分的相似,只是多了 delta 的计算,这里的 delta 对应的就是之前 PPO GAE 篇里介绍的 delta。

# code from cleanrl: https://github.com/vwxyzjn/cleanrl/commit/b7088a41e5e6f0f5f6940fd29054a35118083b28
last_value = agent.get_value(next_obs.to(device)).reshape(1, -1)

advantages = torch.zeros_like(rewards).to(device)
lastgaelam = 0
for t in reversed(range(args.num_steps)):
    if t == args.num_steps - 1:
        nextnonterminal = 1.0 - next_done
        nextvalues = last_value
    else:
        nextnonterminal = 1.0 - dones[t+1]
        nextvalues = values[t+1]
    delta = rewards[t] + args.gamma * nextvalues * nextnonterminal - values[t]
    advantages[t] = lastgaelam = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam
returns = advantages + values

这里通过反向迭代的方式计算 GAE advantage,可能理解上比较抽象,举个例子,就很好理解了:

# advantage(t)
adv[t] = lastgaelam = rewards[t] + gamma * values[t+1] - values[t]
# t-1
adv[t-1] = lastgaelam = rewards[t-1] + gamma * values[t] - values[t-1] + gamma * lambda * lastgaelam
# t-2
adv[t-2] = lastgaelam = rewards[t-2] + gamma * values[t-1] - values[t-2] + gamma * lambda * lastgaelam
...

可以看到,逐项展开,每一时刻的 GAE Advantage 和 PPO GAE 篇里介绍的公式是一模一样的,这里 GAE 就是一种数学公式表达,核心思想是 n step 的优势估计的加权平均,通过数学技巧恰好是上面的形式。

参考

  1. The 37 Implementation Details of Proximal Policy Optimization · The ICLR Blog Track (iclr-blog-track.github.io)
  2. HIGH-DIMENSIONAL CONTINUOUS CONTROL USING GENERALIZED ADVANTAGE ESTIMATION