卷积神经网络用于句子分类在 PyTorch 中的实现

191 阅读6分钟

简介

利用CNN网络实现对简单文本的情感分类。

要求

numpy<2

torch与torchtext版本兼容

b7aaff96fe195351848139a9e948868.png

代码

创建环境

# 创建新的虚拟环境
python -m venv myenv

# 激活虚拟环境(Windows)
myenv\Scripts\activate

# 激活虚拟环境(Linux/Mac)
source myenv/bin/activate

main.py

import os
import argparse
import datetime
import torch
import torchtext
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
import model
import train
import mydatasets

# 禁用 torchtext 弃用警告
torchtext.disable_torchtext_deprecation_warning()

# 参数解析
parser = argparse.ArgumentParser(description='CNN text classifier')
# 学习参数
parser.add_argument('-lr', type=float, default=0.001, help='学习率 [默认: 0.001]')
parser.add_argument('-epochs', type=int, default=256, help='训练轮数 [默认: 256]')
parser.add_argument('-batch-size', type=int, default=64, help='批次大小 [默认: 64]')
parser.add_argument('-log-interval', type=int, default=1, help='日志间隔 [默认: 1]')
parser.add_argument('-test-interval', type=int, default=100, help='测试间隔 [默认: 100]')
parser.add_argument('-save-interval', type=int, default=500, help='模型保存间隔步数 [默认: 500]')
parser.add_argument('-save-dir', type=str, default='snapshot', help='模型保存路径')
parser.add_argument('-early-stop', type=int, default=1000, help='早停步数')
parser.add_argument('-save-best', type=bool, default=True, help='保存最佳模型')
# 数据参数
parser.add_argument('-shuffle', action='store_true', default=False, help='打乱数据')
# 模型参数
parser.add_argument('-dropout', type=float, default=0.5, help='Dropout概率 [默认: 0.5]')
parser.add_argument('-embed-dim', type=int, default=128, help='词向量维度 [默认: 128]')
parser.add_argument('-kernel-num', type=int, default=100, help='卷积核数量')
parser.add_argument('-kernel-sizes', type=str, default='3,4,5', help='卷积核尺寸')
parser.add_argument('-static', action='store_true', default=False, help='固定词向量')
# 设备参数
parser.add_argument('-device', type=int, default=-1, help='设备编号 (-1=CPU)')
parser.add_argument('-no-cuda', action='store_true', default=False, help='禁用GPU')
# 选项参数
parser.add_argument('-snapshot', type=str, default=None, help='模型快照路径')
parser.add_argument('-predict', type=str, default=None, help='预测文本')
parser.add_argument('-test', action='store_true', default=False, help='测试模式')

def collate_batch(batch):
    """处理批次数据:填充文本并转换为张量"""
    texts, labels = zip(*batch)
    # 直接使用已有张量(避免重复包装)
    padded_texts = pad_sequence(texts, batch_first=True, padding_value=0)
    labels = torch.stack(labels)
    return padded_texts, labels

if __name__ == '__main__':
    args = parser.parse_args()
    
    # 构建词汇表
    print("\nBuilding vocabulary...")
    tokenizer = get_tokenizer(lambda x: x.split())
    
    def yield_tokens(data_iter):
        for example in data_iter.examples:
            text, _ = example
            yield text

    # 临时转换函数
    temp_text_transform = lambda x: x
    temp_label_transform = lambda x: x

    # 加载原始数据集构建词汇表
    train_data, _ = mydatasets.MR.splits(temp_text_transform, temp_label_transform, root='.')
    vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=["<unk>"])
    vocab.set_default_index(vocab["<unk>"])

    # 定义正式转换函数
    text_transform = lambda x: vocab(x)
    label_transform = lambda x: 1 if x == 'positive' else 0

    # 加载带转换的数据集
    print("\nLoading datasets...")
    train_data, dev_data = mydatasets.MR.splits(text_transform, label_transform, root='.', dev_ratio=0.1)
    
    # 创建DataLoader
    train_iter = torch.utils.data.DataLoader(
        train_data, 
        batch_size=args.batch_size, 
        collate_fn=collate_batch,
        shuffle=args.shuffle
    )
    dev_iter = torch.utils.data.DataLoader(
        dev_data,
        batch_size=args.batch_size,
        collate_fn=collate_batch
    )

    # 更新模型参数
    args.embed_num = len(vocab)
    args.class_num = 2
    args.cuda = (not args.no_cuda) and torch.cuda.is_available()
    args.kernel_sizes = [int(k) for k in args.kernel_sizes.split(',')]
    args.save_dir = os.path.join(args.save_dir, datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))

    print("\nModel Parameters:")
    for attr, value in sorted(args.__dict__.items()):
        print(f"\t{attr.upper():<15} = {value}")

    # 初始化模型
    cnn = model.CNN_Text(args)
    if args.snapshot:
        print(f"\nLoading model from {args.snapshot}...")
        cnn.load_state_dict(torch.load(args.snapshot))

    if args.cuda:
        torch.cuda.set_device(args.device)
        cnn = cnn.cuda()

    # 运行模式
    if args.predict:
        label = train.predict(args.predict, cnn, text_transform, label_transform, args.cuda)
        print(f'\n[Text]\t{args.predict}\n[Label]\t{label}\n')
    elif args.test:
        try:
            train.eval(dev_iter, cnn, args)
        except Exception as e:
            print(f"\nTest error: {str(e)}")
    else:
        try:
            train.train(train_iter, dev_iter, cnn, args)
        except KeyboardInterrupt:
            print('\n' + '-'*50)
            print('Training stopped by user')

train.py

import os
import sys
import torch
import torch.autograd as autograd
import torch.nn.functional as F


def train(train_iter, dev_iter, model, args):
    if args.cuda:
        model.cuda()

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    steps = 0
    best_acc = 0
    last_step = 0
    for epoch in range(1, args.epochs + 1):
        for batch in train_iter:
            model.train()
            feature, target = batch

            if args.cuda:
                feature, target = feature.cuda(), target.cuda()

            optimizer.zero_grad()
            logit = model(feature)
            loss = F.cross_entropy(logit, target)
            loss.backward()
            optimizer.step()

            steps += 1
            if steps % args.log_interval == 0:
                corrects = (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()
                accuracy = 100.0 * corrects / feature.size(0)
                sys.stdout.write(
                    '\rBatch[{}] - loss: {:.6f}  acc: {:.4f}%({}/{})'.format(steps,
                                                                             loss.item(),
                                                                             accuracy.item(),
                                                                             corrects.item(),
                                                                             feature.size(0)))
            if steps % args.test_interval == 0:
                dev_acc = eval(dev_iter, model, args)
                if dev_acc > best_acc:
                    best_acc = dev_acc
                    last_step = steps
                    if args.save_best:
                        save(model, args.save_dir, 'best', steps)
                else:
                    if steps - last_step >= args.early_stop:
                        print('early stop by {} steps.'.format(args.early_stop))
            elif steps % args.save_interval == 0:
                save(model, args.save_dir, 'snapshot', steps)


def eval(data_iter, model, args):
    model.eval()
    corrects, avg_loss = 0, 0
    with torch.no_grad():
        for batch in data_iter:
            feature, target = batch

            if args.cuda:
                feature, target = feature.cuda(), target.cuda()

            logit = model(feature)
            loss = F.cross_entropy(logit, target, reduction='sum')

            avg_loss += loss.item()
            corrects += (torch.max(logit, 1)[1].view(target.size()).data == target.data).sum()

    size = len(data_iter.dataset)
    avg_loss /= size
    accuracy = 100.0 * corrects / size
    print('\nEvaluation - loss: {:.6f}  acc: {:.4f}%({}/{}) \n'.format(avg_loss,
                                                                       accuracy,
                                                                       corrects,
                                                                       size))
    return accuracy


def predict(text, model, text_transform, label_transform, cuda_flag):
    assert isinstance(text, str)
    model.eval()
    text = text_transform(text)
    x = torch.tensor([text])
    x = autograd.Variable(x)
    if cuda_flag:
        x = x.cuda()
    print(x)
    output = model(x)
    _, predicted = torch.max(output, 1)
    label_mapping = {0: 'negative', 1: 'positive'}
    return label_mapping[predicted.item()]


def save(model, save_dir, save_prefix, steps):
    if not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    save_prefix = os.path.join(save_dir, save_prefix)
    save_path = '{}_steps_{}.pt'.format(save_prefix, steps)
    torch.save(model.state_dict(), save_path)
    

mydatasets.py

# mydatasets.py
import re
import os
import random
import tarfile
import urllib.request
import torch
from torch.utils.data import Dataset
from torchtext.data.utils import get_tokenizer

class TarDataset(Dataset):
    @classmethod
    def download_or_unzip(cls, root):
        path = os.path.join(root, cls.dirname)
        if not os.path.isdir(path):
            tpath = os.path.join(root, cls.filename)
            if not os.path.isfile(tpath):
                print('Downloading dataset...')
                urllib.request.urlretrieve(cls.url, tpath)
            with tarfile.open(tpath, 'r') as tfile:
                print('Extracting...')
                def is_within_directory(directory, target):
                    abs_directory = os.path.abspath(directory)
                    abs_target = os.path.abspath(target)
                    prefix = os.path.commonprefix([abs_directory, abs_target])
                    return prefix == abs_directory

                def safe_extract(tar, path=".", members=None, numeric_owner=False):
                    for member in tar.getmembers():
                        member_path = os.path.join(path, member.name)
                        if not is_within_directory(path, member_path):
                            raise Exception("Attempted Path Traversal in Tar File")
                    tar.extractall(path, members, numeric_owner=numeric_owner)

                safe_extract(tfile, root)
        return os.path.join(path, '')

class MR(TarDataset):
    url = 'https://www.cs.cornell.edu/people/pabo/movie-review-data/rt-polaritydata.tar.gz'
    filename = 'rt-polaritydata.tar.gz'
    dirname = 'rt-polaritydata'

    def __init__(self, text_transform, label_transform, path=None, examples=None):
        self.text_transform = text_transform
        self.label_transform = label_transform
        # 修正分词器:清洗后按空格切分
        self.tokenizer = get_tokenizer(lambda x: self.clean_str(x).split())

        if examples is None:
            path = self.download_or_unzip(os.getcwd() if path is None else path)
            self.examples = []
            # 加载负面评价(确保存储分词列表)
            with open(os.path.join(path, 'rt-polarity.neg'), 'r', encoding='latin-1') as f:
                for line in f:
                    tokenized_text = self.tokenizer(line.strip())
                    self.examples.append((tokenized_text, 'negative'))
            # 加载正面评价
            with open(os.path.join(path, 'rt-polarity.pos'), 'r', encoding='latin-1') as f:
                for line in f:
                    tokenized_text = self.tokenizer(line.strip())
                    self.examples.append((tokenized_text, 'positive'))
        else:
            self.examples = examples

    @staticmethod
    def clean_str(string):
        string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
        string = re.sub(r"\'s", " \'s", string)
        string = re.sub(r"\'ve", " \'ve", string)
        string = re.sub(r"n\'t", " n\'t", string)
        string = re.sub(r"\'re", " \'re", string)
        string = re.sub(r"\'d", " \'d", string)
        string = re.sub(r"\'ll", " \'ll", string)
        string = re.sub(r",", " , ", string)
        string = re.sub(r"!", " ! ", string)
        string = re.sub(r"\(", " ( ", string)
        string = re.sub(r"\)", " ) ", string)
        string = re.sub(r"\?", " ? ", string)
        return re.sub(r"\s{2,}", " ", string).strip()

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        raw_text, raw_label = self.examples[idx]
        # 确保输入是分好词的列表
        processed_text = self.text_transform(raw_text)  # 这里接收的应是列表
        processed_label = self.label_transform(raw_label)
        return torch.tensor(processed_text, dtype=torch.long), torch.tensor(processed_label, dtype=torch.long)

    @classmethod
    def splits(cls, text_transform, label_transform, dev_ratio=0.1, shuffle=True, root='.', **kwargs):
        path = cls.download_or_unzip(root)
        full_dataset = cls(text_transform, label_transform, path=path, **kwargs)
        examples = full_dataset.examples
        if shuffle:
            random.shuffle(examples)
        split_idx = int(len(examples) * (1 - dev_ratio))
        return (cls(text_transform, label_transform, examples=examples[:split_idx]),
                cls(text_transform, label_transform, examples=examples[split_idx:]))

model.py

import torch
import torch.nn as nn
import torch.nn.functional as F

class CNN_Text(nn.Module):
    def __init__(self, args):
        super(CNN_Text, self).__init__()
        self.args = args
        
        # 词向量层
        self.embed = nn.Embedding(args.embed_num, args.embed_dim)
        
        # 卷积层组
        self.convs = nn.ModuleList([
            nn.Conv2d(
                in_channels=1, 
                out_channels=args.kernel_num,
                kernel_size=(k, args.embed_dim)  # 卷积核尺寸 (高度, 宽度)
            ) for k in args.kernel_sizes
        ])
        
        # 分类层
        self.dropout = nn.Dropout(args.dropout)
        self.fc = nn.Linear(len(args.kernel_sizes) * args.kernel_num, args.class_num)
        
        # 冻结词向量
        if args.static:
            self.embed.weight.requires_grad = False

    def forward(self, x):
        # 输入形状: (batch_size, seq_len)
        x = self.embed(x)  # (batch_size, seq_len, embed_dim)
        
        # 添加通道维度: (batch_size, 1, seq_len, embed_dim)
        x = x.unsqueeze(1)  
        
        # 卷积处理
        conv_outputs = []
        for conv in self.convs:
            # 卷积: (batch, Co, seq_len-k+1, 1)
            conv_out = F.relu(conv(x))
            # 去除最后维度: (batch, Co, seq_len-k+1)
            conv_out = conv_out.squeeze(3)
            # 最大池化: (batch, Co)
            pooled = F.max_pool1d(conv_out, conv_out.size(2)).squeeze(2)
            conv_outputs.append(pooled)
        
        # 拼接特征
        x = torch.cat(conv_outputs, 1)
        x = self.dropout(x)
        logits = self.fc(x)
        return logits

训练

运行命令python main.py -epochs 30 (训练时间过长,将epoch改为30,缩短训练时间)

训练过程

image.png

image.png

训练后将会得到类似的目录结构,快照选项表示模型从何处加载。如果不指定该选项,模型将从头开始。

测试

测试命令如下,让模型判断"Hello my dear , I love you so much ."这句话的情感。

 python main.py -predict="Hello my dear , I love you so much ." 
           -snapshot="./snapshot/2025-05-01_14-36-47/best_steps_3800.pt" 
           

结果

结果为positive image.png