训练(微调)ChatGPT 初步探索

168 阅读3分钟

前言

...

结果

本次微调的模型为 gpt-3.5-turbo-0613,训练数据由 ChatGPT 生成,共 10 条。训练前后效果明显,但是依然存在乱说的情况,个人推测和训练数据较为随意以及没有正确的编写 Prompt 有很大的关系。

训练数据

{ "messages": [ { "role": "user", "content": "新智村的总人口是多少?" }, { "role": "assistant", "content": "新智村的总人口为5000人。" } ] }
{ "messages": [ { "role": "user", "content": "农业产值是多少?" }, { "role": "assistant", "content": "新智村的农业产值为100万元。" } ] }
{ "messages": [ { "role": "user", "content": "新智村的儿童和孕妇死亡率是多少?" }, { "role": "assistant", "content": "新智村的儿童和孕妇死亡率为每千名儿童和孕妇中有3人。" } ] }
...

训练前

训练前-问题一

训练前-问题一

训练前-问题二

训练前-问题二

训练前-问题三

训练前-问题三

训练后

训练后-问题一

训练后-问题一

训练后-问题一-变

训练后-问题一-变

训练后-问题二

训练后-问题二

训练后-问题二-变

训练后-问题二-变

训练后-问题三

训练后-问题三

训练后-问题三-变

训练后-问题三-变

实现

本次用到的接口包含:

  • 文件上传 POST https://api.openai.com/v1/files
  • 获取所有上传的文件 GET https://api.openai.com/v1/files
  • 创建训练任务 POST https://api.openai.com/v1/fine_tuning/jobs
  • 获取所有训练的任务 GET https://api.openai.com/v1/fine_tuning/jobs
  • 查看训练结果(聊天)POST https://api.openai.com/v1/chat/completions

流程

  1. 准备训练数据,训练数据的格式为 jsonl
  2. 上传训练数据
  3. 获取训练数据对应文件的 ID
  4. 创新训练任务,这里要注意,任务创建成功后并不是立刻完成的,需要排队等候,本次训练大概等了 10 分钟左右
  5. 获取训练任务的 ID,这里需要注意,调用的结果包含训练状态 status
  6. 调用训练后的模型进行测试,这里注意,当训练任务完成后会返回 fine_tuned_model,这是聊天时使用的模型

部分代码

因为时间和环境的原因,本次代码较为散乱,其中上传文件部分的代码运行在国外的服务器中,剩余代码运行在本地代理环境中。所有代码只限参考,不能直接运行。

上传文件:

import fs from 'fs';
import dotenv from 'dotenv';
import OpenAI, { toFile } from 'openai';

dotenv.config();

const openai = new OpenAI({ apiKey: process.env.GENDATA_API_KEY });

const res = await openai.files.create({ file: fs.createReadStream("training_data.jsonl"), purpose: 'fine-tune' });

console.log(`🚀 > file: main.js:15 > res:`, res);

获取文件:

async function getFiles() {
  const res = await axios.get(`/openai/v1/files`);
  console.log(res);
  // {
  //   object: 'list',
  //   data: [
  //     {
  //       object: 'file',
  //       id: 'file-80MBgup8XrhEG0BkH', // 注意:这里后边会用
  //       purpose: 'fine-tune',
  //       filename: 'training_data.jsonl',
  //       bytes: 2585,
  //       created_at: 1694598431,
  //       status: 'processed',
  //       status_details: null
  //     }
  //   ]
  // };
}

创建训练任务:

async function createFineTuning() {
  const res = await axios.post('/openai/v1/fine_tuning/jobs', {
    training_file: 'file-80MBgup8XrhEG0BkH', // 注意:这里是上边获取到的文件id
    model: 'gpt-3.5-turbo-0613'
  });
  console.log(`🚀 > file: main.js:22 > fineTuning > res:`, res);
}

获取训练任务:

async function getFinTuningJobs() {
  const res = await axios.get('/openai/v1/fine_tuning/jobs');
  console.log(`🚀 > file: main.js:13 > getFinTuningJobs > res:`, res);
  // {
  //   object: 'list',
  //   data: [
  //     {
  //       object: 'fine_tuning.job',
  //       id: 'xxxx',
  //       model: 'gpt-3.5-turbo-0613',
  //       created_at: 1694599349,
  //       finished_at: 1694599940,
  //       fine_tuned_model: 'ft:gpt-3.5-turbo-0613:xxxx::7y3UEm', // 注意:这里后边会用
  //       organization_id: 'xxxx',
  //       result_files: ['xxxx'],
  //       status: 'succeeded',
  //       validation_file: null,
  //       training_file: 'file-80MBgup8XrhEG0BkH',
  //       hyperparameters: {
  //         n_epochs: 10
  //       },
  //       trained_tokens: 6760,
  //       error: null
  //     }
  //   ],
  //   has_more: false
  // };
}

聊天测试:

async function startChat() {
  const askDom = document.querySelector('.ask');
  const answerDom = document.querySelector('.answer');

  const url = `${OPENAI_CONFIG.base}/v1/chat/completions`;

  let resText = '';
  await fetchEventSource(url, {
    method: 'POST',

    headers: {
      'Content-Type': 'application/json'
    },

    body: JSON.stringify({
      messages: [{ role: 'user', content: askDom.innerText }],
      model: 'ft:gpt-3.5-turbo-0613:xxxx::7y3UEm', // 注意:这是用的是训练任务给的模型
      stream: true
    }),

    onmessage(ev) {
      if (ev.data === '[DONE]') {
        console.log(`完整回答为:\n${resText}`);
        return;
      }

      const data = JSON.parse(ev.data);
      const { choices } = data;

      const arr = choices.filter((item) => item.finish_reason !== 'stop');

      arr.forEach((item) => {
        resText += item.delta.content;
      });

      answerDom.innerHTML = md.render(resText);
    }
  });
}