Rust 实现mnist数据集标签分类模型

88 阅读6分钟
use burn::{

    data::{

        dataloader::{DataLoaderBuilder, batcher::Batcher},

        dataset::vision::{MnistDataset, MnistItem},

    },

    nn::{

        Dropout, DropoutConfig, Linear, LinearConfig, Relu,

        conv::{Conv2d, Conv2dConfig},

        loss::CrossEntropyLossConfig,

        pool::{AdaptiveAvgPool2d, AdaptiveAvgPool2dConfig},

    },

    optim::AdamConfig,

    prelude::*,

    record::{CompactRecorder, Recorder},

    tensor::backend::AutodiffBackend,

    train::{

        ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,

        metric::{AccuracyMetric, LossMetric},

    },

    prelude::*,

};

  


// 定义一个用于处理MNIST数据集的批量器结构体,适用于任何后端B

#[derive(Clone)]

pub struct MnistBatcher<B: Backend> {

    // 设备类型,用于指示后端B在哪个设备上执行操作

    device: B::Device,

}

  


// 实现MNIST批量器的构造函数

impl<B: Backend> MnistBatcher<B> {

    // 创建一个新的MNIST批量器实例

    pub fn new(device: B::Device) -> Self {

        // 构造函数体,简单地将设备参数赋值给实例

        Self { device }

    }

}

  


// 定义MNIST批量数据结构体,包含图像和目标标签

#[derive(Clone, Debug)]

pub struct MnistBatch<B: Backend> {

    // 图像张量,三维表示通常为批量大小、图像高度和图像宽度

    pub images: Tensor<B, 3>,

    // 目标标签张量,一维表示,包含每个图像对应的目标类别,数据类型为Int

    pub targets: Tensor<B, 1, Int>,

}

// 为 MnistBatcher 实现 Batcher 特性,定义如何将多个 MnistItem 实例批量化为 MnistBatch。

impl<B: Backend> Batcher<MnistItem, MnistBatch<B>> for MnistBatcher<B> {

    /// 将一组 MnistItem 转换为 MnistBatch。

    ///

    /// # 参数

    /// - `items`: 包含多个 MnistItem 的向量。

    ///

    /// # 返回值

    /// 返回一个包含图像和标签的 MnistBatch。

    fn batch(&self, items: Vec<MnistItem>) -> MnistBatch<B> {

        // 处理图像数据:转换、重塑并归一化

        let images = items

            .iter()

            .map(|item| TensorData::from(item.image).convert::<B::FloatElem>())

            .map(|data| Tensor::<B, 2>::from_data(data, &self.device))

            .map(|tensor| tensor.reshape([1, 28, 28]))

            // 归一化:将像素值缩放到 [0, 1] 并标准化为均值为 0,标准差为 1

            // 均值 0.1307 和标准差 0.3081 来自 PyTorch MNIST 示例

            // https://github.com/pytorch/examples/blob/54f4572509891883a947411fd7239237dd2a39c3/mnist/main.py#L122

            .map(|tensor| ((tensor / 255) - 0.1307) / 0.3081)

            .collect();

  


        // 处理标签数据:将每个标签转换为张量

        let targets = items

            .iter()

            .map(|item| {

                Tensor::<B, 1, Int>::from_data(

                    [(item.label as i64).elem::<B::IntElem>()],

                    &self.device,

                )

            })

            .collect();

  


        // 将所有图像和标签张量拼接成一个批量张量,并移动到指定设备

        let images = Tensor::cat(images, 0).to_device(&self.device);

        let targets = Tensor::cat(targets, 0).to_device(&self.device);

  


        MnistBatch { images, targets }

    }

}

  


/// 定义一个用于处理 MNIST 数据集的模型结构。

#[derive(Module, Debug)]

pub struct Model<B: Backend> {

    conv1: Conv2d<B>,

    conv2: Conv2d<B>,

    pool: AdaptiveAvgPool2d,

    dropout: Dropout,

    linear1: Linear<B>,

    linear2: Linear<B>,

    activation: Relu,

}

  


impl<B: Backend> Model<B> {

    /// 模型前向传播函数。

    ///

    /// # Shapes

    ///   - Images [batch_size, height, width]

    ///   - Output [batch_size, num_classes]

    ///

    /// # 参数

    /// - `images`: 输入图像张量,形状为 [batch_size, height, width]。

    ///

    /// # 返回值

    /// 返回一个形状为 [batch_size, num_classes] 的输出张量。

    pub fn forward(&self, images: Tensor<B, 3>) -> Tensor<B, 2> {

        let [batch_size, height, width] = images.dims();

  


        // 在第二维度上创建通道

        let x = images.reshape([batch_size, 1, height, width]);

  


        // 第一层卷积 + Dropout

        let x = self.conv1.forward(x); // [batch_size, 8, _, _]

        let x = self.dropout.forward(x);

  


        // 第二层卷积 + Dropout + 激活函数

        let x = self.conv2.forward(x); // [batch_size, 16, _, _]

        let x = self.dropout.forward(x);

        let x = self.activation.forward(x);

  


        // 自适应平均池化

        let x = self.pool.forward(x); // [batch_size, 16, 8, 8]

  


        // 展平特征图并传递给全连接层

        let x = x.reshape([batch_size, 16 * 8 * 8]);

        let x = self.linear1.forward(x);

        let x = self.dropout.forward(x);

        let x = self.activation.forward(x);

  


        // 最终的全连接层输出分类结果

        self.linear2.forward(x) // [batch_size, num_classes]

    }

}

// 为Model实现前向分类功能

impl<B: Backend> Model<B> {

    /// 执行分类任务的前向传播

    ///

    /// # 参数

    ///

    /// - `images`: 输入的图像数据,三维张量

    /// - `targets`: 目标标签,一维张量,数据类型为Int

    ///

    /// # 返回

    ///

    /// 返回一个包含损失、输出和目标标签的ClassificationOutput对象

    pub fn forward_classification(

        &self,

        images: Tensor<B, 3>,

        targets: Tensor<B, 1, Int>,

    ) -> ClassificationOutput<B> {

        let output = self.forward(images);

        let loss = CrossEntropyLossConfig::new()

            .init(&output.device())

            .forward(output.clone(), targets.clone());

  


        ClassificationOutput::new(loss, output, targets)

    }

}

  


// 为Model实现训练步骤

impl<B: AutodiffBackend> TrainStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {

    /// 执行一个训练步骤

    ///

    /// # 参数

    ///

    /// - `batch`: 一个包含图像和标签的MnistBatch对象

    ///

    /// # 返回

    ///

    /// 返回一个包含模型、损失的梯度和分类输出的TrainOutput对象

    fn step(&self, batch: MnistBatch<B>) -> TrainOutput<ClassificationOutput<B>> {

        let item = self.forward_classification(batch.images, batch.targets);

  


        TrainOutput::new(self, item.loss.backward(), item)

    }

}

  


// 为Model实现验证步骤

impl<B: Backend> ValidStep<MnistBatch<B>, ClassificationOutput<B>> for Model<B> {

    /// 执行一个验证步骤

    ///

    /// # 参数

    ///

    /// - `batch`: 一个包含图像和标签的MnistBatch对象

    ///

    /// # 返回

    ///

    /// 返回分类输出

    fn step(&self, batch: MnistBatch<B>) -> ClassificationOutput<B> {

        self.forward_classification(batch.images, batch.targets)

    }

}

  


// 定义训练配置结构体

#[derive(Config)]

pub struct TrainingConfig {

    // 模型配置

    pub model: ModelConfig,

    // 优化器配置(Adam)

    pub optimizer: AdamConfig,

    // 默认训练轮数为10

    #[config(default = 10)]

    pub num_epochs: usize,

    // 默认批量大小为64

    #[config(default = 64)]

    pub batch_size: usize,

    // 默认工作线程数为4

    #[config(default = 4)]

    pub num_workers: usize,

    // 默认随机种子为42

    #[config(default = 42)]

    pub seed: u64,

    // 默认学习率为1.0e-4

    #[config(default = 1.0e-4)]

    pub learning_rate: f64,

}

  


// 创建用于存储训练工件(如模型和配置文件)的目录

fn create_artifact_dir(artifact_dir: &str) {

    // 删除现有工件以确保准确的训练总结

    std::fs::remove_dir_all(artifact_dir).ok();

    // 创建新的工件目录

    std::fs::create_dir_all(artifact_dir).ok();

}

  


// 执行模型训练的函数

pub fn train<B: AutodiffBackend>(artifact_dir: &str, config: TrainingConfig, device: B::Device) {

    // 创建用于存储训练工件的目录

    create_artifact_dir(artifact_dir);

    // 保存训练配置到指定路径

    config

        .save(format!("{artifact_dir}/config.json"))

        .expect("配置应成功保存");

  


    // 设置随机种子

    B::seed(config.seed);

  


    // 创建训练数据的批处理对象

    let batcher_train = MnistBatcher::<B>::new(device.clone());

    // 创建验证数据的批处理对象

    let batcher_valid = MnistBatcher::<B::InnerBackend>::new(device.clone());

  


    // 构建训练数据加载器

    let dataloader_train = DataLoaderBuilder::new(batcher_train)

        .batch_size(config.batch_size)

        .shuffle(config.seed)

        .num_workers(config.num_workers)

        .build(MnistDataset::train());

  


    // 构建测试数据加载器

    let dataloader_test = DataLoaderBuilder::new(batcher_valid)

        .batch_size(config.batch_size)

        .shuffle(config.seed)

        .num_workers(config.num_workers)

        .build(MnistDataset::test());

  


    // 构建训练器对象,设置训练参数和指标

    let learner = LearnerBuilder::new(artifact_dir)

        .metric_train_numeric(AccuracyMetric::new())

        .metric_valid_numeric(AccuracyMetric::new())

        .metric_train_numeric(LossMetric::new())

        .metric_valid_numeric(LossMetric::new())

        .with_file_checkpointer(CompactRecorder::new())

        .devices(vec![device.clone()])

        .num_epochs(config.num_epochs)

        .summary()

        .build(

            config.model.init::<B>(&device),

            config.optimizer.init(),

            config.learning_rate,

        );

  


    // 开始训练模型

    let model_trained = learner.fit(dataloader_train, dataloader_test);

  


    // 保存训练后的模型到指定路径

    model_trained

        .save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())

        .expect("训练后的模型应成功保存");

}

// 执行模型推理的函数

pub fn infer<B: Backend>(artifact_dir: &str, device: B::Device, item: MnistItem) {

    // 加载训练配置文件,确保配置存在

    let config = TrainingConfig::load(format!("{artifact_dir}/config.json"))

        .expect("模型应有对应的配置文件");

  


    // 加载训练后的模型记录,确保模型已存在

    let record = CompactRecorder::new()

        .load(format!("{artifact_dir}/model").into(), &device)

        .expect("训练后的模型应存在");

  


    // 初始化并加载模型参数

    let model = config.model.init::<B>(&device).load_record(record);

  


    // 获取输入项的标签

    let label = item.label;

  


    // 创建批处理对象

    let batcher = MnistBatcher::new(device);

    // 将单个输入项转换为批次

    let batch = batcher.batch(vec![item]);

  


    // 使用模型进行前向传播,获取输出

    let output = model.forward(batch.images);

  


    // 获取预测结果

    let predicted = output.argmax(1).flatten::<1>(0, 1).into_scalar();

  


    // 输出预测结果和期望值

    println!("预测值: {} 期望值: {}", predicted, label);

}

  


// 定义模型配置结构体

#[derive(Config, Debug)]

pub struct ModelConfig {

    // 类别数量

    num_classes: usize,

    // 隐藏层大小

    hidden_size: usize,

    // 默认dropout概率为0.5

    #[config(default = "0.5")]

    dropout: f64,

}

  


impl ModelConfig {

    /// 返回初始化的模型

    pub fn init<B: Backend>(&self, device: &B::Device) -> Model<B> {

        Model {

            // 第一层卷积层配置

            conv1: Conv2dConfig::new([1, 8], [3, 3]).init(device),

            // 第二层卷积层配置

            conv2: Conv2dConfig::new([8, 16], [3, 3]).init(device),

            // 自适应平均池化层配置

            pool: AdaptiveAvgPool2dConfig::new([8, 8]).init(),

            // 激活函数

            activation: Relu::new(),

            // 第一个全连接层配置

            linear1: LinearConfig::new(16 * 8 * 8, self.hidden_size).init(device),

            // 第二个全连接层配置

            linear2: LinearConfig::new(self.hidden_size, self.num_classes).init(device),

            // Dropout层配置

            dropout: DropoutConfig::new(self.dropout).init(),

        }

    }

}