前边定义了网络结构juejin.cn/spost/72430…
现在我们来进行模型训练。
1 导入库
导入pytorch中的优化库
import torch.optim as optim
2 创建训练模型
在main方法下,添加训练模型。
- 遍历训练集,并从集合中取出对应的输入图像及标签
- 将图像和标签放到显存中,如果有的话,否则还是由CPU来计算
- 对输入图像进行正向和反向传播计算
- 获取过程中的交叉乘损失
- 获取计算过程中的梯度数据
- 并在每执行2000次计算后,打印出损失值的平均值
# 训练模型
# 获得交叉乘损失
criterion = nn.CrossEntropyLoss()
# 随机梯度下降
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
for epoch in range(10):
# 初始损失
runing_loss = 0.0
# 遍历训练集
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 将输入数据放入显存中
inputs, labels = inputs.to(device), labels.to(device)
# 权重参数梯度清零
optimizer.zero_grad()
# 正向和反向传播
outputs = net(inputs)
# 获取交叉损失
loss = criterion(outputs, labels)
# 获取梯度
loss.backward()
# 更新参数
optimizer.step()
# 显示损失值
runing_loss += loss.item()
if i % 2000 == 1999:
print('[%d, %5d] loss: %.3f' % (epoch+1, i+1, runing_loss/2000))
runing_loss = 0.0
启动程序执行后,可以看到最终的损失值平均值是0.439,但是在中间执行过程中,损失值是处于波动状态,有些值甚至比最终值还要低一点。
[1, 2000] loss: 2.281
[1, 4000] loss: 2.213
[1, 6000] loss: 1.959
[1, 8000] loss: 1.717
[1, 10000] loss: 1.487
……
[10, 2000] loss: 0.320
[10, 4000] loss: 0.363
[10, 6000] loss: 0.383
[10, 8000] loss: 0.392
[10, 10000] loss: 0.444
[10, 12000] loss: 0.439
3 创建测试模型
使用 torch.no_grad() 方法,即不使用梯度,对训练模型进行测试,并在测试完成后输出整体的测试准确率。
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print("Accuracy of the network on the 10000 test images: %d %%" % (100 * correct / total))
执行程序后,运行结果如下:
Accuracy of the network on the 10000 test images: 67 %
所有标签训练的整体结果为67%,可以看到成功率并不是很高。后边还需要继续进行优化。
那么当前算法下,每个分类的训练准确率大概是什么水平呢?
4 计算不同分类的准确率
计算各分类测试准确率时,仍然采用 torch.no_grad() 方法。
增加了 c = (predicted == labels).squeeze() 方法,将获得的结果降低到一维;
并在其下对同一个批次的4张图片,分别取出图像的标签和通过模型计算得到的标签进行比对:
- 比对一致时,正确数加1;
- 否则正确数不增加
最后按照不同的类型,将计算得到的准确率,分别输出的控制台。
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
# 向量除维
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item() # 判断正确时增加1
class_total[label] += 1
for i in range(10):
print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
运行程序后,查看控制台输出
Accuracy of airplane : 75 %
Accuracy of automobile : 83 %
Accuracy of bird : 69 %
Accuracy of cat : 47 %
Accuracy of deer : 47 %
Accuracy of dog : 51 %
Accuracy of frog : 70 %
Accuracy of horse : 78 %
Accuracy of ship : 81 %
Accuracy of truck : 70 %
可以看出,不同类别,其识别准确率也不尽相同。