微调大模型实现评论分类

151 阅读2分钟

一、目标

通过微调 gpt2 大模型,使其能更准确根据好评、差评对电影评论进行分类。

二、本地环境

image.png

三、微调前的 gpt2 部署推理服务

相关代码

修改本地模型路径 model_path

from transformers import GPT2ForSequenceClassification, GPT2Tokenizer
import torch

# 设置模型路径
model_path = "/Users/GGGxie/Documents/bigmodel/model/gpt2-openai"

# 加载模型和分词器
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2ForSequenceClassification.from_pretrained(model_path)

# 使用结束标记作为填充标记
tokenizer.pad_token = tokenizer.eos_token

# 将模型设置为评估模式
model.eval()

content = {"0":"差评","1":"好评"}
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = torch.argmax(logits, dim=-1).item()
    return content[str(predicted_class_id)]

if __name__ == "__main__":
    while True:
        # 从终端获取输入
        text = input("Enter text to predict (or type 'exit' to quit): ")
        if text.lower() == 'exit':
            break
        prediction = predict(text)
        print(f"Text: {text}")
        print(f"Prediction: {prediction}")

运行结果

结果完全不可信! image.png

四、微调

数据集:cornell-movie-review-data/rotten_tomatoes

相关代码

  1. 修改本地模型路径model_path
  2. 微调后的模型权重、配置文件、分词器配置等产物存储在./results
# 加载数据集
from datasets import load_dataset

# 加载 Rotten Tomatoes 数据集的各个分割部分
# 需要魔法从 huggingface 进行远程加载,也可下载到本地
train_dataset = load_dataset("rotten_tomatoes", split="train")
validation_dataset = load_dataset("rotten_tomatoes", split="validation")
test_dataset = load_dataset("rotten_tomatoes", split="test")


# 数据预处理
from transformers import AutoTokenizer

model_path = "/Users/GGGxie/Documents/bigmodel/model/gpt2-openai"
tokenizer = AutoTokenizer.from_pretrained(model_path)

# 添加 pad_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# 对训练集、验证集和测试集进行分词和预处理
tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True)
tokenized_validation_dataset = validation_dataset.map(tokenize_function, batched=True)
tokenized_test_dataset = test_dataset.map(tokenize_function, batched=True)


# 加载模型和配置
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels=2)
model.config.pad_token_id = model.config.eos_token_id


# 设置校准器
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 设置超参
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
)

# 创建Trainer对象
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    eval_dataset=tokenized_validation_dataset,
    data_collator=data_collator,
)

# 进行训练
trainer.train()

# 保存模型和分词器
trainer.save_model("./results")
tokenizer.save_pretrained("./results")

微调结果

image.png

五、微调后的 gpt2 部署推理服务

相关代码

from transformers import AutoTokenizer, GPT2ForSequenceClassification
import torch

# 加载微调后的模型和分词器
model_path = "./results"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = GPT2ForSequenceClassification.from_pretrained(model_path)

# 确保模型在评估模式
model.eval()

content = {"0":"差评","1":"好评"}
# 预测函数
def predict(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predicted_class_id = torch.argmax(logits, dim=-1).item()
    return content[str(predicted_class_id)]

if __name__ == "__main__":
    while True:
        # 从终端获取输入
        text = input("Enter text to predict (or type 'exit' to quit): ")
        if text.lower() == 'exit':
            break
        prediction = predict(text)
        print(f"Text: {text}")
        print(f"Prediction: {prediction}")

运行结果

image.png