Pytorch手写数字识别
内容搬运参考文章部分[1]以及其他文章的内容,主要梳理深度学习的流程
步骤梳理
- 准备数据集(本次使用MINST数据集)和Dataloader
- 构建神经网络模型(CNN卷积神经网络)
- 准备损失函数(交叉熵损失函数)和优化器(SGD)
- 训练神经网络
- 训练成果的保存与读取
- 测试神经网络
0.导入相关依赖
import torch
import torchvision # 计算机视觉相关
from torch.utils.data import DataLoader
from models import RLS_NN # 导入模型
1.准备数据集
- 导入数据
# 2.1 训练集数据
train_data = torchvision.datasets.MNIST(
root='../data', # 数据集保存路径
download=True, # 表示需要从网络上下载
train=True, # 这是训练集
transform=torchvision.transforms.ToTensor() # 将数据转换成tensor格式
)
# 2.2 测试集数据
test_data = torchvision.datasets.MNIST(
root='../data',
download=True,
train=False,
transform=torchvision.transforms.ToTensor()
)
- 创建loader
train_dataloader = DataLoader(train_data, batch_size=64) # batch_size表示一次训练几个数据
test_dataloader = DataLoader(test_data, batch_size=64)
2.构建神经网络模型
import torch.nn as nn
class RLS_NN(nn.Module):
def __init__(self) -> None:
super(RLS_NN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(
in_channels=1, out_channels=16, # 通道数
kernel_size=(3,3), # 卷积核大小
stride=(1,1), # 步长
padding=1
), # 16*28*28
nn.MaxPool2d(kernel_size=2), # 16*14*14
nn.Conv2d(16, 32, 3, 1, 1), # 32*14*14
nn.MaxPool2d(2), # 32*7*7
nn.Flatten(),
nn.Linear(32*7*7, 16), # 全连接层
nn.ReLU(),
nn.Linear(16, 10) # 10分类任务
)
def forward(self, x):
return self.net(x)
3.准备损失函数
loss_func = torch.nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.SGD(models.parameters(), lr=0.1)
4.训练神经网络
cnt_epochs = 10
for i in range(cnt_epochs):
for imgs, labels in train_dataloader:
imgs = imgs.to(device)
labels = labels.to(device)
outputs = models(imgs)
loss = loss_func(outputs, labels)
optimizer.zero_grad() # 清空优化器的梯度
loss.backward()
optimizer.step()
total_loss = 0
with torch.no_grad(): # 表明不需要自动求导
for imgs, labels in test_dataloader:
imgs = imgs.to(device)
labels = labels.to(device)
outputs = models(imgs)
loss = loss_func(outputs, labels)
loss.to('cpu')
total_loss += loss
print('第{}次打印损失:{}'.format(i, total_loss))
output:
第0次打印损失:21.64451026916504
第1次打印损失:12.503313064575195
第2次打印损失:10.950098037719727
第3次打印损失:10.087798118591309
第4次打印损失:9.635557174682617
第5次打印损失:8.812734603881836
第6次打印损失:8.555665016174316
第7次打印损失:8.165846824645996
第8次打印损失:8.197308540344238
第9次打印损失:8.533809661865234
5.模型保存
torch.save(models, 'my_cnn.nn')
6.测试模型精度
models = torch.load('./my_cnn.nn')
models.to('cpu') # 模型训练的时候用了gpu
models.eval() # 推理模式
test_data = torchvision.datasets.MNIST(
root='../data',
download=False,
train=False,
transform=torchvision.transforms.ToTensor()
)
test_loader = DataLoader(test_data, batch_size=64)
total_correct = 0
for imgs, labels in test_loader:
outputs = models(imgs)
pred = outputs.argmax(dim=1) # 将概率最大的位置索引作为预测值
corret = pred.eq(labels).sum().float().item()
total_correct += corret
total = len(test_loader.dataset)
acc = total_correct / total
print("test acc: ", acc) # 0.9831
7.其他补充
7.1 绘制手写数字图片 随机挑选6张图片进行打印
def plot_image(x, label, name):
"""
显示6张手写数字图片以及对应的数字标签
"""
for i in range(6):
plt.subplot(2, 3, i + 1)
plt.tight_layout()
plt.imshow(x[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')
plt.title("{}: {}".format(name, label[i].item()))
plt.xticks([])
plt.yticks([])
plt.show()
# ...参照测试代码
plot_image(imgs, pred, 'image sample')