在Rust中使用torch------day2使用tch训练并导出模型

237 阅读3分钟

前言

在上一篇中,我讲了一下tch环境安装配置以及一个简单的yolov8推理,但是作为深度学习框架,仅仅实现推理并不够更重要的是实现模型训练,所以今天来看看利用tch训练模型并导出训练好的权重.

模型训练与导出

1. 模型训练

利用tch训练模型和pytorch中的思想基本一致,都是先定义好网络结构,然后写好前向传播的过程,定义好训练集,优化器以及损失函数等等.这里定义网络的方法和pytorch类似,也有两种

  • 利用struct定义网络参数,然后实现new方法以及重写nn::ModuleT中的forward_t()方法.(对标常规的class(nn.Module))
  • 利用nn::SequentialT构建模型序列.(对标nn.Sequential())

在这里我直接拿官网中examples的fast_resnet作为例子

fn conv_bn(vs: &nn::Path, c_in: i64, c_out: i64) -> SequentialT {
    let conv2d_cfg = nn::ConvConfig { padding: 1, bias: false, ..Default::default() };
    nn::seq_t()
        .add(nn::conv2d(vs, c_in, c_out, 3, conv2d_cfg))
        .add(nn::batch_norm2d(vs, c_out, Default::default()))
        .add_fn(|x| x.relu())
}
​
fn layer<'a>(vs: &nn::Path, c_in: i64, c_out: i64) -> FuncT<'a> {
    let pre = conv_bn(&vs.sub("pre"), c_in, c_out);
    let block1 = conv_bn(&vs.sub("b1"), c_out, c_out);
    let block2 = conv_bn(&vs.sub("b2"), c_out, c_out);
    nn::func_t(move |xs, train| {
        let pre = xs.apply_t(&pre, train).max_pool2d_default(2);
        let ys = pre.apply_t(&block1, train).apply_t(&block2, train);
        pre + ys
    })
}
​
​
fn fast_resnet(vs: &nn::Path) -> SequentialT {
    nn::seq_t()
        .add(conv_bn(&vs.sub("pre"), 3, 64))
        .add(layer(&vs.sub("layer1"), 64, 128))
        .add(conv_bn(&vs.sub("inter"), 128, 256))
        .add_fn(|x| x.max_pool2d_default(2))
        .add(layer(&vs.sub("layer2"), 256, 512))
        .add_fn(|x| x.max_pool2d_default(4).flat_view())
        .add_fn(|x|x.relu())
        .add(nn::linear(vs.sub("linear"), 512, 10, Default::default()))
}
​

然后在tch训练模型时,需要定义一个VarStore去存储所有的变量值.需要注意的是,我们在训练的过程中需要参数是可以更新的,而保存模型之前得冻结所有参数.

pub fn train(epochs:i64){
    tch::Cuda::set_user_enabled_cudnn(true);
    tch::Cuda::cudnn_set_benchmark(true);
    let data=tch::vision::cifar::load_dir("./data").unwrap();
    let mut vs = nn::VarStore::new(Device::cuda_if_available());
    let net = fast_resnet(&vs.root());
​
    let mut opt=nn::Adam::default().build(&vs,1e-3).unwrap();
​
    let mut best_acc=0.0;
    for epoch in 1..epochs+1 {
        vs.unfreeze();
        for (imgs,labels) in data.train_iter(4096).shuffle().to_device(vs.device()){
            let imgs=tch::vision::dataset::augmentation(&imgs,true,4,8);
            let loss=net.forward_t(&imgs,true).cross_entropy_for_logits(&labels);
            opt.backward_step(&loss);
        }
​
        let test_accuracy=net.batch_accuracy_for_logits(&data.test_images,&data.test_labels,vs.device(),1024);
        println!("epoch:{:4} test accuracy={:5.2}%",epoch,test_accuracy*100.);
    }
}

image-20230707162144049

2. 模型保存并导出

if best_acc<test_accuracy{
            best_acc=test_accuracy;
            vs.freeze();
            let mut closure = |input: &[Tensor]| vec![net.forward_t(&input[0], false)];
            let model = CModule::create_by_tracing(
                "MyModule",
                "forward",
                &[Tensor::zeros([1,3,32,32], FLOAT_CUDA)],
                &mut closure,
            ).unwrap();
            model.save("model.pt").unwrap();
}

我们每训练完一个epoch对模型进行测试,当模型准确率提升的时候保存模型权重.需要注意的是,这里create_by_tracing中需要传入输入的大小,这里的大小应该与后面验证时候输入的大小相同.

加载训练后的模型并验证

上面的过程中,我们利用tch实现了模型在cifar10数据集上的训练并且保存了模型的权重.这个部分,我们将在python中加载保存的权重进行推理验证.

import torch
from torchvision import datasets, transforms
import torch.nn.functional as F
​
​
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction="sum").item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(
        f"Test result: Average loss:{test_loss} Accuracy:{correct}/{len(test_loader.dataset)} {100.0 * correct / len(test_loader.dataset)}%")
​
​
def main():
    model = torch.jit.load("./model.pt")
    trans = transforms.ToTensor()
    dataset = datasets.CIFAR10("./data1", train=False, download=True, transform=trans)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
    test(model, test_loader)
​
​
if __name__ == '__main__':
    main()

image-20230707162117578

总结

今天内容比较简单,主要是模型训练以及权重保存.对于框架的使用只需要多看看官方文档,对着api慢慢熟悉就好了.