MindSpore MNIST模型优化实战:动态学习率+早停机制落地

1 阅读1分钟

​MNIST手写数字识别是深度学习入门经典场景,基于 MindSpore 框架开发时,基础训练代码往往存在 “学习率固定导致收敛慢”“无早停机制易过拟合”“代码工程化程度低” 等问题。本文将从实战角度出发,实现一套包含 “动态学习率 + 早停机制 + 完整工程化流程” 的 MNIST 模型优化方案,提供可直接运行的完整代码,并拆解核心优化逻辑。

一、核心优化目标

实现动态学习率调度,解决固定学习率训练后期收敛慢、易震荡的问题;

加入早停机制,避免模型过拟合,自动保存训练过程中的最优权重;

补充工程化细节(如评估模式控制、训练进度监控),让代码具备落地可用性;

规避 MindSpore 框架特有的优化器、梯度函数适配问题,保证代码可运行。

二、完整优化代码实现

1. 环境与依赖导入

首先导入 MindSpore 核心依赖,涵盖模型构建、数据处理、梯度计算等模块:

import mindspore as ms
import mindspore.nn as nn
from mindspore import value_and_grad
from mindspore.dataset import vision, transforms
from mindspore.dataset import MnistDataset

2. CNN 模型定义

构建轻量级 CNN 模型,适配 MNIST 手写数字识别的特征提取需求,包含卷积、池化、全连接及 Dropout 层(防止过拟合):

class CNN(nn.Cell):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, pad_mode='same')
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, pad_mode='same')
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.flatten = nn.Flatten()
        self.fc1 = nn.Dense(64 * 7 * 7, 1024)
        self.relu3 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Dense(1024, 10)

    def construct(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = self.flatten(x)
        x = self.dropout(self.relu3(self.fc1(x)))
        x = self.fc2(x)
        return x

3. 数据加载与预处理

加载 MNIST 数据集并完成标准化、维度转换、批量处理等预处理,保证输入数据符合模型要求:

def load_data():
    # 加载MNIST数据集(本地无数据会自动下载)
    train_dataset = MnistDataset("MNIST_Data/train")
    test_dataset = MnistDataset("MNIST_Data/test")

    # 数据预处理
    trans = transforms.Compose([
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=[0.1307], std=[0.3081]),
        vision.HWC2CHW()
    ])
    # 批量处理+乱序
    train_dataset = train_dataset.map(trans, input_columns=["image"])
    train_dataset = train_dataset.map(transforms.TypeCast(ms.int32), input_columns=["label"])
    train_dataset = train_dataset.batch(64).shuffle(64)
    test_dataset = test_dataset.map(trans, input_columns=["image"])
    test_dataset = test_dataset.map(transforms.TypeCast(ms.int32), input_columns=["label"])
    test_dataset = test_dataset.batch(64)
    return train_dataset, test_dataset

4. 训练 / 测试循环函数

补全核心训练和测试逻辑,解决基础代码中函数缺失导致的报错问题,同时加入评估模式控制:

def train_loop(model, dataset, loss_fn, train_step):
    """训练循环:返回本轮总损失"""
    total_loss = 0.0
    step_num = 0
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)
        total_loss += loss.asnumpy()
        step_num += 1
        # 每50批次打印一次训练进度(可选)
        if batch % 50 == 0:
            print(f"Batch [{batch}], Loss: {loss.asnumpy():.4f}")
    # 返回平均损失
    return total_loss / step_num

def test_loop(model, dataset, loss_fn):
    """测试循环:返回准确率和平均损失"""
    total_correct = 0
    total_loss = 0.0
    total_num = 0
    # 评估模式:关闭Dropout
    model.set_train(False)
    for data, label in dataset.create_tuple_iterator():
        logits = model(data)
        loss = loss_fn(logits, label)
        # 计算正确数
        pred = logits.argmax(axis=1)
        total_correct += (pred == label).sum().asnumpy()
        total_loss += loss.asnumpy() * data.shape[0]
        total_num += data.shape[0]
    # 恢复训练模式
    model.set_train(True)
    # 计算准确率和平均损失
    accuracy = total_correct / total_num
    avg_loss = total_loss / total_num
    return accuracy, avg_loss

5. 动态学习率优化(核心)

针对 MindSpore 优化器learning_rate不可直接赋值的特性,通过 “按轮次重新初始化优化器” 实现动态学习率,每 5 轮将学习率衰减 50%:

def get_optimizer(model, epoch):
    """根据轮次返回对应学习率的Adam优化器"""
    if epoch < 5:
        lr = 1e-3
    elif epoch < 10:
        lr = 5e-4
    else:
        lr = 2.5e-4
    return nn.Adam(model.trainable_params(), learning_rate=lr, weight_decay=1e-4)

6. 完整训练流程(整合所有优化)

整合动态学习率、早停机制、模型保存等逻辑,形成端到端的训练流程:

if __name__ == "__main__":
    # 初始化环境
    ms.set_context(mode=ms.PYNATIVE_MODE)  # 动态图模式,适合调试
    # 加载数据
    train_dataset, test_dataset = load_data()
    # 初始化模型、损失函数
    model = CNN()
    loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

    # 训练超参数
    epochs = 15
    best_accuracy = 0.0
    patience = 3
    no_improve_epoch = 0
    # 核心训练循环
    for t in range(epochs):
        # 步骤1:每轮重新初始化优化器
        optimizer = get_optimizer(model, t)
        current_lr = optimizer.learning_rate  # 获取当前学习率
        
        # 步骤2:每轮重新构建梯度函数和训练步骤(依赖新优化器)
        def forward_fn(data, label):
            logits = model(data)
            loss = loss_fn(logits, label)
            return loss, logits
        grad_fn = value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)
        def train_step(data, label):
            (loss, _), grads = grad_fn(data, label)
            optimizer(grads)
            return loss
        # 步骤3:训练+评估
        print(f"Epoch {t+1}/{epochs}\n-------------------------------")
        print(f"Current learning rate: {current_lr}")
        train_loss = train_loop(model, train_dataset, loss_fn, train_step)
        test_accuracy, test_loss = test_loop(model, test_dataset, loss_fn)
        
        # 步骤4:早停机制
        print(f"Epoch [{t+1}] Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            no_improve_epoch = 0
            ms.save_checkpoint(model, "best_cnn_model.ckpt")
            print(f"✅ Best model saved! Current best accuracy: {(100*best_accuracy):>0.1f}%\n")
        else:
            no_improve_epoch += 1
            print(f"⚠️ No improvement. No improve epoch: {no_improve_epoch}/{patience}\n")
        
        if no_improve_epoch >= patience:
            print(f"🛑 Early stopping at epoch {t+1}! Best accuracy: {(100*best_accuracy):>0.1f}%")
            break
    print("🎉 Training completed!")

总结:

1. 动态学习率的实现思路

MindSpore 中优化器的learning_rate属于 Cell 子模块,无法直接赋值修改,因此放弃 “修改已有优化器学习率” 的思路,改为 “每轮根据当前 epoch 重新创建优化器”,并同步重构grad_fn和train_step(梯度函数依赖优化器的参数列表),既实现了学习率动态调整,又规避了框架特性导致的TypeError。

2. 早停机制的核心价值

传统固定轮次训练易出现 “后期过拟合” 或 “无效训练”,早停机制通过监控测试准确率:

若准确率持续提升,保存当前最优模型;

若连续 3 轮无提升,直接终止训练,避免资源浪费,同时防止模型过拟合。

3. 工程化细节优化

评估模式控制:测试时调用model.set_train(False)关闭 Dropout 层,保证准确率计算的真实性;测试完成后恢复model.set_train(True),不影响后续训练;

批量进度打印:每 50 批次打印损失值,方便实时监控训练状态,及时发现训练异常;

动态图模式:采用PYNATIVE_MODE而非静态图模式,新手调试时能获得更精准的报错信息,降低排错成本。

运行结果:

Epoch 1/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 2.3089
Batch [50], Loss: 0.3292
Batch [100], Loss: 0.2275
Batch [150], Loss: 0.1016
Batch [200], Loss: 0.0906
Batch [250], Loss: 0.0743
Batch [300], Loss: 0.1022
Batch [350], Loss: 0.1652
Batch [400], Loss: 0.0151
Batch [450], Loss: 0.0343
Batch [500], Loss: 0.1511
Batch [550], Loss: 0.0202
Batch [600], Loss: 0.1167
Batch [650], Loss: 0.0394
Batch [700], Loss: 0.1037
Batch [750], Loss: 0.0193
Batch [800], Loss: 0.0025
Batch [850], Loss: 0.0865
Batch [900], Loss: 0.0467
Epoch [1] Train Loss: 0.1059, Test Loss: 0.0380, Test Accuracy: 0.9876
 Best model saved! Current best accuracy: 98.8%

Epoch 2/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0113
Batch [50], Loss: 0.0550
Batch [100], Loss: 0.1348
Batch [150], Loss: 0.1025
Batch [200], Loss: 0.0653
Batch [250], Loss: 0.1268
Batch [300], Loss: 0.1043
Batch [350], Loss: 0.1079
Batch [400], Loss: 0.0053
Batch [450], Loss: 0.0462
Batch [500], Loss: 0.0368
Batch [550], Loss: 0.0338
Batch [600], Loss: 0.0425
Batch [650], Loss: 0.0698
Batch [700], Loss: 0.0794
Batch [750], Loss: 0.0438
Batch [800], Loss: 0.1844
Batch [850], Loss: 0.0093
Batch [900], Loss: 0.1037
Epoch [2] Train Loss: 0.0532, Test Loss: 0.0386, Test Accuracy: 0.9873
⚠️ No improvement. No improve epoch: 1/3
Epoch 3/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0129
Batch [50], Loss: 0.0368
Batch [100], Loss: 0.0050
Batch [150], Loss: 0.0449
Batch [200], Loss: 0.0210
Batch [250], Loss: 0.0113
Batch [300], Loss: 0.0410
Batch [350], Loss: 0.1138
Batch [400], Loss: 0.0093
Batch [450], Loss: 0.0844
Batch [500], Loss: 0.0089
Batch [550], Loss: 0.0716
Batch [600], Loss: 0.0605
Batch [650], Loss: 0.0150
Batch [700], Loss: 0.0018
Batch [750], Loss: 0.1401
Batch [800], Loss: 0.0548
Batch [850], Loss: 0.1749
Batch [900], Loss: 0.0058
Epoch [3] Train Loss: 0.0440, Test Loss: 0.0332, Test Accuracy: 0.9886
 Best model saved! Current best accuracy: 98.9%
Epoch 4/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0179
Batch [50], Loss: 0.0515
Batch [100], Loss: 0.0204
Batch [150], Loss: 0.1003
Batch [200], Loss: 0.0898
Batch [250], Loss: 0.0033
Batch [300], Loss: 0.0014
Batch [350], Loss: 0.0304
Batch [400], Loss: 0.0066
Batch [450], Loss: 0.0095
Batch [500], Loss: 0.0021
Batch [550], Loss: 0.0356
Batch [600], Loss: 0.0337
Batch [650], Loss: 0.1777
Batch [700], Loss: 0.0246
Batch [750], Loss: 0.0499
Batch [800], Loss: 0.0168
Batch [850], Loss: 0.0065
Batch [900], Loss: 0.0378
Epoch [4] Train Loss: 0.0384, Test Loss: 0.0277, Test Accuracy: 0.9910
 Best model saved! Current best accuracy: 99.1%

Epoch 5/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0258
Batch [50], Loss: 0.0079
Batch [100], Loss: 0.0395
Batch [150], Loss: 0.0083
Batch [200], Loss: 0.0547
Batch [250], Loss: 0.1103
Batch [300], Loss: 0.0156
Batch [350], Loss: 0.0058
Batch [400], Loss: 0.1414
Batch [450], Loss: 0.0208
Batch [500], Loss: 0.0291
Batch [550], Loss: 0.0680
Batch [600], Loss: 0.0339
Batch [650], Loss: 0.0015
Batch [700], Loss: 0.0270
Batch [750], Loss: 0.0083
Batch [800], Loss: 0.0083
Batch [850], Loss: 0.0397
Batch [900], Loss: 0.0629
Epoch [5] Train Loss: 0.0342, Test Loss: 0.0290, Test Accuracy: 0.9905
⚠️ No improvement. No improve epoch: 1/3
Epoch 6/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0903
Batch [50], Loss: 0.0157
Batch [100], Loss: 0.0813
Batch [150], Loss: 0.0360
Batch [200], Loss: 0.0210
Batch [250], Loss: 0.0306
Batch [300], Loss: 0.0438
Batch [350], Loss: 0.0032
Batch [400], Loss: 0.0776
Batch [450], Loss: 0.0019
Batch [500], Loss: 0.0730
Batch [550], Loss: 0.0129
Batch [600], Loss: 0.0004
Batch [650], Loss: 0.0447
Batch [700], Loss: 0.0243
Batch [750], Loss: 0.0273
Batch [800], Loss: 0.0031
Batch [850], Loss: 0.0041
Batch [900], Loss: 0.0970
Epoch [6] Train Loss: 0.0207, Test Loss: 0.0206, Test Accuracy: 0.9931
 Best model saved! Current best accuracy: 99.3%
Epoch 7/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0238
Batch [50], Loss: 0.0019
Batch [100], Loss: 0.0120
Batch [150], Loss: 0.0042
Batch [200], Loss: 0.0165
Batch [250], Loss: 0.0277
Batch [300], Loss: 0.0077
Batch [350], Loss: 0.0080
Batch [400], Loss: 0.0494
Batch [450], Loss: 0.0027
Batch [500], Loss: 0.0171
Batch [550], Loss: 0.0333
Batch [600], Loss: 0.0043
Batch [650], Loss: 0.0061
Batch [700], Loss: 0.0355
Batch [750], Loss: 0.0580
Batch [800], Loss: 0.0488
Batch [850], Loss: 0.0089
Batch [900], Loss: 0.0053
Epoch [7] Train Loss: 0.0181, Test Loss: 0.0253, Test Accuracy: 0.9904
⚠️ No improvement. No improve epoch: 1/3
Epoch 8/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0208
Batch [50], Loss: 0.0071
Batch [100], Loss: 0.0022
Batch [150], Loss: 0.0046
Batch [200], Loss: 0.0195
Batch [250], Loss: 0.0259
Batch [300], Loss: 0.0022
Batch [350], Loss: 0.0033
Batch [400], Loss: 0.0217
Batch [450], Loss: 0.0630
Batch [500], Loss: 0.0074
Batch [550], Loss: 0.0011
Batch [600], Loss: 0.0331
Batch [650], Loss: 0.0086
Batch [700], Loss: 0.0202
Batch [750], Loss: 0.0200
Batch [800], Loss: 0.0273
Batch [850], Loss: 0.0223
Batch [900], Loss: 0.0337
Epoch [8] Train Loss: 0.0176, Test Loss: 0.0268, Test Accuracy: 0.9910
⚠️ No improvement. No improve epoch: 2/3
Epoch 9/15
-------------------------------
Current learning rate: Parameter (name=learning_rate, shape=(), dtype=Float32, requires_grad=True)
Batch [0], Loss: 0.0014
Batch [50], Loss: 0.0260
Batch [100], Loss: 0.0056
Batch [150], Loss: 0.0038
Batch [200], Loss: 0.0014
Batch [250], Loss: 0.0291
Batch [300], Loss: 0.0065
Batch [350], Loss: 0.0019
Batch [400], Loss: 0.0292
Batch [450], Loss: 0.0603
Batch [500], Loss: 0.0146
Batch [550], Loss: 0.0028
Batch [600], Loss: 0.0132
Batch [650], Loss: 0.0106
Batch [700], Loss: 0.0435
Batch [750], Loss: 0.0085
Batch [800], Loss: 0.0024
Batch [850], Loss: 0.0111
Batch [900], Loss: 0.0028
Epoch [9] Train Loss: 0.0174, Test Loss: 0.0225, Test Accuracy: 0.9922
⚠️ No improvement. No improve epoch: 3/3
🛑 Early stopping at epoch 9! Best accuracy: 99.3%
🎉 Training completed!