手写数字识别零基础实战:基于PyTorch的CNN完整拆解

37 阅读9分钟

最近在Kaggle上跑了一个经典的MNIST手写数字识别项目,用PyTorch搭了一个朴素的CNN,效果还不错,准确率能到99%左右。

我现在把整个jupyter notebook 代码贴出来,以供参考:github.com/anjuxi/CNN-…

项目概览

  • 数据集:MNIST(6万训练,1万测试,28×28灰度图)
  • 框架:PyTorch
  • 模型:3层卷积 + 2层全连接,带BatchNorm、Dropout
  • 加速:多GPU并行、混合精度训练
  • 指标:测试集准确率99%+

整个项目就一个jupyter文件,方便调试。


1. 环境与基础配置

import torch;
import torch.nn as nn;
from torch.utils.data import DataLoader;
import torchvision;
import torchvision.transforms as transforms;
import numpy as np;
import matplotlib.pyplot as plt;
import torch.optim as optim;
import time;

首先导入必备的库。PyTorch那一套不用多说,

torchvision帮我们处理MNIST,

matplotlib用来可视化。

# 设置训练时的批次大小(依次传入模型的数据数量)。
batch_size = 1024*2;

batch_size设成2048,这是一个相对激进的选择。一般情况下64、128比较常见,但kaggle给了两张T4并行训练,显存够大。


2. 数据处理

2.1 数据预处理

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
  • ToTensor():把PIL图片转成PyTorch张量,并且像素值从0-255映射到0-1。
  • Normalize:将单通道图像的均值与标准差归一化。MNIST的经验值是均值0.1307,标准差0.3081。这一步很重要,能让模型更容易收敛。

2.2 加载数据集

train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

torchvision.datasets.MNIST会自动下载数据到./data目录。在Kaggle环境下 下载很快。

2.3 构建DataLoader

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)
  • num_workers=4:开启4个子进程加载数据,在Kaggle这种多核环境里能加速IO。
  • pin_memory:启用锁页内存,能加快CPU到GPU的数据传输。
  • persistent_workers:让worker进程在epoch之间保持存活,避免重复创建的开销。

2.4 数据探索

print(f"训练集样本数量: {len(train_dataset)} !");
print(f"测试集样本数量: {len(test_dataset)} !");
print(f"图片尺寸: {train_dataset[0][0].shape} !");
print(f"类别数量: {len(train_dataset.classes)} !");

输出:

训练集样本数量: 60000 !
测试集样本数量: 10000 !
图片尺寸: torch.Size([1, 28, 28]) !
类别数量: 10 !

3. 模型搭建

CNN结构如下:

# 定义卷积神经网络模型类。
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__();
        # 第一层 卷积层:
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1); # 输入1通道(灰度图),输出32通道,卷积核大小3x3。
        # 第一个批归一化层,对32个通道进行归一化。
        self.bn1 = nn.BatchNorm2d(32);
        # ReLU激活函数,引入非线性。
        self.relu1 = nn.ReLU();
        # 第一个最大池化层,池化核大小2x2,步长2。
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2);
        
        # 第二个卷积层:输入32通道,输出64通道。
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1);
        self.bn2 = nn.BatchNorm2d(64);
        self.relu2 = nn.ReLU();
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2);
        
        # 第三个卷积层:输入64通道,输出128通道。
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1);
        self.bn3 = nn.BatchNorm2d(128);
        self.relu3 = nn.ReLU();
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2);
        
        self.dropout = nn.Dropout(0.3);
        self.fc1 = nn.Linear(128*3*3, 256);# 第一个全连接层:输入128*3*3=1152,输出256。
        self.relu4 = nn.ReLU();
        self.dropout2 = nn.Dropout(0.3);
        self.fc2 = nn.Linear(256, 10);# 第二个全连接层:输入256,输出10(对应0-9十个数字类别)。

整个网络由三个卷积块 + 两个全连接层组成。

前向传播

    # 前向传播函数,定义数据流向。
    def forward(self, x): 
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))));# 第一层卷积+批归一化+激活+池化。
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))));
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))));

        # 将特征图展平为一维向量。
        x = x.view(x.size(0), -1);

        x = self.dropout(x);
        x = self.relu4(self.fc1(x));
        x = self.dropout2(x);
        x = self.fc2(x);

        return x;

层叠写法,简洁直观。

多GPU与设备检测

# 多GPU支持。
if torch.cuda.device_count() > 1:
    print(f"使用 {torch.cuda.device_count()} 个GPU进行训练 !");
    model = nn.DataParallel(model);
# 将模型移动到指定设备(GPU或CPU)。
model = model.to(device);

torch.cuda.device_count()检测到2,于是用nn.DataParallel包裹模型,自动做数据并行。最后把模型搬到GPU。


4. 训练配置

4.1 损失函数、优化器、学习率调度

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
  • 交叉熵损失,多分类标配。
  • Adam优化器,初始学习率0.001,能快速收敛。
  • 学习率调度:每10个epoch衰减为原来的0.1倍。因为训练到后期loss几乎不降了,这样做可以让模型微调至更优解。我试过不衰减,最终精度会低0.1%左右。

4.2 混合精度加速

scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

混合精度训练可以让计算速度加快,显存占用减少。GradScaler用来缩放loss防止梯度下溢,一般和autocast搭配使用。


5. 训练与测试函数

5.1 训练函数

# 定义训练函数。
def train(model, train_loader, criterion, optimizer, device, scaler=None):
    model.train();
    running_loss = 0.0;
    correct = 0;
    total = 0;
    
    for batch_idx, (data, target) in enumerate(train_loader): # 遍历训练集数据生成器。
        data, target = data.to(device), target.to(device);
        
        optimizer.zero_grad();
        
        if scaler is not None:
            with torch.cuda.amp.autocast():
                output = model(data);
                loss = criterion(output, target);
            scaler.scale(loss).backward();
            scaler.step(optimizer);
            scaler.update();
        else:
            output = model(data);
            loss = criterion(output, target); # criterion = nn.CrossEntropyLoss();
            loss.backward(); # 反向传播,计算梯度。
            optimizer.step();# 更新参数。
        
        running_loss += loss.item();
        _, predicted = output.max(1);
        total += target.size(0);
        correct += predicted.eq(target).sum().item();
    
    avg_loss = running_loss / len(train_loader);
    accuracy = 100. * correct / total;
    return avg_loss, accuracy;

5.2 测试函数

# 定义测试函数。
def test(model, test_loader, criterion, device):
    # 将模型设置为评估模式。
    model.eval();

    running_loss = 0.0; #累计损失。
    correct = 0; #正确预测数量。 
    total = 0; #总样本数量。
    
    # 不计算梯度,节省内存和计算资源。
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device);
            
            output = model(data);
        
            loss = criterion(output, target);

            running_loss += loss.item(); # 累加损失值。
            # 获取预测结果。
            _, predicted = output.max(1);

            total += target.size(0); # 累加总样本数。
            correct += predicted.eq(target).sum().item(); # 累加正确预测数。
    
    # 计算平均损失。
    avg_loss = running_loss / len(test_loader);
    # 计算准确率。
    accuracy = 100. * correct / total;
    # 返回平均损失和准确率。
    return avg_loss, accuracy;

6. 训练循环与保存最佳模型

%%time

# 设置训练轮数。
num_epochs = 50;

# 创建列表用于存储训练历史。
train_losses = [];
train_accuracies = [];
test_losses = [];
test_accuracies = [];

best_accuracy = 0.0;

# 损失函数(交叉熵损失)。
criterion = nn.CrossEntropyLoss();
# 优化器(Adam优化器,学习率0.001)。
optimizer = optim.Adam(model.parameters(), lr=0.001);
# 学习率调度器(每10个epoch将学习率乘以0.1)。
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1);

# 混合精度训练scaler。
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None;

# 开始训练循环。
for epoch in range(num_epochs):
    epoch_start = time.time();# 計時。
    
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, scaler);
    test_loss, test_acc = test(model, test_loader, criterion, device);
    
    scheduler.step();
    
    train_losses.append(train_loss);
    train_accuracies.append(train_acc);
    test_losses.append(test_loss);
    test_accuracies.append(test_acc);
    
    epoch_time = time.time() - epoch_start;
    
    print(f"轮次: [{epoch+1}/{num_epochs}] "
          f"训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}% "
          f"测试损失: {test_loss:.4f}, 测试准确率: {test_acc:.2f}% "
          f"耗时: {epoch_time:.2f}s !");
    
    # 保存最佳模型。
    if test_acc > best_accuracy:
        best_accuracy = test_acc;
        if isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), 'mnist_cnn_best.pth');
        else:
            torch.save(model.state_dict(), 'mnist_cnn_best.pth');
        print(f"✅最佳模型已保存 (准确率: {best_accuracy:.2f}%) !");

训练输出大致如下:

轮次: [1/50] 训练损失: 0.5928, 训练准确率: 81.79% 测试损失: 0.0999, 测试准确率: 96.99% 耗时: 7.24s !
✅最佳模型已保存 (准确率: 96.99%) !
轮次: [2/50] 训练损失: 0.0967, 训练准确率: 97.10% 测试损失: 0.0454, 测试准确率: 98.53% 耗时: 6.04s !
✅最佳模型已保存 (准确率: 98.53%) !
轮次: [3/50] 训练损失: 0.0653, 训练准确率: 98.05% 测试损失: 0.0529, 测试准确率: 98.41% 耗时: 6.20s !
轮次: [4/50] 训练损失: 0.0516, 训练准确率: 98.42% 测试损失: 0.0375, 测试准确率: 98.71% 耗时: 5.98s !
✅最佳模型已保存 (准确率: 98.71%) !
......

可以看到第1个epoch后测试准确率就有80%+,说明模型学习能力很强。

7. 训练过程可视化

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
# ... 设置标签、图例 ...
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
# ...

画了loss和accuracy的曲线图。从图里明显看出:训练loss平滑下降,测试loss在10epoch左右基本走平,之后学习率衰减让loss又小降了一点。准确率曲线也很漂亮,训练和测试的差距很小,说明过拟合控制得不错。


8. 模型保存与加载

if isinstance(model, nn.DataParallel):
    torch.save(model.module.state_dict(), 'mnist_cnn_model.pth')
else:
    torch.save(model.state_dict(), 'mnist_cnn_model.pth')

保存最终模型(不一定最佳)。后面如果要用,可以这样加载:

model = CNN()
model.load_state_dict(torch.load('mnist_cnn_model.pth'))
model = model.to(device)
model.eval()

9. 预测结果可视化

def visualize_predictions(model, test_loader, device, num_images=16):
    model.eval()
    images, labels = next(iter(test_loader))
    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)
        _, predicted = outputs.max(1)
    images = images.cpu()
    predicted = predicted.cpu()
    # 画4x4网格
    ...

随机取了一个batch的前16张图,画成4×4网格,真实标签用黑色,预测标签用绿色(正确)或红色(错误)。


10. 混淆矩阵与分类报告

from sklearn.metrics import confusion_matrix, classification_report

遍历整个测试集,收集所有预测结果和真实标签,然后用confusion_matrix生成矩阵,再用plt.imshow画热力图。主对角线很亮,其他位置基本没啥数字,说明模型各个类别都区分得很好。

分类报告:

              precision    recall  f1-score   support

           0       0.99      1.00      0.99       980
           1       1.00      1.00      1.00      1135
           2       0.99      1.00      1.00      1032
           3       1.00      1.00      1.00      1010
           4       0.99      0.99      0.99       982
           5       0.99      0.99      0.99       892
           6       1.00      0.99      0.99       958
           7       0.99      0.99      0.99      1028
           8       1.00      0.99      1.00       974
           9       0.99      0.98      0.99      1009

    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000

所有类别的F1都在0.99左右,模型很均衡。


11. 单个图片预测测试

predict_single_image函数随机抽5张测试图,打印真实标签、预测标签和置信度,并给出Top-3预测概率。

def predict_single_image(model, image, device):
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        output = model(image)
        probabilities = torch.softmax(output, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
        return predicted.item(), confidence.item(), probabilities.cpu().numpy()[0]

输出示例:


圖片 8649:
  真實標籤: 8
  預測標籤: 8
  置信度: 1.0000
  預測正確: 
  Top-10 預測:
    1. 數字 8: 1.0000
    2. 數字 9: 0.0000
    3. 數字 3: 0.0000
    4. 數字 5: 0.0000
    5. 數字 2: 0.0000
    6. 數字 6: 0.0000
    7. 數字 0: 0.0000
    8. 數字 4: 0.0000
    9. 數字 7: 0.0000
    10. 數字 1: 0.0000

圖片 9388:
  真實標籤: 1
  預測標籤: 1
  置信度: 1.0000
  預測正確: 
  Top-10 預測:
    1. 數字 1: 1.0000
    2. 數字 7: 0.0000
    3. 數字 4: 0.0000
    4. 數字 2: 0.0000
    5. 數字 8: 0.0000
    6. 數字 0: 0.0000
    7. 數字 5: 0.0000
    8. 數字 9: 0.0000
    9. 數字 6: 0.0000
    10. 數字 3: 0.0000

圖片 6940:
  真實標籤: 3
  預測標籤: 3
  置信度: 0.9999
  預測正確: 
  Top-10 預測:
    1. 數字 3: 0.9999
    2. 數字 5: 0.0001
    3. 數字 2: 0.0000
    4. 數字 9: 0.0000
    5. 數字 7: 0.0000
    6. 數字 1: 0.0000
    7. 數字 8: 0.0000
    8. 數字 6: 0.0000
    9. 數字 4: 0.0000
    10. 數字 0: 0.0000

圖片 7629:
  真實標籤: 4
  預測標籤: 4
  置信度: 1.0000
  預測正確: 
  Top-10 預測:
    1. 數字 4: 1.0000
    2. 數字 9: 0.0000
    3. 數字 7: 0.0000
    4. 數字 1: 0.0000
    5. 數字 5: 0.0000
    6. 數字 6: 0.0000
    7. 數字 0: 0.0000
    8. 數字 2: 0.0000
    9. 數字 8: 0.0000
    10. 數字 3: 0.0000

圖片 6803:
  真實標籤: 5
  預測標籤: 5
  置信度: 1.0000
  預測正確: 
  Top-10 預測:
    1. 數字 5: 1.0000
    2. 數字 3: 0.0000
    3. 數字 9: 0.0000
    4. 數字 8: 0.0000
    5. 數字 6: 0.0000
    6. 數字 1: 0.0000
    7. 數字 7: 0.0000
    8. 數字 0: 0.0000
    9. 數字 4: 0.0000
    10. 數字 2: 0.0000


本文首发于掘金,作者Ailan Anjuxi,转载请注明出处!