MNIST 图像分类实战1-训练模型

217 阅读3分钟

什么是MNIST?

MNIST 是一个手写数字的数据集,包含上万张数字0-9图片,每张图片都提前标注代表的数字标签 每张图片为像素28x28的灰度图像

企业微信截图_1744942075964.png

本文实现

  1. MNIST 数据集加载与预处理
  2. CNN 模型定义
  3. 模型训练与验证
  4. 保存最佳模型参数
  5. 导出ONNX

定义卷积神经网络(CNN)模型

class ConvNet(nn.Module):
    #卷积层 对图像进行卷积操作,提取图像特征点
    #卷积层中的卷积核可以共享参数,即在卷积操作中
    #池化层 降低特征图的大小 从而减少模型参数

    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv = nn.Sequential(
            # [BATCH_SIZE, 1, 28, 28]
            # 输入通道为 1 输出通道为32 卷积核为 5 padding 为 2
            nn.Conv2d(1, 32, 5, 1, 2),
            # [BATCH_SIZE, 32, 28, 28]
            nn.ReLU(),
            nn.MaxPool2d(2),
            # [BATCH_SIZE, 32, 14, 14]
            # 输入通道为 32 输出通道为 64 卷积核为 5 padding 为2
            nn.Conv2d(32, 64, 5, 1, 2),
            # [BATCH_SIZE, 64, 14, 14]
            nn.ReLU(),
            nn.MaxPool2d(2),
            # [BATCH_SIZE, 64, 7, 7]
        )

        #全连接层 将卷积层获得的特征向量映射到类别概率

        self.fc = nn.Linear(64 * 7 * 7, 10)
    def forward(self,x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        y = self.fc(x)
        return y

加载MNIST数据集(训练集+测试集)

     import torch
    
    BATCH_SIZE = 512 if torch.cuda.is_available() else 12
    
    # 加载MNIST数据集(训练集和测试集)
    # 训练集
    train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True)
    
    # 测试集
    test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=BATCH_SIZE, shuffle=True)
    

训练模型

    import torch.nn as nn
    from torchvision import datasets, transforms
    from ConvNet import ConvNet
    
    model = ConvNet().to(device)
    optimizer = torch.optim.Adam(model.parameters())
    lossf = nn.CrossEntropyLoss()
    
    
    # 训练 模型
    def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        print(data.shape)
        optimizer.zero_grad() # 梯度清零
        output = model(data) # 前向传播
        loss = lossf(output, target) # 计算损失
        loss.backward() # 反向传播
        optimizer.step() # 更新参数
        # 每30个batch 打印一次训练信息
        if (batch_idx + 1) % 30 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch + 1, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    

保存模型与测试函数

def save_model(model, device, test_loader, epoch):
    model.eval()
    test_loss = 0
    correct = 0

    if epoch == 0:
        global max_acc
        max_acc = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += lossf(output, target)
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc = correct / len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    # 保存最佳模型
    if test_acc > max_acc:
        torch.save(model.state_dict(), 'model/mnist.pt')
        max_acc = test_acc

导出 ONNX 模型

def transform2onnx():
    model.load_state_dict(torch.load("model/mnist.pt"))
    dummy_input = torch.randn(1, 1, 28, 28).to(device) # 与MINIST 图像尺寸一致
    input_names = ["input_0"]
    output_names = ["output_0"]
    torch.onnx.export(model, dummy_input, 'model/mnist.onnx', verbose=True, input_names=input_names,
                      output_names=output_names)
    

调用流程

    //训练
    for epoch in range(EPOCHS):
        train(model, device, train_loader, optimizer, epoch)
        save_model(model, device, test_loader, epoch)
    
    // 转ONNX
    transform2onnx()

验证模型

import os
import torch
from PIL import Image
from ConvNet import ConvNet
import matplotlib.pyplot as plt
from torchvision import transforms

path = './image/'
images = []
labels = []

for name in sorted(os.listdir(path)):
    img = Image.open(path + name).convert('L')
    img = transforms.ToTensor()(img)
    images.append(img)
    labels.append(int(name[0]))
images = torch.stack(images, 0)
print(images.shape)

#  加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ConvNet()
model.load_state_dict(torch.load('model/mnist.pt', device))
model.eval()

#  测试模型
with torch.no_grad():
    output = model(images)

# %% 打印结果
pred = output.argmax(1)
true = torch.LongTensor(labels)
print(pred)
print(true)

#  绘制
plt.figure(figsize=(10, 4))
for i in range(10):
    plt.subplot(2, 5, i + 1)
    plt.title(f'pred {pred[i]} | true {true[i]}')`
    plt.axis('off')
    plt.imshow(images[i].squeeze(0), cmap='gray')
plt.show()

最终结果

企业微信截图_17449441732382.png

项目地址: github.com/WardTN/MNIS…