Pytorch——搭建Cifar10推理测试脚本与分类模型优化思路

399 阅读3分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第28天,点击查看活动详情


前言

在之前的文章中,我们介绍了如何去定义不同的网络结构,包括ResNet,VGGNet,MobileNet,InceptionNet,并且介绍了如何使用Pytorch中已经定义好的标准网络。

今天,我们介绍如何使用已经定义好的模型来编写测试脚本并且介绍分类模型优化思路。


  • 1.1 搭建推理测试脚本

import torch
import glob
import cv2
from PIL import Image
from torchvision import transforms
import numpy as np
from resnet import resnet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = resnet()

net.load_state_dict(torch.load("models/resnet/200.pth"))

im_list = glob.glob("cifar-10-batches-py/test/*/*")

np.random.shuffle(im_list)

net.to(device)

label_name = ["airplane", "automobile", "bird",
              "cat", "deer", "dog",
              "frog", "horse", "ship", "truck"]

# 预处理
test_transform = transforms.Compose([
    transforms.CenterCrop((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010)),
])

# 遍历图片数据
for im_path in im_list:
    net.eval()
    im_data = Image.open(im_path)

    inputs = test_transform(im_data)
    inputs = torch.unsqueeze(inputs, dim=0)

    inputs = inputs.to(device)
    outputs = net.forward(inputs)

    _, pred = torch.max(outputs.data, dim=1)

    print(label_name[pred.cpu().numpy()[0]])

    img = np.asarray(im_data)
    img = img[:, :, [1, 2, 0]]

    img = cv2.resize(img, (300, 300))
    cv2.imshow("im", img)
    cv2.waitKey()
  • 2.1 分类模型优化思路

分类模型优化思路其实就是调参技巧,调参技巧说起来是比较苍白的,更多的还是去实操,实打实的去调网络模型的参数,观察这些参数,以及利用这些参数去分析其中产生的问题。

本篇文章主要还是介绍想要去进行分类模型的优化主要从哪方面入手,如何去思考这些问题。

  • 2.1.1 backbone:
    • 也就是所谓的主干网络,到底是使用轻量型的卷积神经网络,还是使用resnet这种经典的比较重的网络结构。
    • 网络结构在兼顾模型性能,也就是计算时间,以及模型准确率的平衡的前提下,在保证尽可能计算量小的情况下去实现一个比较好的效果
    • 尤其是在解决一些实际落地场景的问题的时候,如果在GPU上跑这个模型的话,模型太大会非常耗费成本
    • 如果想把模型跑在终端的时候,一定要考虑轻量型的卷进神经网络或者去做模型的裁剪等相关的工作;
    • 在关于卷积神经网络网络结构的设计,在进展这个方向上,其实大多数成果主要还是关注在如何去压缩模型的计算量,我们可以采用一些设计技巧,也可以采用一些工程化的手段,比如模型的裁剪。
    • 在解决问题的时候,经常会尝试不同的网络结构,通过修改它的主干网络来比较模型的最终性能,进而去权衡计算量和精度的情况下,去找到合适的网络模型。在最终使用resnet还是mobilenet还是其他的,具体在使用的时候还是看具体的效果。
  • 2.1.2 过拟合问题
    • 如果不做任何的数据增强,直接将训练集拿过来训练,测试集直接测试,这时会发现loss是先降,降完之后是直接升上去的,这个时候就是产生了过拟合的问题。
    • 解决这个问题可以通过补充样本、进行数据增强、添加dropout层、添加L2正则向、把模型改的更简单一点。
    • 使用最多的就是进行数据增强。
  • 2.1.3 学习率调整
    • 初始化的参数是多少比较合适
    • 在训练的时候如何去衰减比较合适
  • 2.1.5 优化函数
  • 2.1.4 数据增强
    • 亮度、对比度、饱和度
    • 裁剪、旋转

在解决模型优化的时候,或者在解决之前没有碰到的问题的时候,在拿到这些问题的时候,如何去建模?在建好模型之后,如何去调参?

  • 通常要做的事就是观察LOSS的变化,是否产生过拟合
  • 观察数据
  • 了解最新进展baseline
  • 分析错误case

9JQ4ZCQY3M({Q$KEN%9BFQX.png