AI 炼丹系列(Python): 深度学习从练气到化神,是否结丹看你的(2)

205 阅读3分钟

问题:

  1. 深度学习训练的常规步骤
  2. 微调模型的几个思考方向 → 把次序调后以便连贯
  3. 构建自己的数据集用于微调大模型
  4. 各种微调方式的实现

编写代码

模型训练的数据保存格式有很多种,datasets库提供了相关的方法从网上下载并加载导入,所以就不需要重复造轮子了。datasets是HuggingFace开发的,所以默认了从其网站下载数据的功能。在这里我们实现本地加载数据功能。

我们使用jsonl格式,以方便编辑和浏览。它和json的区别就是,每行都是一个json对象(也可以理解为python中的dict对象)。

按咱的惯例先定义路径和数据格式:

cur_path = os.getcwd()

class DataPathEnum(str, Enum):
    GPT2XL_TRAIN_DATA_DIR = "gpt2xl"

    def __str__(self):
        return os.path.join(cur_path, "data", self.value)

数据文件路径和文件名为:

2_path.jpg

question, novel, konwledge是数据分类的名称,代表我要训练问答、小说、知识之类的内容。后面带_e的文件代表的是验证数据文件,两种文件里的格式一致。 question.jsonl其中的数据格式如下:(你完全可以自定义)

{"input":"who am i", "context":"I am YS-SRT's Robet", "summary":["I am a Robet, belong to YS-SRT"]}

question_e.jsonl中

{"input":"who is he", "context":"he is YS-SRT's brother", "summary":["I am your brother, mr. YS-SRT"]}

下面直接继承datasets.GeneratorBasedBuilder,来实现本地数据的加载功能:

_DESCRIPTION = "data for gpt2xl train and verify"

_LOCALPATH = str(DataPathEnum.GPT2XL_TRAIN_DATA_DIR)
# 示例,远程下载的数据路径
_URL="https://HuggingFace.com/YS-SRT/gpt2xl/data.zip"

# 数据类别,以jsonl文件名区分
category_list = [
    "question",
    "novel",
    "knowledge"
]

# 配置类自定义方法,用于设定可以调用的数据类别
class GPT2XLDataConfig(datasets.BuilderConfig):
    def __init__(self, **kwargs):
        super().__init__(version=datasets.Version("2.16.1"), **kwargs)


class GPT2XLData(datasets.GeneratorBasedBuilder):
    # 基类配置,用于数据的导入,在这里给它赋值
    BUILDER_CONFIGS = [
        GPT2XLDataConfig(
            name=category_name,
        )
        for category_name in category_list
    ]
    
    # 数据文件格式定义,与上面数据文件中的格式一致
    def _info(self):
        features = datasets.Features(
            {
                "input": datasets.Value("string"), 
                "context": datasets.Value("string"), 
                "summary": [datasets.Value("string")],
            }
        )
        return datasets.DatasetInfo(
            description=_DESCRIPTION, 
            features=features,
        )

    # 生成多个可分开的数据集
    def _split_generators(self, dl_manager=None):
        # 网路下载数据包并解压,从本地导入的话不需要它
        # data_dir = dl_manager.download_and_extract(_URL) 
        
        # self.config 在基类中由上的self.BUILDER_CONFIGS而来
        category_name = self.config.name
        # 这里可以导入同类别的多个文件
        return [datasets.SplitGenerator(
                name=datasets.Split.TRAIN, #训练数据
                gen_kwargs={
                    "filepath": os.path.join(
                        #如果是下载来的数据,就使用data_dir构建路径。这里用本地路径
                        _LOCALPATH, f"{category_name}.jsonl"
                    ),
                },
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST, #测试数据
                gen_kwargs={
                    "filepath": os.path.join(
                        _LOCALPATH, f"{category_name}_e.jsonl"
                    ),
                },
            )]
        
    # 直接从文件路径导入的生成器,以便查看几个G大小的数据文件
    def _generate_examples(self, filepath):
        with open(filepath, encoding="utf-8") as f:
            for idx, line in enumerate(f):
                key = f"{self.config.name}-{idx}"
                item = json.loads(line)
                yield key, {
                    "input": item["input"],
                    "context": item["context"],
                    "summary": item["summary"],
                }

按惯例,写个ipynb文件测试,datasets相关API可以查看其文档

from datasets import load_dataset
# 注意本地导入的话,传递的第一个位置参数是包含自定义数据Builder类的py文件,第二个参数就是类别
# 可以指定split="train" 来单独加载Split.Train标识的数据集
# 返回 DatasetDict|Dataset
dd = load_dataset("gpt2xl_data_build.py", "question", trust_remote_code=True)
dd.items() # DatasetDict的方法
# 查看训练数据集的第一条
dd["train"][0]
# 查看测试数据集的第一条 
dd["test"][0]

image.png

如果想保存为arrow格式

import os
from utils import DataPathEnum
data_path = os.path.join(str(DataPathEnum.GPT2XL_TRAIN_DATA_DIR), "question_data")

question_dd = load_dataset("gpt2xl_data_build.py", "question", trust_remote_code=True)
#保存为arrow数据文件
question_dd.save_to_disk(data_path)

#加载arrow数据文件, 返回的是Dataset
question_ds = load_from_disk(data_path)
#显示数据结构
question_ds.data # Dataset的方法

# question_ds.data["train"]
# question_ds.data["test"]

2_save_load.jpg

保存的文件格式是:

2_arrow_path.jpg

完整代码地址

my_github.jpg