最近在Kaggle上跑了一个经典的MNIST手写数字识别项目,用PyTorch搭了一个朴素的CNN,效果还不错,准确率能到99%左右。
我现在把整个jupyter notebook 代码贴出来,以供参考:github.com/anjuxi/CNN-…
项目概览
- 数据集:MNIST(6万训练,1万测试,28×28灰度图)
- 框架:PyTorch
- 模型:3层卷积 + 2层全连接,带BatchNorm、Dropout
- 加速:多GPU并行、混合精度训练
- 指标:测试集准确率99%+
整个项目就一个jupyter文件,方便调试。
1. 环境与基础配置
import torch;
import torch.nn as nn;
from torch.utils.data import DataLoader;
import torchvision;
import torchvision.transforms as transforms;
import numpy as np;
import matplotlib.pyplot as plt;
import torch.optim as optim;
import time;
首先导入必备的库。PyTorch那一套不用多说,
torchvision帮我们处理MNIST,
matplotlib用来可视化。
# 设置训练时的批次大小(依次传入模型的数据数量)。
batch_size = 1024*2;
把batch_size设成2048,这是一个相对激进的选择。一般情况下64、128比较常见,但kaggle给了两张T4并行训练,显存够大。
2. 数据处理
2.1 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
ToTensor():把PIL图片转成PyTorch张量,并且像素值从0-255映射到0-1。Normalize:将单通道图像的均值与标准差归一化。MNIST的经验值是均值0.1307,标准差0.3081。这一步很重要,能让模型更容易收敛。
2.2 加载数据集
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform
)
torchvision.datasets.MNIST会自动下载数据到./data目录。在Kaggle环境下 下载很快。
2.3 构建DataLoader
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
persistent_workers=True
)
test_loader = DataLoader(
test_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
persistent_workers=True
)
num_workers=4:开启4个子进程加载数据,在Kaggle这种多核环境里能加速IO。pin_memory:启用锁页内存,能加快CPU到GPU的数据传输。persistent_workers:让worker进程在epoch之间保持存活,避免重复创建的开销。
2.4 数据探索
print(f"训练集样本数量: {len(train_dataset)} !");
print(f"测试集样本数量: {len(test_dataset)} !");
print(f"图片尺寸: {train_dataset[0][0].shape} !");
print(f"类别数量: {len(train_dataset.classes)} !");
输出:
训练集样本数量: 60000 !
测试集样本数量: 10000 !
图片尺寸: torch.Size([1, 28, 28]) !
类别数量: 10 !
3. 模型搭建
CNN结构如下:
# 定义卷积神经网络模型类。
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__();
# 第一层 卷积层:
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1); # 输入1通道(灰度图),输出32通道,卷积核大小3x3。
# 第一个批归一化层,对32个通道进行归一化。
self.bn1 = nn.BatchNorm2d(32);
# ReLU激活函数,引入非线性。
self.relu1 = nn.ReLU();
# 第一个最大池化层,池化核大小2x2,步长2。
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2);
# 第二个卷积层:输入32通道,输出64通道。
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1);
self.bn2 = nn.BatchNorm2d(64);
self.relu2 = nn.ReLU();
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2);
# 第三个卷积层:输入64通道,输出128通道。
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1);
self.bn3 = nn.BatchNorm2d(128);
self.relu3 = nn.ReLU();
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2);
self.dropout = nn.Dropout(0.3);
self.fc1 = nn.Linear(128*3*3, 256);# 第一个全连接层:输入128*3*3=1152,输出256。
self.relu4 = nn.ReLU();
self.dropout2 = nn.Dropout(0.3);
self.fc2 = nn.Linear(256, 10);# 第二个全连接层:输入256,输出10(对应0-9十个数字类别)。
整个网络由三个卷积块 + 两个全连接层组成。
前向传播
# 前向传播函数,定义数据流向。
def forward(self, x):
x = self.pool1(self.relu1(self.bn1(self.conv1(x))));# 第一层卷积+批归一化+激活+池化。
x = self.pool2(self.relu2(self.bn2(self.conv2(x))));
x = self.pool3(self.relu3(self.bn3(self.conv3(x))));
# 将特征图展平为一维向量。
x = x.view(x.size(0), -1);
x = self.dropout(x);
x = self.relu4(self.fc1(x));
x = self.dropout2(x);
x = self.fc2(x);
return x;
层叠写法,简洁直观。
多GPU与设备检测
# 多GPU支持。
if torch.cuda.device_count() > 1:
print(f"使用 {torch.cuda.device_count()} 个GPU进行训练 !");
model = nn.DataParallel(model);
# 将模型移动到指定设备(GPU或CPU)。
model = model.to(device);
torch.cuda.device_count()检测到2,于是用nn.DataParallel包裹模型,自动做数据并行。最后把模型搬到GPU。
4. 训练配置
4.1 损失函数、优化器、学习率调度
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
- 交叉熵损失,多分类标配。
- Adam优化器,初始学习率0.001,能快速收敛。
- 学习率调度:每10个epoch衰减为原来的0.1倍。因为训练到后期loss几乎不降了,这样做可以让模型微调至更优解。我试过不衰减,最终精度会低0.1%左右。
4.2 混合精度加速
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None
混合精度训练可以让计算速度加快,显存占用减少。GradScaler用来缩放loss防止梯度下溢,一般和autocast搭配使用。
5. 训练与测试函数
5.1 训练函数
# 定义训练函数。
def train(model, train_loader, criterion, optimizer, device, scaler=None):
model.train();
running_loss = 0.0;
correct = 0;
total = 0;
for batch_idx, (data, target) in enumerate(train_loader): # 遍历训练集数据生成器。
data, target = data.to(device), target.to(device);
optimizer.zero_grad();
if scaler is not None:
with torch.cuda.amp.autocast():
output = model(data);
loss = criterion(output, target);
scaler.scale(loss).backward();
scaler.step(optimizer);
scaler.update();
else:
output = model(data);
loss = criterion(output, target); # criterion = nn.CrossEntropyLoss();
loss.backward(); # 反向传播,计算梯度。
optimizer.step();# 更新参数。
running_loss += loss.item();
_, predicted = output.max(1);
total += target.size(0);
correct += predicted.eq(target).sum().item();
avg_loss = running_loss / len(train_loader);
accuracy = 100. * correct / total;
return avg_loss, accuracy;
5.2 测试函数
# 定义测试函数。
def test(model, test_loader, criterion, device):
# 将模型设置为评估模式。
model.eval();
running_loss = 0.0; #累计损失。
correct = 0; #正确预测数量。
total = 0; #总样本数量。
# 不计算梯度,节省内存和计算资源。
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device);
output = model(data);
loss = criterion(output, target);
running_loss += loss.item(); # 累加损失值。
# 获取预测结果。
_, predicted = output.max(1);
total += target.size(0); # 累加总样本数。
correct += predicted.eq(target).sum().item(); # 累加正确预测数。
# 计算平均损失。
avg_loss = running_loss / len(test_loader);
# 计算准确率。
accuracy = 100. * correct / total;
# 返回平均损失和准确率。
return avg_loss, accuracy;
6. 训练循环与保存最佳模型
%%time
# 设置训练轮数。
num_epochs = 50;
# 创建列表用于存储训练历史。
train_losses = [];
train_accuracies = [];
test_losses = [];
test_accuracies = [];
best_accuracy = 0.0;
# 损失函数(交叉熵损失)。
criterion = nn.CrossEntropyLoss();
# 优化器(Adam优化器,学习率0.001)。
optimizer = optim.Adam(model.parameters(), lr=0.001);
# 学习率调度器(每10个epoch将学习率乘以0.1)。
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1);
# 混合精度训练scaler。
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None;
# 开始训练循环。
for epoch in range(num_epochs):
epoch_start = time.time();# 計時。
train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, scaler);
test_loss, test_acc = test(model, test_loader, criterion, device);
scheduler.step();
train_losses.append(train_loss);
train_accuracies.append(train_acc);
test_losses.append(test_loss);
test_accuracies.append(test_acc);
epoch_time = time.time() - epoch_start;
print(f"轮次: [{epoch+1}/{num_epochs}] "
f"训练损失: {train_loss:.4f}, 训练准确率: {train_acc:.2f}% "
f"测试损失: {test_loss:.4f}, 测试准确率: {test_acc:.2f}% "
f"耗时: {epoch_time:.2f}s !");
# 保存最佳模型。
if test_acc > best_accuracy:
best_accuracy = test_acc;
if isinstance(model, nn.DataParallel):
torch.save(model.module.state_dict(), 'mnist_cnn_best.pth');
else:
torch.save(model.state_dict(), 'mnist_cnn_best.pth');
print(f"✅最佳模型已保存 (准确率: {best_accuracy:.2f}%) !");
训练输出大致如下:
轮次: [1/50] 训练损失: 0.5928, 训练准确率: 81.79% 测试损失: 0.0999, 测试准确率: 96.99% 耗时: 7.24s !
✅最佳模型已保存 (准确率: 96.99%) !
轮次: [2/50] 训练损失: 0.0967, 训练准确率: 97.10% 测试损失: 0.0454, 测试准确率: 98.53% 耗时: 6.04s !
✅最佳模型已保存 (准确率: 98.53%) !
轮次: [3/50] 训练损失: 0.0653, 训练准确率: 98.05% 测试损失: 0.0529, 测试准确率: 98.41% 耗时: 6.20s !
轮次: [4/50] 训练损失: 0.0516, 训练准确率: 98.42% 测试损失: 0.0375, 测试准确率: 98.71% 耗时: 5.98s !
✅最佳模型已保存 (准确率: 98.71%) !
......
可以看到第1个epoch后测试准确率就有80%+,说明模型学习能力很强。
7. 训练过程可视化
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
# ... 设置标签、图例 ...
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
# ...
画了loss和accuracy的曲线图。从图里明显看出:训练loss平滑下降,测试loss在10epoch左右基本走平,之后学习率衰减让loss又小降了一点。准确率曲线也很漂亮,训练和测试的差距很小,说明过拟合控制得不错。
8. 模型保存与加载
if isinstance(model, nn.DataParallel):
torch.save(model.module.state_dict(), 'mnist_cnn_model.pth')
else:
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
保存最终模型(不一定最佳)。后面如果要用,可以这样加载:
model = CNN()
model.load_state_dict(torch.load('mnist_cnn_model.pth'))
model = model.to(device)
model.eval()
9. 预测结果可视化
def visualize_predictions(model, test_loader, device, num_images=16):
model.eval()
images, labels = next(iter(test_loader))
images = images.to(device)
with torch.no_grad():
outputs = model(images)
_, predicted = outputs.max(1)
images = images.cpu()
predicted = predicted.cpu()
# 画4x4网格
...
随机取了一个batch的前16张图,画成4×4网格,真实标签用黑色,预测标签用绿色(正确)或红色(错误)。
10. 混淆矩阵与分类报告
from sklearn.metrics import confusion_matrix, classification_report
遍历整个测试集,收集所有预测结果和真实标签,然后用confusion_matrix生成矩阵,再用plt.imshow画热力图。主对角线很亮,其他位置基本没啥数字,说明模型各个类别都区分得很好。
分类报告:
precision recall f1-score support
0 0.99 1.00 0.99 980
1 1.00 1.00 1.00 1135
2 0.99 1.00 1.00 1032
3 1.00 1.00 1.00 1010
4 0.99 0.99 0.99 982
5 0.99 0.99 0.99 892
6 1.00 0.99 0.99 958
7 0.99 0.99 0.99 1028
8 1.00 0.99 1.00 974
9 0.99 0.98 0.99 1009
accuracy 0.99 10000
macro avg 0.99 0.99 0.99 10000
weighted avg 0.99 0.99 0.99 10000
所有类别的F1都在0.99左右,模型很均衡。
11. 单个图片预测测试
predict_single_image函数随机抽5张测试图,打印真实标签、预测标签和置信度,并给出Top-3预测概率。
def predict_single_image(model, image, device):
model.eval()
with torch.no_grad():
image = image.unsqueeze(0).to(device)
output = model(image)
probabilities = torch.softmax(output, dim=1)
confidence, predicted = torch.max(probabilities, 1)
return predicted.item(), confidence.item(), probabilities.cpu().numpy()[0]
输出示例:
圖片 8649:
真實標籤: 8
預測標籤: 8
置信度: 1.0000
預測正確: ✅
Top-10 預測:
1. 數字 8: 1.0000
2. 數字 9: 0.0000
3. 數字 3: 0.0000
4. 數字 5: 0.0000
5. 數字 2: 0.0000
6. 數字 6: 0.0000
7. 數字 0: 0.0000
8. 數字 4: 0.0000
9. 數字 7: 0.0000
10. 數字 1: 0.0000
圖片 9388:
真實標籤: 1
預測標籤: 1
置信度: 1.0000
預測正確: ✅
Top-10 預測:
1. 數字 1: 1.0000
2. 數字 7: 0.0000
3. 數字 4: 0.0000
4. 數字 2: 0.0000
5. 數字 8: 0.0000
6. 數字 0: 0.0000
7. 數字 5: 0.0000
8. 數字 9: 0.0000
9. 數字 6: 0.0000
10. 數字 3: 0.0000
圖片 6940:
真實標籤: 3
預測標籤: 3
置信度: 0.9999
預測正確: ✅
Top-10 預測:
1. 數字 3: 0.9999
2. 數字 5: 0.0001
3. 數字 2: 0.0000
4. 數字 9: 0.0000
5. 數字 7: 0.0000
6. 數字 1: 0.0000
7. 數字 8: 0.0000
8. 數字 6: 0.0000
9. 數字 4: 0.0000
10. 數字 0: 0.0000
圖片 7629:
真實標籤: 4
預測標籤: 4
置信度: 1.0000
預測正確: ✅
Top-10 預測:
1. 數字 4: 1.0000
2. 數字 9: 0.0000
3. 數字 7: 0.0000
4. 數字 1: 0.0000
5. 數字 5: 0.0000
6. 數字 6: 0.0000
7. 數字 0: 0.0000
8. 數字 2: 0.0000
9. 數字 8: 0.0000
10. 數字 3: 0.0000
圖片 6803:
真實標籤: 5
預測標籤: 5
置信度: 1.0000
預測正確: ✅
Top-10 預測:
1. 數字 5: 1.0000
2. 數字 3: 0.0000
3. 數字 9: 0.0000
4. 數字 8: 0.0000
5. 數字 6: 0.0000
6. 數字 1: 0.0000
7. 數字 7: 0.0000
8. 數字 0: 0.0000
9. 數字 4: 0.0000
10. 數字 2: 0.0000
本文首发于掘金,作者Ailan Anjuxi,转载请注明出处!