Pytorch手写数字识别

225 阅读2分钟

Pytorch手写数字识别

内容搬运参考文章部分[1]以及其他文章的内容,主要梳理深度学习的流程

步骤梳理

  1. 准备数据集(本次使用MINST数据集)和Dataloader
  2. 构建神经网络模型(CNN卷积神经网络)
  3. 准备损失函数(交叉熵损失函数)和优化器(SGD)
  4. 训练神经网络
  5. 训练成果的保存与读取
  6. 测试神经网络

0.导入相关依赖

import torch
import torchvision # 计算机视觉相关
from torch.utils.data import DataLoader
from models import RLS_NN # 导入模型

1.准备数据集

  • 导入数据
# 2.1 训练集数据
train_data = torchvision.datasets.MNIST(
  root='../data', # 数据集保存路径
  download=True, # 表示需要从网络上下载
  train=True, # 这是训练集
  transform=torchvision.transforms.ToTensor() # 将数据转换成tensor格式
)
# 2.2 测试集数据
test_data = torchvision.datasets.MNIST(
  root='../data',
  download=True,
  train=False,
  transform=torchvision.transforms.ToTensor()
)
  • 创建loader
train_dataloader = DataLoader(train_data, batch_size=64) # batch_size表示一次训练几个数据
test_dataloader = DataLoader(test_data, batch_size=64)

2.构建神经网络模型

import torch.nn as nn

class RLS_NN(nn.Module):
  def __init__(self) -> None:
    super(RLS_NN, self).__init__()
    self.net = nn.Sequential(
      nn.Conv2d(
        in_channels=1, out_channels=16, # 通道数
        kernel_size=(3,3), # 卷积核大小
        stride=(1,1), # 步长
        padding=1
      ), # 16*28*28
      nn.MaxPool2d(kernel_size=2), # 16*14*14
      nn.Conv2d(16, 32, 3, 1, 1), # 32*14*14
      nn.MaxPool2d(2), # 32*7*7
      nn.Flatten(),
      nn.Linear(32*7*7, 16), # 全连接层
      nn.ReLU(),
      nn.Linear(16, 10) # 10分类任务
    )

  def forward(self, x):
    return self.net(x)

3.准备损失函数

loss_func = torch.nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.SGD(models.parameters(), lr=0.1)

4.训练神经网络

cnt_epochs = 10

for i in range(cnt_epochs):
  for imgs, labels in train_dataloader:
    imgs = imgs.to(device)
    labels = labels.to(device)
    outputs = models(imgs)
    loss = loss_func(outputs, labels)
    optimizer.zero_grad() # 清空优化器的梯度
    loss.backward()
    optimizer.step()
  
  total_loss = 0
  with torch.no_grad(): # 表明不需要自动求导
    for imgs, labels in test_dataloader:
      imgs = imgs.to(device)
      labels = labels.to(device)
      outputs = models(imgs)
      loss = loss_func(outputs, labels)
      loss.to('cpu')
      total_loss += loss
  print('第{}次打印损失:{}'.format(i, total_loss))
  
 output:
第0次打印损失:21.644510269165041次打印损失:12.5033130645751952次打印损失:10.9500980377197273次打印损失:10.0877981185913094次打印损失:9.6355571746826175次打印损失:8.8127346038818366次打印损失:8.5556650161743167次打印损失:8.1658468246459968次打印损失:8.1973085403442389次打印损失:8.533809661865234

5.模型保存

torch.save(models, 'my_cnn.nn')

6.测试模型精度

models = torch.load('./my_cnn.nn')
models.to('cpu') # 模型训练的时候用了gpu
models.eval() # 推理模式
test_data = torchvision.datasets.MNIST(
  root='../data',
  download=False,
  train=False,
  transform=torchvision.transforms.ToTensor()
)
test_loader = DataLoader(test_data, batch_size=64)
total_correct = 0
for imgs, labels in test_loader:
  outputs = models(imgs)
  pred = outputs.argmax(dim=1) # 将概率最大的位置索引作为预测值
  corret = pred.eq(labels).sum().float().item()
  total_correct += corret

total = len(test_loader.dataset)  
acc = total_correct / total
print("test acc: ", acc) # 0.9831

7.其他补充

7.1 绘制手写数字图片 随机挑选6张图片进行打印

def plot_image(x, label, name):
    """
    显示6张手写数字图片以及对应的数字标签
    """
    for i in range(6):
        plt.subplot(2, 3, i + 1)
        plt.tight_layout()
        plt.imshow(x[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
        plt.title("{}: {}".format(name, label[i].item()))
        plt.xticks([])
        plt.yticks([])
    plt.show()
 
 # ...参照测试代码
 plot_image(imgs, pred, 'image sample')

image.png

参考文章

  1. 手把手教你用pytorch实现手写数字识别
  2. 一文详解用PyTorch解决手写数字问题
  3. pytorch 实现 minist手写数据集(cpu/gpu)版本