libtorch教程:训练一个简单CNN

23 阅读2分钟

libtorch教程:训练一个简单CNN

Libtorch是PyTorch的cpp前端库,它提供了与PyTorch类似的API,允许您在cpp程序中使用PyTorch的功能。
Libtorch非常适合于将PyTorch集成到现有cpp项目中,或在需要cpp性能和效率的环境中使用。在此教程中,
我们将介绍如何使用Libtorch构建和训练一个简单的卷积神经网络(CNN)来对MNIST手写数字数据集进行分类。

开始之前,请确保您的系统上已安装Libtorch库。您可以从PyTorch官方网站(pytorch.org/)下载预编译的二进制文…)。

现在,让我们开始使用Libtorch构建和训练CNN吧!

1. 包含头文件:在cpp代码中,首先需要包含Libtorch的头文件

#include <torch/torch.h>

2. 加载数据集:我们将使用Libtorch提供的工具来加载MNIST数据集

    // 加载训练集和测试集
    torch::data::datasets::MNIST mnist_train("path/to/mnist", torch::data::datasets::MNIST::Mode::Train);
    torch::data::datasets::MNIST mnist_test("path/to/mnist", torch::data::datasets::MNIST::Mode::Test);
    // 请将"path/to/mnist"替换为您本地MNIST数据集的路径。

3 定义神经网络:我们将定义一个简单的卷积神经网络来对MNIST数据集进行分类

    struct Net : torch::nn::Module {
        Net() : conv1(torch::nn::Conv2d(1, 32, 3, 1)),
                conv2(torch::nn::Conv2d(32, 64, 3, 1)),
                fc1(torch::nn::Linear(9216, 128)),
                fc2(torch::nn::Linear(128, 10)) {}

        torch::Tensor forward(torch::Tensor x) {
            x = torch::relu(conv1->forward(x));
            x = torch::relu(conv2->forward(x));
            x = x.view({-1, 9216});
            x = torch::relu(fc1->forward(x));
            x = fc2->forward(x);
            return x;
        }

        torch::nn::Conv2d conv1, conv2;
        torch::nn::Linear fc1, fc2;
    };

4. 定义损失函数和优化器:我们将使用交叉熵损失函数和随机梯度下降优化器

    torch::optim::SGD optimizer(net->parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));
    torch::nn::CrossEntropyLoss loss;

5. 训练神经网络:我们将使用数据加载器(DataLoader)来迭代训练数据集,并对神经网络进行训练

    for (size_t epoch = 1; epoch <= 10; ++epoch) {
        size_t batch_index = 0;
        for (auto& batch : torch::data::make_data_loader(mnist_train, 64)) {
            auto data = batch.data.view({-1, 1, 28, 28});
            auto target = batch.target;

            optimizer.zero_grad();
            auto output = net->forward(data);
            auto loss_value = loss(output, target);
            loss_value.backward();
            optimizer.step();

            if (batch_index % 100 == 0) {
                std::cout << "Epoch: " << epoch << " | Batch: " << batch_index << " | Loss: " << loss_value.item<float>() << std::endl;
            }

            batch_index++;
        }
    }

6. 评估神经网络:我们将使用测试集来评估训练后的神经网络

    size_t correct = 0;
    for (auto& batch : torch::data::make_data_loader(mnist_test, 64)) {
        auto data = batch.data.view({-1, 1, 28, 28});
        auto target = batch.target;

        auto output = net->forward(data);
        auto predictions = torch::argmax(output, 1);
        correct += torch::sum(predictions == target).item<int64_t>();
    }

    double accuracy = static_cast<double>(correct) / mnist_test.size().value();
    std::cout << "Test accuracy: " << accuracy << std::endl;

以上就是使用Libtorch构建和训练一个简单CNN的完整教程。 您可以根据需要修改代码,例如添加更多的卷积层或全连接层,调整超参数等。 希望本教程能帮助您使用Libtorch进行深度学习项目开发!