机器学习|从0开发大模型之数据预处理

4 阅读3分钟

本文主要介绍数据的预处理。

1、找大模型的数据

前面写了一篇文章《ChatGPT|大语言模型训练有哪些开源数据集? 》(mp.weixin.qq.com/s?__biz=MzA…

不过在开发大模型,需要根据实际的需求可以找到不同的数据,比如如果需要英文预料,那么就需要找到英文的预料,目前我们的 myllm 项目主要是中文小模型,所以找了一些中文相关数据:

如果需要其他数据可以在 huggingface 查找,地址:huggingface.co/datasets,我也…

2、数据预处理

下载数据以后,按照如下流程处理:

  • 提取文件的文本数据
  • 将文本数据进行截断,比如某段文本超过限制的上下文大小(如:512),就需要截断,增加截断标识
  • 将文本转换为token,格式化存储token数据

处理以下格式的数据:

[
    {
        "completion""昭通机场(ZPZT)是位于中国云南昭通的民用机场,始建于1935年,1960年3月开通往返航班“昆明-昭通”,原来属军民合用机场。1986年机场停止使用。1991年11月扩建,于1994年2月恢复通航。是西南地区「文明机场」,通航城市昆明。 机场占地1957亩,飞行区等级为4C,有一条跑道,长2720米,宽48米,可供波音737及以下机型起降。机坪面积6600平方米,停机位2个,航站楼面积1900平方米。位于城东6公里处,民航路与金鹰大道交叉处。\n航点\n客服电话\n昭通机场客服电话:0870-2830004",
        "source""wikipedia.zh2307"
    }
]

处理代码如下:

tokenizer = AutoTokenizer.from_pretrained('./my_tokenizer', use_fast=False)
basepath = "../datasets"

# 截断数据
def split_text(text, n = 512):
    return [text[i: i + n] for i in range(0len(text), n)]

# 整理wikipedia-cn-20230720-filtered数据,下载地址:https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered  
def process_wiki_clean():
    with open(f'{basepath}/wikipedia-cn-20230720-filtered.json''r', encoding='utf-8'as file:
        data = ujson.loads(file.read())
    data_len = len(data)
    doc_ids = []
    for idx, line in enumerate(data):
        text_input = line['completion']
        text_arr = split_text(text_input)
        for text in text_arr:
            text_id = tokenizer(f'{bos_token}{text}{eos_token}').data['input_ids']
            print("text_id: ", text_id, ", text: ", text)
            if len(text_id) > 5:
                doc_ids += text_id
        if idx % (int(data_len / 20)) == 0:
            print(f"[{idx}/{data_len}{text}")
    arr = np.array(doc_ids, dtype=np.uint16)
    with open(f'{basepath}/wikipedia-cn-20230720-filtered.bin''wb'as f:
        f.write(arr.tobytes())

其中 text_id 输出就是从前面训练的Tokenizer中输出的对应的词ID,然后将 doc_ids 通过 numpy 序列化为 wikipedia-cn-20230720-filtered.bin 文件。

3、合并多个数据

可以将多个数据,代码如下:

# 将多个数据合并为一个文件
def pretrain_process():
    process_wiki_clean()

    data_path_list = [
        f'{basepath}/wikipedia-cn-20230720-filtered.bin',
    ]
    data_list = []
    for data_path in data_path_list:
        with open(data_path, 'rb'as f:
            data = np.fromfile(f, dtype=np.uint16)
            data_list.append(data)
    arr = np.concatenate(data_list)
    print(arr.shape)
    with open(f'{basepath}/pretrain_data.bin''wb'as f:
        f.write(arr.tobytes())

最后训练数据是 pretrain_data.bin,数据大小 361M

参考

(1)Wiki中文百科:huggingface.co/datasets/pl…
(2)天工数据集:huggingface.co/datasets/Sk…
(3)github.com/jiahe7ay/MI…