transformers+datasets+evaluate文本分类

181 阅读1分钟

1、导入相关包

from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments, pipeline
from datasets import load_dataset
import evaluate

2、加载、划分、处理数据集

# 加载
dataset = load_dataset('csv',data_files='./ChnSentiCorp_htl_all.csv',split='train')
dataset = dataset.filter(lambda x: x['review'] is not None) #去除空数据
# 划分
datasets = dataset.train_test_split(test_size=0.1)
# 数据处理
tokenizer = AutoTokenizer.from_pretrained("./rbt3")
def process_function(examples):
  tokenized_examples = tokenizer(examples['review'],padding='max_length',truncation=True,max_length=128)
  tokenized_examples['labels'] = examples['label']
  return tokenized_examples

tokenized_dataset = datasets.map(process_function,batched=True,remove_columns=datasets["train"].column_names)

3、创建模型

model = AutoModelForSequenceClassification.from_pretrained("./rbt3")

4、创建评估函数

acc_metric = evaluate.load("./evaluate-main/metrics/accuracy")
f1_metric = evaluate.load("./evaluate-main/metrics/f1")

def eval_metric(eval_predict):
  predictions, labels = eval_predict
  predictions = predictions.argmax(axis=-1)
  acc = acc_metric.compute(predictions=predictions,references=labels)
  f1 = f1_metric.compute(predictions=predictions,references=labels)
  acc.update(f1)
  return acc

5、创建TrainingArguments、Trainer

train_args = TrainingArguments(output_dir="./checkpoints",
                               per_device_train_batch_size=64,
                               per_device_eval_batch_size=128,
                               logging_steps=10,
                               evaluation_strategy="epoch",
                               save_strategy="epoch",
                               save_total_limit=3,
                               learning_rate=2e-5,
                               weight_decay=0.01)

collator = DataCollatorWithPadding(tokenizer=tokenizer)
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    data_collator=collator,
    compute_metrics=eval_metric
)

6、模型训练、评估、预测

# 模型训练
trainer.train()
# 模型评估
trainer.evaluate()

7、模型预测

# 模型预测
trainer.predict(tokenized_dataset["test"])

# 测试
sen = "我觉得这家酒店不错,饭很好吃!"
id2_label = {0:"差评",1:"好评"}
model.config.id2label = id2_label
pipe = pipeline("text-classification",model=model,tokenizer=tokenizer)
pipe(sen)