PyTorch项目实战11——CIFAR10模型训练和测试

247 阅读3分钟

前边定义了网络结构juejin.cn/spost/72430…

现在我们来进行模型训练。

1 导入库

导入pytorch中的优化库

import torch.optim as optim

2 创建训练模型

在main方法下,添加训练模型。

  • 遍历训练集,并从集合中取出对应的输入图像及标签
  • 将图像和标签放到显存中,如果有的话,否则还是由CPU来计算
  • 对输入图像进行正向和反向传播计算
  • 获取过程中的交叉乘损失
  • 获取计算过程中的梯度数据
  • 并在每执行2000次计算后,打印出损失值的平均值
# 训练模型
# 获得交叉乘损失
criterion = nn.CrossEntropyLoss()
# 随机梯度下降
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    # 初始损失
    runing_loss = 0.0
    # 遍历训练集
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        # 将输入数据放入显存中
        inputs, labels = inputs.to(device), labels.to(device)
        # 权重参数梯度清零
        optimizer.zero_grad()
        # 正向和反向传播
        outputs = net(inputs)
        # 获取交叉损失
        loss = criterion(outputs, labels)
        # 获取梯度
        loss.backward()
        # 更新参数
        optimizer.step()
        # 显示损失值
        runing_loss += loss.item()
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, runing_loss/2000))
            runing_loss = 0.0

启动程序执行后,可以看到最终的损失值平均值是0.439,但是在中间执行过程中,损失值是处于波动状态,有些值甚至比最终值还要低一点。

[1, 2000] loss: 2.281

[1, 4000] loss: 2.213

[1, 6000] loss: 1.959

[1, 8000] loss: 1.717

[1, 10000] loss: 1.487

……

[10, 2000] loss: 0.320

[10, 4000] loss: 0.363

[10, 6000] loss: 0.383

[10, 8000] loss: 0.392

[10, 10000] loss: 0.444

[10, 12000] loss: 0.439

3 创建测试模型

使用 torch.no_grad() 方法,即不使用梯度,对训练模型进行测试,并在测试完成后输出整体的测试准确率。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print("Accuracy of the network on the 10000 test images: %d %%" % (100 * correct / total))

执行程序后,运行结果如下:

Accuracy of the network on the 10000 test images: 67 %

所有标签训练的整体结果为67%,可以看到成功率并不是很高。后边还需要继续进行优化。

那么当前算法下,每个分类的训练准确率大概是什么水平呢?

4 计算不同分类的准确率

计算各分类测试准确率时,仍然采用 torch.no_grad() 方法。

增加了 c = (predicted == labels).squeeze() 方法,将获得的结果降低到一维;

并在其下对同一个批次的4张图片,分别取出图像的标签和通过模型计算得到的标签进行比对:

  • 比对一致时,正确数加1;
  • 否则正确数不增加

最后按照不同的类型,将计算得到的准确率,分别输出的控制台。

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        # 向量除维
        c = (predicted == labels).squeeze()
        for i in range(4):
            label = labels[i]
            class_correct[label] += c[i].item()  # 判断正确时增加1
            class_total[label] += 1

for i in range(10):
    print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

运行程序后,查看控制台输出

Accuracy of airplane : 75 %

Accuracy of automobile : 83 %

Accuracy of bird : 69 %

Accuracy of cat : 47 %

Accuracy of deer : 47 %

Accuracy of dog : 51 %

Accuracy of frog : 70 %

Accuracy of horse : 78 %

Accuracy of ship : 81 %

Accuracy of truck : 70 %

可以看出,不同类别,其识别准确率也不尽相同。