- 🍨 本文为🔗 365天深度学习训练营中的学习记录博客
- 🍖 原作者: K同学啊
整体流程如下
- 0.依赖引入
- 0.pytorch-硬件设备处理
- 1.设置数据批次大小
- 2.构建训练数据
- 3.构建测试数据
- 4.查看数据集内容
- 5.展示数据集内手写数字
- *********
- 6.构建模型-特征提取
- 7.构建模型-分类
- 8.构建模型-前向传播
- *********
- 9.创建损失函数
- 10.设置学习率
- 11.设置优化器
- *********
- 12.构造训练函数-获取训练批次及训练数据集大小
- 13.构造训练函数-调用模型获取预测值
- 14.构造训练函数-调用损失函数
- 15.构造训练函数-梯度归零、计算、更新参数
- 16.构造训练函数-计算正确率
- 17.构造训练函数-计算损失值
- *********
- 18.构造测试函数-获取测试批次及测试数据集大小
- 19.构造测试函数-调用模型获取预测值
- 20.构造测试函数-调用损失函数
- 21.构造测试函数-计算争取率、损失值
- *********
- 22.模型训练-训练轮次设置
- 23.模型训练-开启训练模式
- 24.模型训练-调用训练函数使用训练数据集进行训练
- 25.模型训练-记录训练正确率、训练时损失值
- 26.模型训练-开始评估模式
- 27.模型训练-调用测试函数使用测试集进行测试
- 28.模型训练-记录测试争取率、测试时损失值
- *********
- 29.训练效果展示-正确率变化
- 30.训练效果展示-损失值变化
0-依赖引入
import torch
import torch.nn as nn
import torch.nn.functional as f;
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchinfo import summary
import warnings
from datetime import datetime
0-pytorch-硬件设备处理
# 如果当前pytorch环境继承了cuda则使用cuda(即gpu)否则使用cpu
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)
1.设置数据批次大小
# 批次数据大小
batch_size = 32
2.构建训练数据(注意,transform=torchvision.transforms.ToTensor()省略括号时会报错)
# 训练数据
train_ds = torchvision.datasets.MNIST('data',
train=True,
# 此处的ToTensor() 如果写成ToTensor后续会报错
transform=torchvision.transforms.ToTensor(),
download=True)
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=batch_size,
shuffle=True)
print(len(train_ds), 'len(train_ds)')
print(len(train_dl), 'len(train_dl)')
3.构建测试数据
# 测试数据
test_ds = torchvision.datasets.MNIST('data',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
test_dl = torch.utils.data.DataLoader(test_ds,
batch_size=batch_size)
print(len(test_ds), 'len(test_ds)')
print(len(test_dl), 'len(test_dl)')
4.查看数据集内容
def detail_data_loader(data_loader):
# 查看DataLoader的Length
print(len(data_loader), 'dataLoader的Length')
# DataLoader的Length等于 数据集大小 除以 批次大小
print(len(train_ds) / batch_size, '训练数据集大小 除以 批次大小')
# 打印第一条数据
imgs, labels = next(iter(data_loader))
print(imgs[0], labels[0], "imgs[0],labels[0] ")
# 查看DataLoader的情况(以训练数据集为例)
detail_data_loader(train_dl)
1875 dataLoader的Length 1875.0 训练数据集大小 除以 批次大小 tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6745, 0.9961, 0.9961, 0.9961, 0.9961, 1.0000, 0.9412, 0.4784, 0.0863, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7020, 0.9412, 0.8431, 0.8431, 0.8980, 0.9961, 0.9922, 0.9922, 0.9176, 0.3098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.1451, 0.0000, 0.0000, 0.0824, 0.2196, 0.4118, 0.7373, 0.9922, 0.9490, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.5255, 0.9922, 0.9176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.9255, 0.9176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.9255, 0.9176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2824, 0.6980, 0.5608, 0.3333, 0.2863, 0.1373, 0.6627, 0.9922, 0.7922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5882, 0.9922, 0.9922, 0.9961, 0.9843, 0.9725, 0.9922, 0.9922, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.4627, 0.7843, 0.8824, 0.8980, 0.9922, 0.9922, 0.9922, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0392, 0.2706, 0.6235, 0.9647, 0.9725, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.9961, 0.7490, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1608, 0.9098, 0.9922, 0.0902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4431, 0.9922, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4431, 0.9922, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.8157, 0.9922, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.5686, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0510, 0.5176, 0.9922, 0.7176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5255, 0.9059, 0.0510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.4863, 0.8941, 0.9922, 0.9647, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7725, 0.9922, 0.5059, 0.0000, 0.0000, 0.0510, 0.2235, 0.2235, 0.6863, 0.9765, 0.9922, 0.9922, 0.8667, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6863, 0.9529, 0.9647, 0.8471, 0.8471, 0.8784, 0.9922, 0.9922, 0.9961, 0.6980, 0.5098, 0.2667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2824, 0.8275, 0.9922, 0.9922, 0.9098, 0.6235, 0.6235, 0.2588, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]) tensor(3) imgs[0],labels[0]
5.展示数据集内手写数字
# 获取一个批次的图像及标签
imgs, labels = next(iter(train_dl))
def display_data(imgs):
plt.figure(figsize=(20, 5))
# 取前20个数据
for i, img in enumerate(imgs[:24]):
if i == 0:
# 第一个数据打印维度信息 [1,28,28]
print(img.numpy().shape, 'shape Before')
# 维度缩减后 [28,28]
print(np.squeeze(img.numpy()).shape, 'shape Before')
npimg = np.squeeze(img.numpy())
plt.subplot(2, 12, i+1)
plt.imshow(npimg, cmap=plt.cm.binary)
plt.axis('off')
plt.show()
6~8.构建模型
num_classes = 10
class Model(nn.Module):
def __init__(self):
super().__init__()
# 特征提取网络
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.pool2 = nn.MaxPool2d(2)
# 分类网络
self.fc1 = nn.Linear(1600, 64)
self.fc2 = nn.Linear(64, num_classes)
# 前向传播
def forward(self, x):
x = self.pool1(f.relu(self.conv1(x)))
x = self.pool2(f.relu(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = f.relu(self.fc1(x))
x = self.fc2(x)
return x
model = Model().to(device)
# 打印模型信息
summary(model)
9~11.创建损失函数、设置学习率及优化器
# 创建损失函数
loss_fn = nn.CrossEntropyLoss()
# 学习率 0.01
learn_rate = 1e-2
# 优化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)
12~17.构造训练函数
# 训练函数
def train(dataloader, model, loss_fn, optimizer):
print('Train')
# 数据集大小 len(dataloader.dataset) = 60000
size = len(dataloader.dataset)
# 批次大小,len(dataloader) = 1875 = 60000/32
num_batches = len(dataloader)
# 初始化损失以及正确率
train_loss, train_right_count = 0, 0
for x, y in dataloader:
# 移动数据至gpu或cpu
x, y = x.to(device), y.to(device)
pred = model(x)
loss= loss_fn(pred,y)
# 反向传播
# 梯度归零
optimizer.zero_grad()
# 计算梯度
loss.backward()
# 更新参数
optimizer.step()
# 记录acc
# pred.argmax(1) argmax(i)在第i个维度上取最大值所在的索引
# pred的形状为[batch * num_classes]
# 第一个维度为10种分类的概率(数字0到9共计10种分类),取最大值所在索引即模型训练时 预测到的数字
# pred.argmax(1) == y 最大值所在的索引和标签一致,说明预测正确
# .type(torch.float)将预测正确的转为数字1.0 .sum()为加和 即预测正确的数量
# xxxx.item()将张量转换为python的标量 标量即单个数值
train_right_count += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
# 预测争取率= 预测正确数量 / 总数据集
train_acc_rate = train_right_count / size
# 计算每个批次平均loss
train_loss /= num_batches
return train_acc_rate, train_loss
18~21.构造测试函数
# 测试函数
def test(dataloader, model, loss_fn):
print("Test")
# 测试集数据量
size = len(dataloader.dataset)
# 测试集批次
num_batches = len(dataloader)
# 测试时的loss, 测试时预测正确的次数
test_loss, test_right_count = 0, 0
# 停止梯度更新,一方面防止和训练时的梯度混淆,一方面节省内存损耗
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
# 计算loss
pred = model(imgs)
loss = loss_fn(pred, target)
test_loss += loss.item()
test_right_count += (pred.argmax(1) == target).type(torch.float).sum().item()
# 计算每个批次平均的loss
test_loss /= num_batches
# 计算正确率
test_right_rate = test_right_count / size
return test_right_rate, test_loss
22~28.模型训练
epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
# 开启训练模式
model.train()
# 调用训练方法,获取争取率和损失值
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
# 开启评估模式
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Train Complete')
30.训练效果展示-正确率变化
# 忽略告警信息
warnings.filterwarnings('ignore')
# 正常显示中文标签
plt.rcParams['font.sans-serif'] = ['SimHei']
# 正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 分辨率
plt.rcParams['figure.dpi'] = 100
# 获取当前时间
current_time = datetime.now()
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)
31.训练效果展示-损失值变化
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
全部代码
import torch
import torch.nn as nn
import torch.nn.functional as f;
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchinfo import summary
import warnings
from datetime import datetime
# 设置硬件设备
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)
# 批次数据大小
batch_size = 32
# 训练数据
train_ds = torchvision.datasets.MNIST('data',
train=True,
# 此处的ToTensor() 如果写成ToTensor后续会报错
transform=torchvision.transforms.ToTensor(),
download=True)
train_dl = torch.utils.data.DataLoader(train_ds,
batch_size=batch_size,
shuffle=True)
print(len(train_ds), 'len(train_ds)')
print(len(train_dl), 'len(train_dl)')
# 测试数据
test_ds = torchvision.datasets.MNIST('data',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True)
test_dl = torch.utils.data.DataLoader(test_ds,
batch_size=batch_size)
print(len(test_ds), 'len(test_ds)')
print(len(test_dl), 'len(test_dl)')
def detail_data_loader(data_loader):
# 查看DataLoader的Length
print(len(data_loader), 'dataLoader的Length')
# DataLoader的Length等于 数据集大小 除以 批次大小
print(len(train_ds) / batch_size, '训练数据集大小 除以 批次大小')
# 打印第一条数据
imgs, labels = next(iter(data_loader))
print(imgs[0], labels[0], "imgs[0],labels[0] ")
# 查看DataLoader的情况(以训练数据集为例)
detail_data_loader(train_dl)
# 获取一个批次的图像及标签
imgs, labels = next(iter(train_dl))
def display_data(imgs):
plt.figure(figsize=(20, 5))
# 取前20个数据
for i, img in enumerate(imgs[:24]):
if i == 0:
# 第一个数据打印维度信息 [1,28,28]
print(img.numpy().shape, 'shape Before')
# 维度缩减后 [28,28]
print(np.squeeze(img.numpy()).shape, 'shape Before')
npimg = np.squeeze(img.numpy())
plt.subplot(2, 12, i+1)
plt.imshow(npimg, cmap=plt.cm.binary)
plt.axis('off')
plt.show()
# 可视化数据
display_data(imgs)
num_classes = 10
class Model(nn.Module):
def __init__(self):
super().__init__()
# 特征提取网络
self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
self.pool2 = nn.MaxPool2d(2)
# 分类网络
self.fc1 = nn.Linear(1600, 64)
self.fc2 = nn.Linear(64, num_classes)
# 前向传播
def forward(self, x):
x = self.pool1(f.relu(self.conv1(x)))
x = self.pool2(f.relu(self.conv2(x)))
x = torch.flatten(x, start_dim=1)
x = f.relu(self.fc1(x))
x = self.fc2(x)
return x
model = Model().to(device)
# 打印模型信息
summary(model)
# 创建损失函数
loss_fn = nn.CrossEntropyLoss()
# 学习率 0.01
learn_rate = 1e-2
# 优化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)
# 训练函数
def train(dataloader, model, loss_fn, optimizer):
print('Train')
# 数据集大小 len(dataloader.dataset) = 60000
size = len(dataloader.dataset)
# 批次大小,len(dataloader) = 1875 = 60000/32
num_batches = len(dataloader)
# 初始化损失以及正确率
train_loss, train_right_count = 0, 0
for x, y in dataloader:
# 移动数据至gpu或cpu
x, y = x.to(device), y.to(device)
pred = model(x)
loss= loss_fn(pred,y)
# 反向传播
# 梯度归零
optimizer.zero_grad()
# 计算梯度
loss.backward()
# 更新参数
optimizer.step()
# 记录acc
# pred.argmax(1) argmax(i)在第i个维度上取最大值所在的索引
# pred的形状为[batch * num_classes]
# 第一个维度为10种分类的概率(数字0到9共计10种分类),取最大值所在索引即模型训练时 预测到的数字
# pred.argmax(1) == y 最大值所在的索引和标签一致,说明预测正确
# .type(torch.float)将预测正确的转为数字1.0 .sum()为加和 即预测正确的数量
# xxxx.item()将张量转换为python的标量 标量即单个数值
train_right_count += (pred.argmax(1) == y).type(torch.float).sum().item()
train_loss += loss.item()
# 预测争取率= 预测正确数量 / 总数据集
train_acc_rate = train_right_count / size
# 计算每个批次平均loss
train_loss /= num_batches
return train_acc_rate, train_loss
# 测试函数
def test(dataloader, model, loss_fn):
print("Test")
# 测试集数据量
size = len(dataloader.dataset)
# 测试集批次
num_batches = len(dataloader)
# 测试时的loss, 测试时预测正确的次数
test_loss, test_right_count = 0, 0
# 停止梯度更新,一方面防止和训练时的梯度混淆,一方面节省内存损耗
with torch.no_grad():
for imgs, target in dataloader:
imgs, target = imgs.to(device), target.to(device)
# 计算loss
pred = model(imgs)
loss = loss_fn(pred, target)
test_loss += loss.item()
test_right_count += (pred.argmax(1) == target).type(torch.float).sum().item()
# 计算每个批次平均的loss
test_loss /= num_batches
# 计算正确率
test_right_rate = test_right_count / size
return test_right_rate, test_loss
epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []
for epoch in range(epochs):
# 开启训练模式
model.train()
# 调用训练方法,获取争取率和损失值
epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)
# 开启评估模式
model.eval()
epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
train_acc.append(epoch_train_acc)
train_loss.append(epoch_train_loss)
test_acc.append(epoch_test_acc)
test_loss.append(epoch_test_loss)
template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Train Complete')
# 忽略告警信息
warnings.filterwarnings('ignore')
# 正常显示中文标签
plt.rcParams['font.sans-serif'] = ['SimHei']
# 正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 分辨率
plt.rcParams['figure.dpi'] = 100
# 获取当前时间
current_time = datetime.now()
epochs_range = range(epochs)
plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()