Pytorch——Cifar10图像分类实战

286 阅读4分钟

携手创作,共同成长!这是我参与「掘金日新计划 · 8 月更文挑战」的第26天,点击查看活动详情


上一篇文章中,我们进行了图像分类网络模型框架解读

今天,我们来进行Cifar10图像分类实战


  • 1.1 数据读取-处理

  • 1.1.1 Cifar10/100数据集介绍&下载
  • Cifar10/100

    • 8000万个微小图像数据集的子集
    • 由Alex krizhevsky,Vinod Nair,Geoffrey Hinton收集
    • Cifar10/100通常会用来评价基础的卷积神经网络,比如VGGNet,AlexNet,ResNet等基础的网络结构的时候会用到的标准的数据集
    • 数据集由6万张32*32的彩⾊图⽚组成,⼀共有10个类别。每个类别6000张图⽚。其中有5万张训练图⽚及1万张测试图⽚。数据集被划分为5个训练块和1个测试块,每个块1万张图⽚。
    • 官网:www.cs.toronto.edu/~kriz/cifar…
    • 在官网中提供了将下载的数据解码的方式,直接使用即可

数据解析代码实现:

import pickle
def unpickle(file):
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
label_name = ["airplane",
              "automobile",
              "bird",
              "cat",
              "deer",
              "dog",
              "frog",
              "horse",
              "ship",
              "truck"]

import glob
import numpy as np
import cv2
import os

train_list = glob.glob("cifar-10-batches-py/data_batch_*")
print(train_list)
save_path = "cifar-10-batches-py/train"

for l in train_list:
    print(l)
    l_dict = unpickle(l)
    # print(l_dict)
    print(l_dict.keys())

    for im_idx, im_data in enumerate(l_dict[b'data']):
        im_label = l_dict[b'labels'][im_idx]
        im_name = l_dict[b'filenames'][im_idx]
        print(im_label, im_name, im_data)
        im_label_name = label_name[im_label]
        im_data = np.reshape(im_data, [3, 32, 32])
        im_data = np.transpose(im_data, (1, 2, 0))

        # cv2.imshow("im_data", cv2.resize(im_data, (200, 200)))
        # cv2.waitKey(0)

        if not os.path.exists("{}/{}".format(save_path,
                                             im_label_name)):
            os.mkdir("{}/{}".format(save_path, im_label_name))

        cv2.imwrite("{}/{}/{}".format(save_path,
                                      im_label_name,
                                      im_name.decode("utf-8")), im_data)

加载cifar10数据:

加载好的cifar10数据会用于模型的训练

from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image
import numpy as np
import glob

label_name = ["airplane", "automobile", "bird",
              "cat", "deer", "dog",
              "frog", "horse", "ship", "truck"]

# 将类别存储到字典中
label_dict = {}
# 对标签加工
for idx, name in enumerate(label_name):
    label_dict[name] = idx



def default_loader(path):
    return Image.open(path).convert("RGB")

train_transform = transforms.Compose([
    transforms.RandomResizedCrop((28, 28)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(90),
    transforms.RandomGrayscale(0.1),
    transforms.ColorJitter(0.3, 0.3, 0.3, 0.3),
    transforms.ToTensor()
])



class MyDataset(Dataset):
    def __init__(self, im_list,
                 transform=None,
                 loader=default_loader):
        super(MyDataset, self).__init__()
        imgs = []

        for im_item in im_list:
            #"cifar-10-batches-py/train/" \
            #"airplane/aeroplane_s_000021.png"
            im_label_name = im_item.split("\")[-2]
            imgs.append([im_item, label_dict[im_label_name]])

        self.imgs = imgs
        self.transform = transform
        self.loader = loader

    def __getitem__(self, index):
        im_path, im_label = self.imgs[index]
        im_data = self.loader(im_path)
        if self.transform is not None:
            im_data = self.transform(im_data)

        return im_data, im_label

    def __len__(self):
        return len(self.imgs)

im_train_list = glob.glob("cifar-10-batches-py/train/*/*.png")

im_test_list = glob.glob("cifar-10-batches-py/test/*/*.png")


train_dataset = MyDataset(im_train_list,
                         transform=train_transform)
test_dataset = MyDataset(im_test_list,
                        transform =transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset,
                               batch_size=128,
                               shuffle=True,
                               num_workers=4)

test_loader = DataLoader(dataset=test_dataset,
                               batch_size=128,
                               shuffle=False,
                               num_workers=4)

print("num_of_train", len(train_dataset))
print("num_of_test", len(test_dataset))

运行结果:

num_of_train 50000
num_of_test 10000
  • 1.2 搭建模型

参考vggnet搭建一个类似于vggnet的串联结构用来处理cifar10图像分类任务

  • 标准的vggnet的输入图像尺寸是24x24,进行32维下采样之后得到7x7的特征图,然后使用fc层完成1000个类别的分类
  • 如果图片尺寸不进行任何尺度的缩放,输入的尺寸可以是32x32的,但是本次实战对图像进行了数据增强,将图像resize28x28的尺寸
import torch
import torch.nn as nn
import torch.nn.functional as F

class VGGbase(nn.Module):

    def __init__(self):
        super(VGGbase, self).__init__()

        # 3 * 28 * 28 (crop-->32, 28)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.max_pooling1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 14 * 14
        self.conv2_1 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.conv2_2 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.max_pooling2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # 7 * 7
        self.conv3_1 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.conv3_2 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.max_pooling3 = nn.MaxPool2d(kernel_size=2,
                                         stride=2,
                                         padding=1)

        # 4 * 4
        self.conv4_1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.conv4_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        self.max_pooling4 = nn.MaxPool2d(kernel_size=2,
                                         stride=2)

        # 2 * 2

        # barchsize * 512 * 2 * 2 --> batchsize * (512 * 2 * 2)
        self.fc = nn.Linear(512 * 4, 10)

    def forward(self, x):
        batchsize = x.size(0)
        out = self.conv1(x)
        out = self.max_pooling1(out)

        out = self.conv2_1(out)
        out = self.conv2_2(out)
        out = self.max_pooling2(out)

        out = self.conv3_1(out)
        out = self.conv3_2(out)
        out = self.max_pooling3(out)

        out = self.conv4_1(out)
        out = self.conv4_2(out)
        out = self.max_pooling4(out)

        out = out.view(batchsize, -1)

        out = self.fc(out)
        out = F.log_softmax(out , dim=1)

        return out
    
  • 1.3 训练模型

import torch
import torch.nn as nn
import torchvision
from vggnet import VGGbase
from load_cifar10 import train_loader, test_loader
import os
import tensorboardX


#是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epoch_num = 200
lr = 0.1
batch_size = 128
net = VGGbase().to(device)

#loss
loss_func = nn.CrossEntropyLoss()

#optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=lr)

# optimizer = torch.optim.SGD(net.parameters(), lr = lr,
#                 momentum=0.9, weight_decay=5e-4)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=10,
                                            gamma=0.9)

model_path = "models/VGGNet"
log_path = "logs/VGGNet"
if not os.path.exists(log_path):
    os.makedirs(log_path)
if not os.path.exists(model_path):
    os.makedirs(model_path)
writer = tensorboardX.SummaryWriter(log_path)

step_n = 0
for epoch in range(epoch_num):
    print(" epoch is ", epoch)


    for i, data in enumerate(train_loader):
        net.train()
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, pred = torch.max(outputs.data, dim=1)

        correct = pred.eq(labels.data).cpu().sum()
        # print("epoch is ", epoch)
        # print("train lr is ", optimizer.state_dict()["param_groups"][0]["lr"])
        # print("train step", i, "loss is:", loss.item(),
        #       "mini-batch correct is:", 100.0 * correct / batch_size)

        writer.add_scalar("train loss", loss.item(), global_step=step_n)
        writer.add_scalar("train correct",
                          100.0 * correct.item() / batch_size, global_step=step_n)

        im = torchvision.utils.make_grid(inputs)
        writer.add_image("train im", im, global_step=step_n)

        step_n += 1

    torch.save(net.state_dict(), "{}/{}.pth".format(model_path,
                                                     epoch + 1))
    scheduler.step()

    sum_loss = 0
    sum_correct = 0
    for i, data in enumerate(test_loader):
        net.eval()
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = net(inputs)
        loss = loss_func(outputs, labels)
        _, pred = torch.max(outputs.data, dim=1)
        correct = pred.eq(labels.data).cpu().sum()

        sum_loss += loss.item()
        sum_correct += correct.item()
        im = torchvision.utils.make_grid(inputs)
        writer.add_image("test im", im, global_step=step_n)

    test_loss = sum_loss * 1.0 / len(test_loader)
    test_correct = sum_correct * 100.0 / len(test_loader) / batch_size

    writer.add_scalar("test loss", test_loss, global_step=epoch + 1)
    writer.add_scalar("test correct",
                      test_correct, global_step=epoch + 1)



    print("epoch is", epoch + 1, "loss is:", test_loss,
          "test correct is:", test_correct)

writer.close()

9JQ4ZCQY3M({Q$KEN%9BFQX.png