生成模型(四): InstructGPT

217 阅读2分钟

前面我们已经陆陆续续介绍了 gpt3、in-context learning等内容,但提到生成模型,不可不提目前火爆程度不亚于《狂飙》的 ChatGPT,而本文将介绍 ChatGPT 所基于的模型: InstructGPT。

介绍

虽然 GPT3/GPT3.5 模型参数和训练数据规模不断提升,但是我们发现这些大模型并不会因为参数的增加,而更好地读懂大家的心意, 例如目前 GPT3 依然可能会生成一些不置信、令人感觉不舒服的答案,因此,openAI 希望利用用户的反馈信息使得模型能生成令人满意的答案。训练数据来自用户人工标注的 prompt,并通过 open-ai接口搜集到这些labele data,做监finetuning,使得1.3B 的模型表现比 175B的 GPT要更好。结果显示,也许未来利用用户的反馈信息finetuning才是对结果保真的方向。

训练方法

如下图所示,InstructGPT 整体训练分为三个部分: 截屏2023-02-21 下午8.09.35.png

  • 有监督微调 SFT

    • 在这一步中,首先从 gpt3.5 的训练样本中随机sample一些 prompt;
    • 然后,标注人员书写对应的 prompt,期望输出;
    • 最后,用上面的监督数据 finetuning gpt3;
  • RM 训练 这一步被叫做奖励模型,具体过程:

    • 用不同的模型,输入同一个 prompt,得到不同的输出;
    • 标注者将上述模型结果从 好 到 不好 进行排序;
    • 这些数据将用于训练得到一个 reward model;
  • 利用强化学习,以RM为reward,优化迭代 SFT

    • 随机sample 一些 prompt
    • 用gpt3.5 生成对应的 output内容
    • 用 RM 模型计算 reward,然后再根据这个 reward 更新 gpt3.5 模型。

训练数据

上面已经大概介绍了整个训练思路,但是对应的每一步的训练数据如何得到,可以详细了解一下。

  • SFT 数据集
    • 这一步的数据,除了上面提到的 open ai 团队找了一个 40人的标注团队标注之外,还有一部分来自于 playground 游戏用户。其中人工标注有一些标准: a) 给一些简单但是需要有多样性的任务; b) prompt 和对应的答案对; c) 从线上的接口获取 prompt,然后根据这些 prompt 写答案;
  • RM
    • 人工对模型的输出排序,没有什么特殊的,不过这里值得一提的是,标准会打压模型输出的一些涉及偏见、有负面引导性、人类不喜欢的内容;
  • PPO
    • 从数据集中sample的,没有标注,但是可以关注一下数据分布: 生成任务占 45.6%, QA任务占 12.4%, 创意型 11.2%, 对话任务 8.4%

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 27 天,点击查看活动详情