基于PyTorch的MLP模型的MNIST手写数字识别CIFAR10方案代码

489 阅读9分钟

         本文基于PyTorch框架,采用使用PyTorch的nn.Module模块定义多层感知机(MLP)模型实现MNIST手写数字识别,在GPU上运行,实现高达98%的测试准确率,并完整展示从数据加载到模型部署的全流程。

       PyTorch的模型大致结构普遍相似,也可修改相关参数的更换为其他简单图像分类任务,实测在CIFAR10数据集任务上分类准确率在60%左右,较为复杂的图像需要考虑使用CNN框架,或者YOLO模型会处理较好,实测在使用Resnet50分类冻结部分层情况下CIFAR10任务精度可达到95%以上。

目录

一、MNIST数据集

1.1 数据概述

1.2  数据集构成

二、导入必要库

 2.2 数据集准备

二、可视化部分图例

2.1 批次样本可视化

三、模型架构设计

3.1 MLP网络结构

四、模型训练模块

4.1 训练配置

4.2 训练流程

五、可视化损失,准确度

六、可视化预测结果

一、MNIST数据集

1.1 数据概述

        MNIST数据集是深度学习领域中广泛使用的基准测试数据集,常被称为“Hello World”项目,是许多研究者和学习者进入深度学习领域的起点。MNIST数据集最初由Yann LeCun等研究人员开发,目的是为了方便测试不同的机器学习算法的性能,这个数据集主要用于手写数字识别任务,包含了大量的手写数字图像及其对应的标签。

1.2  数据集构成

        MNIST数据集包含60,000张训练图像和10,000张测试图像,每张图像是28x28像素的灰度图像,代表数字0至9。

导入必要库

        首先,代码导入了必要的Python库,如torchtorchvisionmatplotlib等,用于深度学习模型的构建、数据处理和可视化。接着,检查CUDA是否可用,如果可用则使用GPU进行计算,代码如下:

    import torch
    import time
    import torchvision
    import torch.optim  as optim
    from torchvision import datasets, transforms
    from torch.utils.data  import DataLoader
    import torch.nn  as nn 
    import matplotlib.pyplot as plt
    print(torch.cuda.is_available())
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import torch
import time
import torchvision
import torch.optim  as optim
from torchvision import datasets, transforms
from torch.utils.data  import DataLoader
import torch.nn  as nn 
import matplotlib.pyplot as plt
print(torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

2.2 数据集准备

        定义了数据预处理操作,将图像转换为张量,并设置了批量大小。通过torchvision.datasets.MNIST 下载MNIST训练集和测试集,并使用DataLoader创建数据加载器,方便后续批量处理数据:

#加载数据
transform=transforms.Compose([
    transforms.ToTensor(),
])
batch_size=64
#下载数据集
Train_datast = datasets.MNIST(root='./',train=True,download=True,transform=transform)
test_datast = datasets.MNIST(root='./',train=False,download=True,transform=transform)

train_loader=DataLoader(dataset=Train_datast,batch_size=batch_size,shuffle=True,drop_last=True)
test_loader=DataLoader(dataset=test_datast,batch_size=batch_size,shuffle=True,drop_last=True)
# 打印数据集信息 
print(f"训练集大小: {len(Train_datast)}") 
print(f"测试集大小: {len(test_datast)}")

二、可视化部分图例

2.1 批次样本可视化

        代码从训练数据加载器中取出一个批次的图像和标签,使用torchvision.utils.make_grid 函数将这些图像拼接成一个网格,并使用matplotlib进行可视化展示,以直观地查看数据集的样本情况:

   # 构建图像网格
    batch_images, batch_labels = next(iter(train_loader)) 
    grid_img = torchvision.utils.make_grid( 
        tensor=batch_images,   # 输入张量 [B, C, H, W]
        nrow=8,               # 每行显示图片数(根据批次大小调整)
        padding=2,            # 图片间距 

    )
    print(grid_img.shape)
    print(grid_img.permute(1,  2, 0).shape)
    # 可视化 
    plt.figure(figsize=(12,  6))
    plt.imshow(grid_img.permute(1,  2, 0).squeeze(), cmap='gray')  # 维度调整 
    plt.axis('off') 
    plt.title(f"Batch  Size: {len(batch_images)} | Label Range: {batch_labels.min()}-{batch_labels.max()}") 
    plt.show() 

运行结果如下

三、模型架构设计

3.1 MLP网络结构

        定义了一个多层感知机(MLP)类MLP,继承自nn.Module。该模型包含多个全连接层,中间使用了批量归一化(nn.BatchNorm1d)、ReLU激活函数(nn.ReLU)和Dropout正则化(nn.Dropout),以提高模型的泛化能力。最后一层输出维度为10,对应MNIST数据集中的10个数字类别:

#定义网络架构
    class MLP(nn.Module):
        def __init__(self,hidden_size,height,width):
            super(MLP,self).__init__()
            
            self.layers  = nn.Sequential(
                nn.Flatten(),
                nn.Linear(height*width, hidden_size),
                nn.BatchNorm1d(hidden_size),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_size, hidden_size//2),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.Linear(hidden_size//2, hidden_size//4),
                nn.ReLU(),
                nn.Dropout(0.2),
                nn.BatchNorm1d(hidden_size//4),
                nn.Linear(hidden_size//4, 10),   
            )
        def forward(self, x):
            return self.layers(x) 
    # 训练参数
    batch_num,height,width=batch_images[0].shape
    model=MLP(hidden_size=4096,height=height,width=width).to(device)
    print(batch_images[0].shape)
    print(height*width)
    torch.flatten(batch_images[0]).shape

四、模型训练模块

4.1 训练配置

参数作用说明
优化器AdamW结合权重衰减的正则化优化
学习率0.001平衡收敛速度与稳定性
权重衰减1e-4控制模型复杂度
Early Stopping10轮防止过拟合

4.2 训练流程

        设置训练参数,包括训练轮数(epochs)、初始最优准确率(best_accuracy)等。在每个训练轮次中,模型处于训练模式(model.train() ),遍历训练数据加载器,进行前向传播、损失计算(使用交叉熵损失函数criterion)、反向传播和参数更新(使用优化器optimizer,代码中未定义,需补充)。在验证阶段,模型处于评估模式(model.eval() ),遍历测试数据加载器,计算测试集的损失和准确率。同时,记录每个轮次的训练损失、测试损失、训练准确率和测试准确率,并更新最优模型:

# 训练参数
    epochs = 100
    best_accuracy = 0.0
    correct_train=0.0
    loss_train=0
    loss_train_list=[]
    loss_test_list=[]
    train_acc_list = []
    test_acc_list = []
    for epoch in range(epochs):
        since = time.time()
        model.train() 
        print('-' * 20)
        print(f'Epoch {epoch+1}/{epochs}')
        
        total_loss_train=0
        correct_train=0
        k_test=0
        k_train=0
        for images, labels in train_loader:
            images = images.to(device) 
            labels = labels.to(device)
            # 前向传播
            outputs = model(images)
            _, predicted_train = torch.max(outputs,  dim=1)
            loss_train = criterion(outputs, labels)
            
            correct_train += (predicted_train == labels).sum().item()
            # 反向传播和优化
            optimizer.zero_grad() 
            loss_train.backward() 
            optimizer.step() 
            total_loss_train+=loss_train.item() 
            k_train+=1
       # 验证阶段
        model.eval() 
        correct_test = 0
        loss_test = []
        total_loss_test=0
        with torch.no_grad(): 
            for images, labels in test_loader:
                images = images.to(device) 
                labels = labels.to(device) 
                outputs = model(images)
                values_test, predicted_test = torch.max(outputs,  1)
                loss_test =criterion(outputs, labels) 
                total_loss_test+=loss_test.item() 
                correct_test += (predicted_test == labels).sum().item()
                k_test+=1
        #更新保存损失,准确率
        current_train_acc=correct_train/(k_train*batch_size)
        current_test_acc = correct_test/(k_test*batch_size)
        train_acc_list.append(current_train_acc)
        test_acc_list.append(current_test_acc)
        loss_train_list.append(total_loss_train)
        loss_test_list.append(total_loss_test)


        # 实时更新最优模型
        if current_test_acc >= best_accuracy:
            best_accuracy = current_test_acc  
            epochs_since_improvement = 0
            torch.save(model, 'best_model.pkl')
        else:
            epochs_since_improvement += 1
        time_elapsed = time.time() - since
        print(f'训练集交叉熵损失为:{total_loss_train}   测试集交叉熵损失为:{total_loss_test}')
        print(f'训练集准确率为:{current_train_acc}   测试集准确率为:{current_test_acc}')
        print('本轮用时:{:.0f}m {:.0f}s  epochs_since_improvement:{}'.format(time_elapsed // 60, time_elapsed % 60,epochs_since_improvement))
        patience=10
        if epochs_since_improvement >= patience:
            print('-'*20)
            print('测试集损失在',patience, '轮内没有提升,训练结束')
            print('停止在', epoch-patience)
            print(f"最优准确率为:{best_accuracy}")
            time_elapsed = time.time() - since
            print('Training complete in {:.0f}m {:.0f}s '.format(time_elapsed // 60, time_elapsed % 60))
            break

如果在连续patience轮中测试集准确率没有提升,则提前停止训练。 运行结果

五、可视化损失,准确度

        使用matplotlib创建复合图表,分别绘制训练过程中的损失曲线和准确率曲线,直观展示模型的训练效果和泛化能力:

import matplotlib.pyplot  as plt 
    import numpy as np 
     
    # 配置全局绘图参数 
    plt.rcParams['font.sans-serif']  = ['SimHei']  # 支持中文显示 
    plt.rcParams['axes.unicode_minus']  = False    # 解决负号显示异常 
                 # 使用专业主题 
     
    # 创建复合图表 
    fig, (ax1, ax2) = plt.subplots(1,  2, figsize=(18, 6), dpi=100)
     
    # 子图1:损失曲线 
    ax1.plot(np.array(loss_test_list)/batch_size,  'orange', lw=1.5, alpha=0.8, label='测试损失')
    ax1.set_title(' 损失曲线演进分析', fontsize=14, pad=12)
    ax1.set_xlabel(' 训练轮次', fontsize=12)
    ax1.set_ylabel(' 标准化损失值', fontsize=12)
    ax1.legend(frameon=True,  shadow=True, loc='upper right')
     
    # 子图2:准确率曲线 
    ax2.plot(train_acc_list,  'b-', lw=1.5, alpha=0.8, label='训练准确率')
    ax2.plot(test_acc_list,  'orange', lw=1.5, alpha=0.8, label='测试准确率')
    ax2.set_title(' 准确率发展轨迹', fontsize=14, pad=12)
    ax2.set_xlabel(' 训练轮次', fontsize=12)
    ax2.set_ylabel(' 分类准确率', fontsize=12)
    ax2.set_ylim(min(train_acc_list), 1)  # 固定准确率范围 
    ax2.legend(frameon=True,  shadow=True, loc='lower right')
     
    # 添加全局注释 
    plt.suptitle(f'MNIST 分类器训练监控报告 (最优准确率: {best_accuracy:.2%})', 
                y=1.02, fontsize=16, color='#2c3e50')
     

    # 保存与显示 
    plt.tight_layout() 
    plt.savefig('MNIST_training_analysis_MLP.png',  bbox_inches='tight', dpi=300)
    plt.show() 

 结果可视化,准确率在98.6%左右

六、可视化预测结果

        加载保存的最优模型,从测试数据加载器中取出一个批次的图像和标签,使用模型进行预测。将预测结果和真实标签转换为NumPy数组,并使用matplotlib创建图像网格,将预测结果和真实标签标注在每个图像上,直观展示模型的预测效果:

 #加载模型产生预测结果
    best_model=torch.load('./best_model.pkl',weights_only=False,map_location='cuda:0')
    best_model.eval()
    batch_images, batch_labels = next(iter(test_loader)) 
    batch_images=batch_images.to(device)
    values_test, predicted_test = torch.max(best_model(batch_images),  1)
    #转移到cpu numpy格式
    batch_labels=batch_labels.cpu().numpy() 
    predicted_test=predicted_test.cpu().numpy() 
    # print(f'实际标签:{batch_labels}')
    # print(f'预测标签:{predicted_test}')
    # 创建可视化图像网格

    nrow =  16# 与make_grid中的nrow一致
    ncol = batch_size // nrow  # 计算列数
    fig, axes = plt.subplots(nrow, ncol, figsize=(15, 15))
    axes = axes.flatten()
    for i in range(batch_size):
        # 处理图像维度 [C, H, W] -> [H, W]
        img = batch_images[i].cpu().permute(1, 2, 0).squeeze().numpy()
        #显示图像
        axes[i].imshow(img, cmap='gray')
        #隐藏坐标轴
        axes[i].axis('off')
        # 设置标题颜色和内容
        pred = predicted_test[i]
        true = batch_labels[i]
        color = 'green' if pred == true else 'red'
        axes[i].set_title(f'P:{pred} | T:{true}', color=color, fontsize=8)

    # 隐藏多余的子图(如果有)
    for j in range(batch_size, len(axes)):
        axes[j].axis('off')
    plt.tight_layout()
    plt.suptitle(f"Batch Size: {batch_size} | Label Range: {batch_labels.min()}-{batch_labels.max()}", y=1.02)
    plt.show()


运行结果

  • 绿色标注表示预测正确
  • 红色标注表示预测错误
  • 典型错误案例:书写模糊/非常规写法