pytorch案例-Mnist手写数字识别

63 阅读9分钟

整体流程如下

  • 0.依赖引入
  • 0.pytorch-硬件设备处理
  • 1.设置数据批次大小
  • 2.构建训练数据
  • 3.构建测试数据
  • 4.查看数据集内容
  • 5.展示数据集内手写数字
  • *********
  • 6.构建模型-特征提取
  • 7.构建模型-分类
  • 8.构建模型-前向传播
  • *********
  • 9.创建损失函数
  • 10.设置学习率
  • 11.设置优化器
  • *********
  • 12.构造训练函数-获取训练批次及训练数据集大小
  • 13.构造训练函数-调用模型获取预测值
  • 14.构造训练函数-调用损失函数
  • 15.构造训练函数-梯度归零、计算、更新参数
  • 16.构造训练函数-计算正确率
  • 17.构造训练函数-计算损失值
  • *********
  • 18.构造测试函数-获取测试批次及测试数据集大小
  • 19.构造测试函数-调用模型获取预测值
  • 20.构造测试函数-调用损失函数
  • 21.构造测试函数-计算争取率、损失值
  • *********
  • 22.模型训练-训练轮次设置
  • 23.模型训练-开启训练模式
  • 24.模型训练-调用训练函数使用训练数据集进行训练
  • 25.模型训练-记录训练正确率、训练时损失值
  • 26.模型训练-开始评估模式
  • 27.模型训练-调用测试函数使用测试集进行测试
  • 28.模型训练-记录测试争取率、测试时损失值
  • *********
  • 29.训练效果展示-正确率变化
  • 30.训练效果展示-损失值变化

0-依赖引入

import torch
import torch.nn as nn
import torch.nn.functional as f;
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchinfo import summary
import warnings
from datetime import datetime

0-pytorch-硬件设备处理

# 如果当前pytorch环境继承了cuda则使用cuda(即gpu)否则使用cpu
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)

1.设置数据批次大小

# 批次数据大小
batch_size = 32

2.构建训练数据(注意,transform=torchvision.transforms.ToTensor()省略括号时会报错)

# 训练数据
train_ds = torchvision.datasets.MNIST('data',
                                      train=True,
                                      # 此处的ToTensor() 如果写成ToTensor后续会报错
                                      transform=torchvision.transforms.ToTensor(),
                                      download=True)
train_dl = torch.utils.data.DataLoader(train_ds,
                                       batch_size=batch_size,
                                       shuffle=True)
print(len(train_ds), 'len(train_ds)')
print(len(train_dl), 'len(train_dl)')

1739518560235.png

3.构建测试数据

# 测试数据
test_ds = torchvision.datasets.MNIST('data',
                                     train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)

test_dl = torch.utils.data.DataLoader(test_ds,
                                       batch_size=batch_size)

print(len(test_ds), 'len(test_ds)')
print(len(test_dl), 'len(test_dl)')

1739519202178.png

4.查看数据集内容

def detail_data_loader(data_loader):

    # 查看DataLoader的Length
    print(len(data_loader), 'dataLoader的Length')

    # DataLoader的Length等于 数据集大小 除以 批次大小
    print(len(train_ds) / batch_size, '训练数据集大小 除以 批次大小')

    # 打印第一条数据
    imgs, labels = next(iter(data_loader))
    print(imgs[0], labels[0], "imgs[0],labels[0] ")


# 查看DataLoader的情况(以训练数据集为例)
detail_data_loader(train_dl)

1875 dataLoader的Length 1875.0 训练数据集大小 除以 批次大小 tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6745, 0.9961, 0.9961, 0.9961, 0.9961, 1.0000, 0.9412, 0.4784, 0.0863, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7020, 0.9412, 0.8431, 0.8431, 0.8980, 0.9961, 0.9922, 0.9922, 0.9176, 0.3098, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0196, 0.1451, 0.0000, 0.0000, 0.0824, 0.2196, 0.4118, 0.7373, 0.9922, 0.9490, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0157, 0.5255, 0.9922, 0.9176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.9255, 0.9176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0235, 0.9255, 0.9176, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2824, 0.6980, 0.5608, 0.3333, 0.2863, 0.1373, 0.6627, 0.9922, 0.7922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5882, 0.9922, 0.9922, 0.9961, 0.9843, 0.9725, 0.9922, 0.9922, 0.1961, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0863, 0.4627, 0.7843, 0.8824, 0.8980, 0.9922, 0.9922, 0.9922, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0392, 0.2706, 0.6235, 0.9647, 0.9725, 0.2196, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5961, 0.9961, 0.7490, 0.0118, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1608, 0.9098, 0.9922, 0.0902, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4431, 0.9922, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.4431, 0.9922, 0.4039, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0353, 0.8157, 0.9922, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0588, 0.5686, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0510, 0.5176, 0.9922, 0.7176, 0.0078, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5255, 0.9059, 0.0510, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1765, 0.4863, 0.8941, 0.9922, 0.9647, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7725, 0.9922, 0.5059, 0.0000, 0.0000, 0.0510, 0.2235, 0.2235, 0.6863, 0.9765, 0.9922, 0.9922, 0.8667, 0.3059, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6863, 0.9529, 0.9647, 0.8471, 0.8471, 0.8784, 0.9922, 0.9922, 0.9961, 0.6980, 0.5098, 0.2667, 0.0667, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.2824, 0.8275, 0.9922, 0.9922, 0.9098, 0.6235, 0.6235, 0.2588, 0.0275, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]]) tensor(3) imgs[0],labels[0]

5.展示数据集内手写数字

# 获取一个批次的图像及标签
imgs, labels = next(iter(train_dl))


def display_data(imgs):
    plt.figure(figsize=(20, 5))

    # 取前20个数据
    for i, img in enumerate(imgs[:24]):
        if i == 0:
            # 第一个数据打印维度信息 [1,28,28]
            print(img.numpy().shape, 'shape Before')
            # 维度缩减后 [28,28]
            print(np.squeeze(img.numpy()).shape, 'shape Before')
        npimg = np.squeeze(img.numpy())
        plt.subplot(2, 12, i+1)
        plt.imshow(npimg, cmap=plt.cm.binary)
        plt.axis('off')
    plt.show()

1739519769549.jpg

1739519781618.png

6~8.构建模型

num_classes = 10


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征提取网络
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)

        # 分类网络
        self.fc1 = nn.Linear(1600, 64)
        self.fc2 = nn.Linear(64, num_classes)

    # 前向传播
    def forward(self, x):
        x = self.pool1(f.relu(self.conv1(x)))
        x = self.pool2(f.relu(self.conv2(x)))

        x = torch.flatten(x, start_dim=1)

        x = f.relu(self.fc1(x))

        x = self.fc2(x)

        return x


model = Model().to(device)
# 打印模型信息
summary(model)

1739520146237.jpg

9~11.创建损失函数、设置学习率及优化器

# 创建损失函数
loss_fn = nn.CrossEntropyLoss()
# 学习率 0.01
learn_rate = 1e-2
# 优化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)

12~17.构造训练函数

# 训练函数
def train(dataloader, model, loss_fn, optimizer):
    print('Train')
    # 数据集大小 len(dataloader.dataset) = 60000
    size = len(dataloader.dataset)
    # 批次大小,len(dataloader) = 1875 = 60000/32
    num_batches = len(dataloader)

    # 初始化损失以及正确率
    train_loss, train_right_count = 0, 0

    for x, y in dataloader:
        # 移动数据至gpu或cpu
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss= loss_fn(pred,y)

        # 反向传播
        # 梯度归零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 更新参数
        optimizer.step()

        # 记录acc
        # pred.argmax(1) argmax(i)在第i个维度上取最大值所在的索引
        # pred的形状为[batch * num_classes]
        # 第一个维度为10种分类的概率(数字0到9共计10种分类),取最大值所在索引即模型训练时 预测到的数字
        # pred.argmax(1) == y 最大值所在的索引和标签一致,说明预测正确
        # .type(torch.float)将预测正确的转为数字1.0   .sum()为加和 即预测正确的数量
        # xxxx.item()将张量转换为python的标量 标量即单个数值
        train_right_count += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    # 预测争取率= 预测正确数量 / 总数据集
    train_acc_rate = train_right_count / size
    # 计算每个批次平均loss
    train_loss /= num_batches

    return train_acc_rate, train_loss

18~21.构造测试函数

# 测试函数
def test(dataloader, model, loss_fn):
    print("Test")
    # 测试集数据量
    size = len(dataloader.dataset)

    # 测试集批次
    num_batches = len(dataloader)

    # 测试时的loss, 测试时预测正确的次数
    test_loss, test_right_count = 0, 0

    # 停止梯度更新,一方面防止和训练时的梯度混淆,一方面节省内存损耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            # 计算loss
            pred = model(imgs)
            loss = loss_fn(pred, target)

            test_loss += loss.item()
            test_right_count += (pred.argmax(1) == target).type(torch.float).sum().item()

    # 计算每个批次平均的loss
    test_loss /= num_batches
    # 计算正确率
    test_right_rate = test_right_count / size

    return test_right_rate, test_loss

22~28.模型训练

epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    # 开启训练模式
    model.train()
    # 调用训练方法,获取争取率和损失值
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

    # 开启评估模式
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))

print('Train Complete')

1739520224607.jpg 30.训练效果展示-正确率变化

# 忽略告警信息
warnings.filterwarnings('ignore')
# 正常显示中文标签
plt.rcParams['font.sans-serif'] = ['SimHei']
# 正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 分辨率
plt.rcParams['figure.dpi'] = 100

# 获取当前时间
current_time = datetime.now()

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)


plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)

31.训练效果展示-损失值变化

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

1739520199509.png

全部代码

import torch
import torch.nn as nn
import torch.nn.functional as f;
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchinfo import summary
import warnings
from datetime import datetime


# 设置硬件设备
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)


# 批次数据大小
batch_size = 32

# 训练数据
train_ds = torchvision.datasets.MNIST('data',
                                      train=True,
                                      # 此处的ToTensor() 如果写成ToTensor后续会报错
                                      transform=torchvision.transforms.ToTensor(),
                                      download=True)
train_dl = torch.utils.data.DataLoader(train_ds,
                                       batch_size=batch_size,
                                       shuffle=True)
print(len(train_ds), 'len(train_ds)')
print(len(train_dl), 'len(train_dl)')


# 测试数据
test_ds = torchvision.datasets.MNIST('data',
                                     train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)

test_dl = torch.utils.data.DataLoader(test_ds,
                                       batch_size=batch_size)

print(len(test_ds), 'len(test_ds)')
print(len(test_dl), 'len(test_dl)')


def detail_data_loader(data_loader):

    # 查看DataLoader的Length
    print(len(data_loader), 'dataLoader的Length')

    # DataLoader的Length等于 数据集大小 除以 批次大小
    print(len(train_ds) / batch_size, '训练数据集大小 除以 批次大小')

    # 打印第一条数据
    imgs, labels = next(iter(data_loader))
    print(imgs[0], labels[0], "imgs[0],labels[0] ")


# 查看DataLoader的情况(以训练数据集为例)
detail_data_loader(train_dl)

# 获取一个批次的图像及标签
imgs, labels = next(iter(train_dl))


def display_data(imgs):
    plt.figure(figsize=(20, 5))

    # 取前20个数据
    for i, img in enumerate(imgs[:24]):
        if i == 0:
            # 第一个数据打印维度信息 [1,28,28]
            print(img.numpy().shape, 'shape Before')
            # 维度缩减后 [28,28]
            print(np.squeeze(img.numpy()).shape, 'shape Before')
        npimg = np.squeeze(img.numpy())
        plt.subplot(2, 12, i+1)
        plt.imshow(npimg, cmap=plt.cm.binary)
        plt.axis('off')
    plt.show()


# 可视化数据
display_data(imgs)


num_classes = 10


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 特征提取网络
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.pool2 = nn.MaxPool2d(2)

        # 分类网络
        self.fc1 = nn.Linear(1600, 64)
        self.fc2 = nn.Linear(64, num_classes)

    # 前向传播
    def forward(self, x):
        x = self.pool1(f.relu(self.conv1(x)))
        x = self.pool2(f.relu(self.conv2(x)))

        x = torch.flatten(x, start_dim=1)

        x = f.relu(self.fc1(x))

        x = self.fc2(x)

        return x


model = Model().to(device)
# 打印模型信息
summary(model)

# 创建损失函数
loss_fn = nn.CrossEntropyLoss()
# 学习率 0.01
learn_rate = 1e-2
# 优化器
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)


# 训练函数
def train(dataloader, model, loss_fn, optimizer):
    print('Train')
    # 数据集大小 len(dataloader.dataset) = 60000
    size = len(dataloader.dataset)
    # 批次大小,len(dataloader) = 1875 = 60000/32
    num_batches = len(dataloader)

    # 初始化损失以及正确率
    train_loss, train_right_count = 0, 0

    for x, y in dataloader:
        # 移动数据至gpu或cpu
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss= loss_fn(pred,y)

        # 反向传播
        # 梯度归零
        optimizer.zero_grad()
        # 计算梯度
        loss.backward()
        # 更新参数
        optimizer.step()

        # 记录acc
        # pred.argmax(1) argmax(i)在第i个维度上取最大值所在的索引
        # pred的形状为[batch * num_classes]
        # 第一个维度为10种分类的概率(数字0到9共计10种分类),取最大值所在索引即模型训练时 预测到的数字
        # pred.argmax(1) == y 最大值所在的索引和标签一致,说明预测正确
        # .type(torch.float)将预测正确的转为数字1.0   .sum()为加和 即预测正确的数量
        # xxxx.item()将张量转换为python的标量 标量即单个数值
        train_right_count += (pred.argmax(1) == y).type(torch.float).sum().item()
        train_loss += loss.item()

    # 预测争取率= 预测正确数量 / 总数据集
    train_acc_rate = train_right_count / size
    # 计算每个批次平均loss
    train_loss /= num_batches

    return train_acc_rate, train_loss


# 测试函数
def test(dataloader, model, loss_fn):
    print("Test")
    # 测试集数据量
    size = len(dataloader.dataset)

    # 测试集批次
    num_batches = len(dataloader)

    # 测试时的loss, 测试时预测正确的次数
    test_loss, test_right_count = 0, 0

    # 停止梯度更新,一方面防止和训练时的梯度混淆,一方面节省内存损耗
    with torch.no_grad():
        for imgs, target in dataloader:
            imgs, target = imgs.to(device), target.to(device)
            # 计算loss
            pred = model(imgs)
            loss = loss_fn(pred, target)

            test_loss += loss.item()
            test_right_count += (pred.argmax(1) == target).type(torch.float).sum().item()

    # 计算每个批次平均的loss
    test_loss /= num_batches
    # 计算正确率
    test_right_rate = test_right_count / size

    return test_right_rate, test_loss


epochs = 5
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    # 开启训练模式
    model.train()
    # 调用训练方法,获取争取率和损失值
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)

    # 开启评估模式
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(epoch_train_acc)
    train_loss.append(epoch_train_loss)
    test_acc.append(epoch_test_acc)
    test_loss.append(epoch_test_loss)

    template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f},Test_acc:{:.1f}%,Test_loss:{:.3f}')
    print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss))

print('Train Complete')

# 忽略告警信息
warnings.filterwarnings('ignore')
# 正常显示中文标签
plt.rcParams['font.sans-serif'] = ['SimHei']
# 正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 分辨率
plt.rcParams['figure.dpi'] = 100

# 获取当前时间
current_time = datetime.now()

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)


plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)


plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()