简介
利用CNN网络实现对简单文本的情感分类。
要求
numpy<2
torch与torchtext版本兼容
代码
创建环境
# 创建新的虚拟环境
python -m venv myenv
# 激活虚拟环境(Windows)
myenv\Scripts\activate
# 激活虚拟环境(Linux/Mac)
source myenv/bin/activate
main.py
import os
import argparse
import datetime
import torch
import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
import model
import train
import mydatasets
# 禁用 torchtext 弃用警告
torchtext.disable_torchtext_deprecation_warning()
# 参数解析
parser = argparse.ArgumentParser(description='CNN text classifier')
# 学习参数
parser.add_argument('-lr', type=float, default=0.001, help='学习率 [默认: 0.001]')
parser.add_argument('-epochs', type=int, default=256, help='训练轮数 [默认: 256]')
parser.add_argument('-batch-size', type=int, default=64, help='批次大小 [默认: 64]')
parser.add_argument('-log-interval', type=int, default=1, help='日志间隔 [默认: 1]')
parser.add_argument('-test-interval', type=int, default=100, help='测试间隔 [默认: 100]')
parser.add_argument('-save-interval', type=int, default=500, help='模型保存间隔步数 [默认: 500]')
parser.add_argument('-save-dir', type=str, default='snapshot', help='模型保存路径')
parser.add_argument('-early-stop', type=int, default=1000, help='早停步数')
parser.add_argument('-save-best', type=bool, default=True, help='保存最佳模型')
# 数据参数
parser.add_argument('-shuffle', action='store_true', default=False, help='打乱数据')
# 模型参数
parser.add_argument('-dropout', type=float, default=0.5, help='Dropout概率 [默认: 0.5]')
parser.add_argument('-embed-dim', type=int, default=128, help='词向量维度 [默认: 128]')
parser.add_argument('-kernel-num', type=int, default=100, help='卷积核数量')
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='卷积核尺寸')
parser.add_argument('-static', action='store_true', default=False, help='固定词向量')
# 设备参数
parser.add_argument('-device', type=int, default=-1, help='设备编号 (-1=CPU)')
parser.add_argument('-no-cuda', action='store_true', default=False, help='禁用GPU')
# 选项参数
parser.add_argument('-snapshot', type=str, default=None, help='模型快照路径')
parser.add_argument('-predict', type=str, default=None, help='预测文本')
parser.add_argument('-test', action='store_true', default=False, help='测试模式')
def collate_batch(batch):
"""处理批次数据:填充文本并转换为张量"""
texts, labels = zip(*batch)
# 直接使用已有张量(避免重复包装)
padded_texts = pad_sequence(texts, batch_first=True, padding_value=0)
labels = torch.stack(labels)
return padded_texts, labels
if __name__ == '__main__':
args = parser.parse_args()
# 构建词汇表
print("\nBuilding vocabulary...")
tokenizer = get_tokenizer(lambda x: x.split())
def yield_tokens(data_iter):
for example in data_iter.examples:
text, _ = example
yield text
# 临时转换函数
temp_text_transform = lambda x: x
temp_label_transform = lambda x: x
# 加载原始数据集构建词汇表
train_data, _ = mydatasets.MR.splits(temp_text_transform, temp_label_transform, root='.')
vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"])
# 定义正式转换函数
text_transform = lambda x: vocab(x)
label_transform = lambda x: 1 if x == 'positive' else 0
# 加载带转换的数据集
print("\nLoading datasets...")
train_data, dev_data = mydatasets.MR.splits(text_transform, label_transform, root='.', dev_ratio=0.1)
# 创建DataLoader
train_iter = torch.utils.data.DataLoader(
train_data,
batch_size=args.batch_size,
collate_fn=collate_batch,
shuffle=args.shuffle
)
dev_iter = torch.utils.data.DataLoader(
dev_data,
batch_size=args.batch_size,
collate_fn=collate_batch
)
# 更新模型参数
args.embed_num = len(vocab)
args.class_num = 2
args.cuda = (not args.no_cuda) and torch.cuda.is_available()
args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
print("\nModel Parameters:")
for attr, value in sorted(args.__dict__.items()):
print(f"\t{attr.upper():<15} = {value}")
# 初始化模型
cnn = model.CNN_Text(args)
if args.snapshot:
print(f"\nLoading model from {args.snapshot}...")
cnn.load_state_dict(torch.load(args.snapshot))
if args.cuda:
torch.cuda.set_device(args.device)
cnn = cnn.cuda()
# 运行模式
if args.predict:
label = train.predict(args.predict, cnn, text_transform, label_transform, args.cuda)
print(f'\n[Text]\t{args.predict}\n[Label]\t{label}\n')
elif args.test:
try:
train.eval(dev_iter, cnn, args)
except Exception as e:
print(f"\nTest error: {str(e)}")
else:
try:
train.train(train_iter, dev_iter, cnn, args)
except KeyboardInterrupt:
print('\n' + '-'*50)
print('Training stopped by user')
train.py
import os
import sys
import torch
import torch.autograd as autograd
import torch.nn.functional as F
def train(train_iter, dev_iter, model, args):
if args.cuda:
model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
steps = 0
best_acc = 0
last_step = 0
for epoch in range(1, args.epochs + 1):
for batch in train_iter:
model.train()
feature, target = batch
if args.cuda:
feature, target = feature.cuda(), target.cuda()
optimizer.zero_grad()
logit = model(feature)
loss = F.cross_entropy(logit, target)
loss.backward()
optimizer.step()
steps += 1
if steps % args.log_interval == 0:
corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
accuracy = 100.0 * corrects / feature.size(0)
sys.stdout.write(
'\rBatch[{}] - loss: {:.6f} acc: {:.4f}%({}/{})'.format(steps,
loss.item(),
accuracy.item(),
corrects.item(),
feature.size(0)))
if steps % args.test_interval == 0:
dev_acc = eval(dev_iter, model, args)
if dev_acc > best_acc:
best_acc = dev_acc
last_step = steps
if args.save_best:
save(model, args.save_dir, 'best', steps)
else:
if steps - last_step >= args.early_stop:
print('early stop by {} steps.'.format(args.early_stop))
elif steps % args.save_interval == 0:
save(model, args.save_dir, 'snapshot', steps)
def eval(data_iter, model, args):
model.eval()
corrects, avg_loss = 0, 0
with torch.no_grad():
for batch in data_iter:
feature, target = batch
if args.cuda:
feature, target = feature.cuda(), target.cuda()
logit = model(feature)
loss = F.cross_entropy(logit, target, reduction='sum')
avg_loss += loss.item()
corrects += (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
size = len(data_iter.dataset)
avg_loss /= size
accuracy = 100.0 * corrects / size
print('\nEvaluation - loss: {:.6f} acc: {:.4f}%({}/{}) \n'.format(avg_loss,
accuracy,
corrects,
size))
return accuracy
def predict(text, model, text_transform, label_transform, cuda_flag):
assert isinstance(text, str)
model.eval()
text = text_transform(text)
x = torch.tensor([text])
x = autograd.Variable(x)
if cuda_flag:
x = x.cuda()
print(x)
output = model(x)
_, predicted = torch.max(output, 1)
label_mapping = {0: 'negative', 1: 'positive'}
return label_mapping[predicted.item()]
def save(model, save_dir, save_prefix, steps):
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
save_prefix = os.path.join(save_dir, save_prefix)
save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
torch.save(model.state_dict(), save_path)
mydatasets.py
# mydatasets.py
import re
import os
import random
import tarfile
import urllib.request
import torch
from torch.utils.data import Dataset
from torchtext.data.utils import get_tokenizer
class TarDataset(Dataset):
@classmethod
def download_or_unzip(cls, root):
path = os.path.join(root, cls.dirname)
if not os.path.isdir(path):
tpath = os.path.join(root, cls.filename)
if not os.path.isfile(tpath):
print('Downloading dataset...')
urllib.request.urlretrieve(cls.url, tpath)
with tarfile.open(tpath, 'r') as tfile:
print('Extracting...')
def is_within_directory(directory, target):
abs_directory = os.path.abspath(directory)
abs_target = os.path.abspath(target)
prefix = os.path.commonprefix([abs_directory, abs_target])
return prefix == abs_directory
def safe_extract(tar, path=".", members=None, numeric_owner=False):
for member in tar.getmembers():
member_path = os.path.join(path, member.name)
if not is_within_directory(path, member_path):
raise Exception("Attempted Path Traversal in Tar File")
tar.extractall(path, members, numeric_owner=numeric_owner)
safe_extract(tfile, root)
return os.path.join(path, '')
class MR(TarDataset):
url = 'https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz'
filename = 'rt-polaritydata.tar.gz'
dirname = 'rt-polaritydata'
def __init__(self, text_transform, label_transform, path=None, examples=None):
self.text_transform = text_transform
self.label_transform = label_transform
# 修正分词器:清洗后按空格切分
self.tokenizer = get_tokenizer(lambda x: self.clean_str(x).split())
if examples is None:
path = self.download_or_unzip(os.getcwd() if path is None else path)
self.examples = []
# 加载负面评价(确保存储分词列表)
with open(os.path.join(path, 'rt-polarity.neg'), 'r', encoding='latin-1') as f:
for line in f:
tokenized_text = self.tokenizer(line.strip())
self.examples.append((tokenized_text, 'negative'))
# 加载正面评价
with open(os.path.join(path, 'rt-polarity.pos'), 'r', encoding='latin-1') as f:
for line in f:
tokenized_text = self.tokenizer(line.strip())
self.examples.append((tokenized_text, 'positive'))
else:
self.examples = examples
@staticmethod
def clean_str(string):
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " ( ", string)
string = re.sub(r"\)", " ) ", string)
string = re.sub(r"\?", " ? ", string)
return re.sub(r"\s{2,}", " ", string).strip()
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
raw_text, raw_label = self.examples[idx]
# 确保输入是分好词的列表
processed_text = self.text_transform(raw_text) # 这里接收的应是列表
processed_label = self.label_transform(raw_label)
return torch.tensor(processed_text, dtype=torch.long), torch.tensor(processed_label, dtype=torch.long)
@classmethod
def splits(cls, text_transform, label_transform, dev_ratio=0.1, shuffle=True, root='.', **kwargs):
path = cls.download_or_unzip(root)
full_dataset = cls(text_transform, label_transform, path=path, **kwargs)
examples = full_dataset.examples
if shuffle:
random.shuffle(examples)
split_idx = int(len(examples) * (1 - dev_ratio))
return (cls(text_transform, label_transform, examples=examples[:split_idx]),
cls(text_transform, label_transform, examples=examples[split_idx:]))
model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
class CNN_Text(nn.Module):
def __init__(self, args):
super(CNN_Text, self).__init__()
self.args = args
# 词向量层
self.embed = nn.Embedding(args.embed_num, args.embed_dim)
# 卷积层组
self.convs = nn.ModuleList([
nn.Conv2d(
in_channels=1,
out_channels=args.kernel_num,
kernel_size=(k, args.embed_dim) # 卷积核尺寸 (高度, 宽度)
) for k in args.kernel_sizes
])
# 分类层
self.dropout = nn.Dropout(args.dropout)
self.fc = nn.Linear(len(args.kernel_sizes) * args.kernel_num, args.class_num)
# 冻结词向量
if args.static:
self.embed.weight.requires_grad = False
def forward(self, x):
# 输入形状: (batch_size, seq_len)
x = self.embed(x) # (batch_size, seq_len, embed_dim)
# 添加通道维度: (batch_size, 1, seq_len, embed_dim)
x = x.unsqueeze(1)
# 卷积处理
conv_outputs = []
for conv in self.convs:
# 卷积: (batch, Co, seq_len-k+1, 1)
conv_out = F.relu(conv(x))
# 去除最后维度: (batch, Co, seq_len-k+1)
conv_out = conv_out.squeeze(3)
# 最大池化: (batch, Co)
pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
conv_outputs.append(pooled)
# 拼接特征
x = torch.cat(conv_outputs, 1)
x = self.dropout(x)
logits = self.fc(x)
return logits
训练
运行命令python main.py -epochs 30 (训练时间过长,将epoch改为30,缩短训练时间)
训练过程
训练后将会得到类似的目录结构,快照选项表示模型从何处加载。如果不指定该选项,模型将从头开始。
测试
测试命令如下,让模型判断"Hello my dear , I love you so much ."这句话的情感。
python main.py -predict="Hello my dear , I love you so much ."
-snapshot="./snapshot/2025-05-01_14-36-47/best_steps_3800.pt"
结果
结果为positive