Pytorch——调用Pytorch提供的标准网络

347 阅读2分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第28天,点击查看活动详情


前言

在之前的文章中,我们介绍了如何去自定义去完成关于ResNet这样的网络结构,VGGNet这样的网络结构,MobileNet这样的网络结构,以及Inception这样不同的四大类结构。实际上,在Pytorch中提供了非常多的已经定义好的模型,这些模型也是目前来说比较标准的网络结构,我们经常会利用这些标准的网络结构去作为我们的预训练的模型,这样就可以节省很多的工作,就不需要自己去自定义模型结构。

今天,我们通过调用Pytorch提供的标准网络ResNet18来完成Cifar10模型的训练。


  • 1.1 调用Pytorch提供的标准网络

相比于之前自定义的网络结构,使用Pytorch提供的标准网络的代码量是比较少的,如果不需要对网络结构进行自己定义或者进行模型压缩裁剪等操作的时候,推荐大家使用Pytorch提供的标准网络结构

import torch.nn as nn
from torchvision import models

class resnet18(nn.Module):
    def __init__(self):
        super(resnet18, self).__init__()
        self.model = models.resnet18(pretrained=True)
        # 这里主要用来解决cifar10,需要修改类别数
        self.num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(self.num_features, 10)

    def forward(self, x):
        out = self.model(x)

        return out


def pytorch_resnet18():
    return resnet18()

注:cifar10数据训练的代码参考我之前的文章Pytorch——Cifar10图像分类中的训练模型的代码,只需要修改一下net即可。

在进行cifar10的数据训练,可以看到在第一个epoch之后,准确率到了26%,并且整个网络是处于收敛过程中的,如果需要使用其它的网络结构的时候,也可以利用这个模板来调用其他的模型。

9JQ4ZCQY3M({Q$KEN%9BFQX.png