NLP实战高手课学习笔记(15):文本分类实践2--模型训练与评估、改进与建议

741 阅读4分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第30天,点击查看活动详情

说明

本系列博客将记录自己学习的课程:NLP实战高手课,链接为:time.geekbang.org/course/intr… 本篇为27-28节的课程笔记,主要介绍Pytorch中使用torchtext进行文本分类的示例,本篇博客将介绍如何在上一篇博客记录的模型定义后进行训练和评估,最后导师给出了一些改进的尝试方向。

评估函数的建立

IMDB数据集是一个典型的2分类数据集。为此,我们使用准确率作为评估指标,该函数的定义如下:

def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc

模型训练

接下来,我们介绍如何定义模型的训练,其主要包含以下几个模块:

  • 设置模型为train模式
  • 从Dataloader中逐个batch的加载数据
  • 清零优化器的梯度
  • 进行模型的forward操作,得到输出预测
  • 计算输出预测与真实值之间的loss
  • loss进行反向传播将梯度进行回传
  • 优化器进行优化

按照上面的结构,我们将代码书写如下:

def train(model, iterator, optimizer, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.train()
    
    for batch in iterator:        
        optimizer.zero_grad()        
        text, text_lengths = batch.text        
        predictions = model(text, text_lengths).squeeze(1)        
        loss = criterion(predictions, batch.label)        
        acc = binary_accuracy(predictions, batch.label)        
        loss.backward()        
        optimizer.step()        
        epoch_loss += loss.item()
        epoch_acc += acc.item()     
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

在这里,我们还计算了训练的loss和训练集上的预测准确率作为参考。

模型的验证

模型的验证模块主要衡量一个训练后的模型能够在验证集上的表现如何。其整理代码结构与训练相仿,但必须要注意的是,进行模型验证前一定要把模型设置为验证模式,此时,模型在计算时的梯度将不会保留。

其对应的代码如下:

def evaluate(model, iterator, criterion):
    
    epoch_loss = 0
    epoch_acc = 0
    
    model.eval()
    
    with torch.no_grad():    
        for batch in iterator:
            text, text_lengths = batch.text            
            predictions = model(text, text_lengths).squeeze(1)      
            loss = criterion(predictions, batch.label)            
            acc = binary_accuracy(predictions, batch.label)
            epoch_loss += loss.item()
            epoch_acc += acc.item()
        
    return epoch_loss / len(iterator), epoch_acc / len(iterator)

训练与评估

完成好以上的定义后,我们终于可以开始训练和评估了,我们设置训练的epoch为5,在每个epoch结束时对模型进行评估。

N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()
    
    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
    
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'model.pt')
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')

打印输出如下:

Epoch: 01 | Epoch Time: 1m 50s
	Train Loss: 0.558 | Train Acc: 70.57%
	 Val. Loss: 0.444 |  Val. Acc: 79.54%
Epoch: 02 | Epoch Time: 1m 50s
	Train Loss: 0.393 | Train Acc: 82.70%
	 Val. Loss: 0.383 |  Val. Acc: 83.21%
Epoch: 03 | Epoch Time: 1m 50s
	Train Loss: 0.287 | Train Acc: 88.10%
	 Val. Loss: 0.300 |  Val. Acc: 88.08%
Epoch: 04 | Epoch Time: 1m 50s
	Train Loss: 0.161 | Train Acc: 94.26%
	 Val. Loss: 0.314 |  Val. Acc: 87.84%
Epoch: 05 | Epoch Time: 1m 50s
	Train Loss: 0.122 | Train Acc: 95.53%
	 Val. Loss: 0.367 |  Val. Acc: 87.17%

可以看到,随着训练的进行,模型在训练集上的loss稳定下降,准确性也在逐步提高;而在验证集上准确性提高到一定数值后就不在提升了,甚至有所下降,这可能是模型过拟合导致。

得到验证集上的最好表现的模型后,我们在测试集上进行一步测试:

model.load_state_dict(torch.load('model.pt'))

test_loss, test_acc = evaluate(model, test_iterator, criterion)

print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')

输出如下:

Test Loss: 0.321 | Test Acc: 87.01%

可以看到,我们的模型在测试集上依旧表现良好,达到了87%的准确性。

改进与建议

以上的代码虽然可以取得较好的结果,但由于该问题比较简单,所以还有很多地方可以改进,主要的改进点如下:

  • 超参数设置:如学习率的大小、变化方式、需不需要warmup;batch_size的大小;
  • 初始化:Embedding、weight的初始化方式不同将会带来不同的表现;
  • 数据集清洗:我们可以去除掉某些乱码或者HTML标签等;
  • 分词优化:可以尝试其他tokenizer改进分词效果。
  • ……

总结

本篇和上一篇博客为大家介绍了一个完整的NLP中的文本分类任务所涉及的各个模块,由于篇幅所限,未能展示全部代码,有兴趣的读者可以参考gitee.com/geektime-ge…