【AI实战】从零实现CIFAR-10图像分类,准确率80%+(附完整代码)

2 阅读5分钟

💡 本教程是《AI 入门 30 天挑战》系列的项目实战部分


🎯 项目简介

这是一个完整的 CIFAR-10 图像分类项目,使用 PyTorch 实现 CNN 模型,达到 80%+ 准确率。

你将学到:

  • ✅ 完整的深度学习项目流程
  • ✅ CNN 模型设计与优化
  • ✅ 数据增强技巧
  • ✅ 模型训练与评估
  • ✅ 可视化结果分析

项目特点:

  • 🚀 模块化设计,代码清晰
  • 📊 完整的训练曲线和混淆矩阵
  • 🎨 预测结果可视化
  • 📝 详细的中文注释

📂 项目结构

cifar10-classification/
├── main.py              # 主程序入口
├── model.py             # CNN 模型定义
├── train.py             # 训练脚本
├── evaluate.py          # 评估脚本
├── utils.py             # 可视化工具
├── requirements.txt     # 依赖包
└── README.md            # 详细说明

🚀 快速开始

1. 安装依赖

pip install torch torchvision matplotlib seaborn scikit-learn

2. 运行训练

cd projects/cifar10-classification
python main.py --mode train

训练完成后会生成:

  • cifar_best.pth - 最佳模型权重
  • training_curves.png - 训练曲线图
  • confusion_matrix.png - 混淆矩阵
  • predictions.png - 预测示例

3. 评估模型

python main.py --mode evaluate

🔍 核心代码解析

CNN 模型架构

class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        
        # 特征提取层
        self.features = nn.Sequential(
            # 第一层卷积块
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            
            # 第二层卷积块
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
            
            # 第三层卷积块
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout(0.25),
        )
        
        # 分类器
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

大白话解释:

  • 这个模型就像人的眼睛,分3层看图片
  • 每层先看细节(卷积),再总结特征(池化)
  • 最后用全连接层做判断
  • Dropout 防止死记硬背(过拟合)

📊 训练结果

准确率曲线

训练过程中会生成训练曲线图:

  • 训练集准确率:85%+
  • 测试集准确率:80%+
  • 无明显过拟合

混淆矩阵

混淆矩阵显示各类别的分类情况:

容易混淆的类别:

  • Cat vs Dog(猫和狗)
  • Truck vs Automobile(卡车和汽车)

这是正常的,因为这些类别本身就很相似。


💡 优化技巧

1. 数据增强

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.RandomRotation(10),      # 随机旋转
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

为什么有效?

  • 让模型看到更多变化的图片
  • 提高泛化能力
  • 相当于增加了训练数据

2. 学习率调度

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

作用:

  • 每 10 个 epoch 学习率减半
  • 后期微调,避免震荡
  • 帮助模型收敛到更好的位置

3. Batch Normalization

nn.BatchNorm2d(32)

好处:

  • 加速训练
  • 允许使用更大的学习率
  • 有一定的正则化效果

🎨 可视化预测

项目提供了完整的可视化工具:

def predict_image(model, image_path):
    """预测单张图片"""
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    input_tensor = transform(image).unsqueeze(0)
    
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.softmax(output, dim=1)[0]
        confidence, predicted_idx = torch.max(probabilities, 0)
    
    return CLASSES[predicted_idx.item()], float(confidence)

🤔 常见问题

Q1: 为什么准确率只有 80%?

A: CIFAR-10 本身就有难度,有些图片连人都分不清。80% 对初学者来说已经不错了。想提高可以:

  • 使用更深的网络(ResNet)
  • 更多的数据增强
  • 迁移学习

Q2: 训练太慢怎么办?

A:

  • 使用 GPU(速度提升 10-50 倍)
  • 减少 batch size
  • 减少 epoch 数量

Q3: 如何用自己的图片测试?

A: 修改 predict_image 函数中的图片路径即可。


📚 相关教程

这是《AI 入门 30 天挑战》的项目实战部分,前置知识:

完整 30 天教程:


🎉 总结

通过这个实战项目,你学会了:

  1. ✅ 完整的深度学习项目流程
  2. ✅ CNN 模型设计与实现
  3. ✅ 数据增强和正则化技巧
  4. ✅ 模型评估和可视化

下一步:

  • ⭐ Star GitHub 获取完整代码
  • ➕ 关注专栏查看更多项目
  • 💬 评论区分享你的训练结果

其他项目实战:


🎉 恭喜你完成今天的学习!

📚 学习路径导航

上一篇当前下一篇
Day 14 - Week2 综合项目项目实战 - CIFAR-10项目实战 - 文本生成

🔗 资源汇总

💬 互动时间

思考题:你觉得还有哪些方法可以提高 CIFAR-10 的分类准确率?

欢迎在评论区分享你的想法或疑问!👇

❤️ 如果有帮助

  • 👍 点赞:让更多人看到这篇教程
  • Star GitHub:获取完整代码和项目
  • 关注专栏:不错过后续更新
  • 🔄 分享给朋友:一起学习进步

明天见!继续下一个项目实战~ 🚀


本文是《AI 入门 30 天挑战》系列的项目实战篇 完整代码已开源,欢迎 Star 支持!