1、导入相关包
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, TrainingArguments, Trainer, BloomForCausalLM
2、加载、划分、处理数据集
模型地址 国内链接
# 加载
ds = Dataset.load_from_disk("./wiki_cn_filtered/")
# 划分
# 数据处理
tokenizer = AutoTokenizer.from_pretrained("./bloom-389m-zh")
def process_func(examples):
contents = [e + tokenizer.eos_token for e in examples["completion"]]
return tokenizer(contents, max_length=384, truncation=True)
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds.column_names)
3、创建模型
model = AutoModelForCausalLM.from_pretrained("./bloom-389m-zh")
4、创建评估函数
5、创建TrainingArguments、Trainer
args = TrainingArguments(
output_dir="./causal_lm",
per_device_train_batch_size=2,
gradient_accumulation_steps=16,
logging_steps=10,
num_train_epochs=1,
fp16=True
)
trainer = Trainer(
args=args,
model=model,
tokenizer=tokenizer,
train_dataset=tokenized_ds,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
6、模型训练、评估、预测
# 模型训练
trainer.train()
# 模型评估
trainer.evaluate()
7、模型预测
# 模型预测
# 测试
from transformers import pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
pipe("西安交通大学博物馆(Xi'an Jiaotong University Museum)是一座位于西安", max_length=128, do_sample=True)
pipe("下面是一则游戏新闻。小编报道,近日,游戏产业发展的非常", max_length=128, do_sample=True)