本文主要介绍数据的预处理。
1、找大模型的数据
前面写了一篇文章《ChatGPT|大语言模型训练有哪些开源数据集? 》(mp.weixin.qq.com/s?__biz=MzA…
不过在开发大模型,需要根据实际的需求可以找到不同的数据,比如如果需要英文预料,那么就需要找到英文的预料,目前我们的 myllm
项目主要是中文小模型,所以找了一些中文相关数据:
- Wiki中文百科:huggingface.co/datasets/pl…
- 天工数据集:huggingface.co/datasets/Sk…
如果需要其他数据可以在 huggingface
查找,地址:huggingface.co/datasets,我也…
2、数据预处理
下载数据以后,按照如下流程处理:
- 提取文件的文本数据
- 将文本数据进行截断,比如某段文本超过限制的上下文大小(如:512),就需要截断,增加截断标识
- 将文本转换为token,格式化存储token数据
处理以下格式的数据:
[
{
"completion": "昭通机场(ZPZT)是位于中国云南昭通的民用机场,始建于1935年,1960年3月开通往返航班“昆明-昭通”,原来属军民合用机场。1986年机场停止使用。1991年11月扩建,于1994年2月恢复通航。是西南地区「文明机场」,通航城市昆明。 机场占地1957亩,飞行区等级为4C,有一条跑道,长2720米,宽48米,可供波音737及以下机型起降。机坪面积6600平方米,停机位2个,航站楼面积1900平方米。位于城东6公里处,民航路与金鹰大道交叉处。\n航点\n客服电话\n昭通机场客服电话:0870-2830004",
"source": "wikipedia.zh2307"
}
]
处理代码如下:
tokenizer = AutoTokenizer.from_pretrained('./my_tokenizer', use_fast=False)
basepath = "../datasets"
# 截断数据
def split_text(text, n = 512):
return [text[i: i + n] for i in range(0, len(text), n)]
# 整理wikipedia-cn-20230720-filtered数据,下载地址:https://huggingface.co/datasets/pleisto/wikipedia-cn-20230720-filtered
def process_wiki_clean():
with open(f'{basepath}/wikipedia-cn-20230720-filtered.json', 'r', encoding='utf-8') as file:
data = ujson.loads(file.read())
data_len = len(data)
doc_ids = []
for idx, line in enumerate(data):
text_input = line['completion']
text_arr = split_text(text_input)
for text in text_arr:
text_id = tokenizer(f'{bos_token}{text}{eos_token}').data['input_ids']
print("text_id: ", text_id, ", text: ", text)
if len(text_id) > 5:
doc_ids += text_id
if idx % (int(data_len / 20)) == 0:
print(f"[{idx}/{data_len}] {text}")
arr = np.array(doc_ids, dtype=np.uint16)
with open(f'{basepath}/wikipedia-cn-20230720-filtered.bin', 'wb') as f:
f.write(arr.tobytes())
其中 text_id
输出就是从前面训练的Tokenizer中输出的对应的词ID,然后将 doc_ids
通过 numpy
序列化为 wikipedia-cn-20230720-filtered.bin
文件。
3、合并多个数据
可以将多个数据,代码如下:
# 将多个数据合并为一个文件
def pretrain_process():
process_wiki_clean()
data_path_list = [
f'{basepath}/wikipedia-cn-20230720-filtered.bin',
]
data_list = []
for data_path in data_path_list:
with open(data_path, 'rb') as f:
data = np.fromfile(f, dtype=np.uint16)
data_list.append(data)
arr = np.concatenate(data_list)
print(arr.shape)
with open(f'{basepath}/pretrain_data.bin', 'wb') as f:
f.write(arr.tobytes())
最后训练数据是 pretrain_data.bin
,数据大小 361M
。
参考
(1)Wiki中文百科:huggingface.co/datasets/pl…
(2)天工数据集:huggingface.co/datasets/Sk…
(3)github.com/jiahe7ay/MI…