什么是MNIST?
MNIST 是一个手写数字的数据集,包含上万张数字0-9图片,每张图片都提前标注代表的数字标签 每张图片为像素28x28的灰度图像
本文实现
- MNIST 数据集加载与预处理
- CNN 模型定义
- 模型训练与验证
- 保存最佳模型参数
- 导出ONNX
定义卷积神经网络(CNN)模型
class ConvNet(nn.Module):
#卷积层 对图像进行卷积操作,提取图像特征点
#卷积层中的卷积核可以共享参数,即在卷积操作中
#池化层 降低特征图的大小 从而减少模型参数
def __init__(self):
super(ConvNet, self).__init__()
self.conv = nn.Sequential(
# [BATCH_SIZE, 1, 28, 28]
# 输入通道为 1 输出通道为32 卷积核为 5 padding 为 2
nn.Conv2d(1, 32, 5, 1, 2),
# [BATCH_SIZE, 32, 28, 28]
nn.ReLU(),
nn.MaxPool2d(2),
# [BATCH_SIZE, 32, 14, 14]
# 输入通道为 32 输出通道为 64 卷积核为 5 padding 为2
nn.Conv2d(32, 64, 5, 1, 2),
# [BATCH_SIZE, 64, 14, 14]
nn.ReLU(),
nn.MaxPool2d(2),
# [BATCH_SIZE, 64, 7, 7]
)
#全连接层 将卷积层获得的特征向量映射到类别概率
self.fc = nn.Linear(64 * 7 * 7, 10)
def forward(self,x):
x = self.conv(x)
x = x.view(x.size(0), -1)
y = self.fc(x)
return y
加载MNIST数据集(训练集+测试集)
import torch
BATCH_SIZE = 512 if torch.cuda.is_available() else 12
# 加载MNIST数据集(训练集和测试集)
# 训练集
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
# 测试集
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=BATCH_SIZE, shuffle=True)
训练模型
import torch.nn as nn
from torchvision import datasets, transforms
from ConvNet import ConvNet
model = ConvNet().to(device)
optimizer = torch.optim.Adam(model.parameters())
lossf = nn.CrossEntropyLoss()
# 训练 模型
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
print(data.shape)
optimizer.zero_grad() # 梯度清零
output = model(data) # 前向传播
loss = lossf(output, target) # 计算损失
loss.backward() # 反向传播
optimizer.step() # 更新参数
# 每30个batch 打印一次训练信息
if (batch_idx + 1) % 30 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch + 1, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
保存模型与测试函数
def save_model(model, device, test_loader, epoch):
model.eval()
test_loss = 0
correct = 0
if epoch == 0:
global max_acc
max_acc = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += lossf(output, target)
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
test_acc = correct / len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
# 保存最佳模型
if test_acc > max_acc:
torch.save(model.state_dict(), 'model/mnist.pt')
max_acc = test_acc
导出 ONNX 模型
def transform2onnx():
model.load_state_dict(torch.load("model/mnist.pt"))
dummy_input = torch.randn(1, 1, 28, 28).to(device) # 与MINIST 图像尺寸一致
input_names = ["input_0"]
output_names = ["output_0"]
torch.onnx.export(model, dummy_input, 'model/mnist.onnx', verbose=True, input_names=input_names,
output_names=output_names)
调用流程
//训练
for epoch in range(EPOCHS):
train(model, device, train_loader, optimizer, epoch)
save_model(model, device, test_loader, epoch)
// 转ONNX
transform2onnx()
验证模型
import os
import torch
from PIL import Image
from ConvNet import ConvNet
import matplotlib.pyplot as plt
from torchvision import transforms
path = './image/'
images = []
labels = []
for name in sorted(os.listdir(path)):
img = Image.open(path + name).convert('L')
img = transforms.ToTensor()(img)
images.append(img)
labels.append(int(name[0]))
images = torch.stack(images, 0)
print(images.shape)
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model.load_state_dict(torch.load('model/mnist.pt', device))
model.eval()
# 测试模型
with torch.no_grad():
output = model(images)
# %% 打印结果
pred = output.argmax(1)
true = torch.LongTensor(labels)
print(pred)
print(true)
# 绘制
plt.figure(figsize=(10, 4))
for i in range(10):
plt.subplot(2, 5, i + 1)
plt.title(f'pred {pred[i]} | true {true[i]}')`
plt.axis('off')
plt.imshow(images[i].squeeze(0), cmap='gray')
plt.show()
最终结果
项目地址: github.com/WardTN/MNIS…