1、导入相关包
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
2、加载、划分、处理数据集
ds = Dataset.load_from_disk("./nlpcc_2017/")
ds = ds.train_test_split(100, seed=42)
tokenizer = AutoTokenizer.from_pretrained("../glm-large-chinese", trust_remote_code=True)
def process_func(exmaples):
contents = ["摘要生成: \n" + e + tokenizer.mask_token for e in exmaples["content"]]
inputs = tokenizer(contents, max_length=384, truncation=True, padding="max_length", return_tensors="pt")
inputs = tokenizer.build_inputs_for_generation(inputs, targets=exmaples['title'], padding=True, max_gen_length=64)
return inputs
tokenized_ds = ds.map(process_func, batched=True, remove_columns=ds["train"].column_names)
3、创建模型
model = AutoModelForSeq2SeqLM.from_pretrained("../glm-large-chinese", trust_remote_code=True)
4、创建评估函数
5、创建TrainingArguments、Trainer
args = Seq2SeqTrainingArguments(
output_dir="./summary_glm",
per_device_train_batch_size=4,
per_device_eval_batch_size=8,
gradient_accumulation_steps=8,
logging_steps=8,
num_train_epochs=1
)
trainer = Seq2SeqTrainer(
args=args,
model=model,
train_dataset=tokenized_ds["train"],
tokenizer=tokenizer,
)
6、模型训练、评估、预测
trainer.train()
trainer.evaluate()
7、模型预测
input_text = ds["test"][-1]["content"]
inputs = tokenizer("摘要生成: \n" + input_text + tokenizer.mask_token, return_tensors="pt")
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=64)
inputs = inputs.to("cuda")
output = model.generate(**inputs, max_new_tokens=64, eos_token_id=tokenizer.eop_token_id, do_sample=True)
tokenizer.decode(output[0].tolist())
import torch
model = model.eval()
def predict_test():
predict = []
with torch.inference_mode():
for d in ds["test"]:
inputs = tokenizer("摘要生成: \n" + d["content"] + tokenizer.mask_token, return_tensors="pt")
inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=64)
inputs = inputs.to("cuda")
output = model.generate(**inputs, max_new_tokens=64, eos_token_id=tokenizer.eop_token_id, do_sample=True)
predict.append(tokenizer.decode(output[0].tolist()).split("<|startofpiece|>")[1].replace("<|endofpiece|>", "").strip())
print("curID:", len(predict))
return predict
result = predict_test()
result