pytorch案例-CIFAR10彩色图片识别 代码复制可运行

122 阅读3分钟
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchinfo import summary
import warnings
from datetime import datetime
# 判断是否使用gup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(device, 'device')

# 获取训练数据集
train_ds = torchvision.datasets.CIFAR10('data',
                                        train=True,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
# 获取测试数据集
test_ds = torchvision.datasets.CIFAR10('data',
                                       train=False,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)

# 数据集每批次的数量
batch_size = 32

# 构造训练集的批数据
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

# 构造测试集的批数据
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

# 获取第一批次的数据
imgs, labels = next(iter(train_dl))
print(imgs.shape)
print(labels.shape)

image.png

# 打印数据
plt.figure(figsize=(20, 5))
for i, imgs in enumerate(imgs[:20]):
    npimg = imgs.numpy().transpose((1, 2, 0))
    plt.subplot(2, 10, i + 1)
    plt.imshow(npimg, cmap=plt.cm.binary)
    plt.axis('off')
plt.show()
# 分类数量
number_classes = 10


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # 第一层卷积层
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        # 第一层池化层
        self.pool1 = nn.MaxPool2d(kernel_size=2)

        # 第二层卷积层
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3)
        # 第二层卷积层
        self.pool2 = nn.MaxPool2d(kernel_size=2)

        # 第三层卷积层
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        # 第三层卷积层
        self.pool3 = nn.MaxPool2d(kernel_size=2)

        # 第一个分类器
        self.fc1 = nn.Linear(512, 256)
        # 第二个分类器
        self.fc2 = nn.Linear(256, number_classes)

    # 前向传播
    def forward(self, x):
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))

        x = torch.flatten(x, start_dim=1)

        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return x


# 创建模型
model = Model().to(device)
# 打印模型信息
summary(model)
```
# 创建损失函数
loss_fn = nn.CrossEntropyLoss()

# 学习率
lr = 1e-2

# 优化器
opt = torch.optim.SGD(model.parameters(), lr=lr)
```

1740139824654.jpg

# 训练函数
def train(dataloader, model, loss_fn, optimizer):
    # 数据集大小
    size = len(dataloader.dataset)

    # 批次大小,数据集大小 除以 每批次数量
    batch_num = len(dataloader)

    #所有数据总的loss,模型预测争取的数量
    total_loss, pre_right_count = 0, 0

    for X, y in dataloader:
        # 数据移动到gpu or cpu
        X, y = X.to(device), y.to(device)

        # 调用模型,执行前向传播获取预测结果
        pred = model(X)

        # 调用损失函数
        loss = loss_fn(pred, y)
        # 优化器的梯度重置
        optimizer.zero_grad()
        # 反向传播
        loss.backward()
        # 更新参数
        optimizer.step()

        # 获取当前批次的损失值
        total_loss += loss.item()

        # argmax(1) 中的1表示第二个维度 (argmax(0)表示第一个维度)
        pre_right_count += (pred.argmax(1) == y).type(torch.float).sum().item()

    # 训练时预测正确率
    train_acc = pre_right_count / size
    # 训练时平均损失值
    train_loss = total_loss / batch_num

    return train_acc, train_loss
# 测试函数
def test(dataloader, model, loss_fn):
    # 数据集大小
    size = len(dataloader.dataset)
    # 多少个批次=size/batch_size
    num_batches = len(dataloader)
    # 总损失值,预测正确数量
    total_loss, pred_right_count = 0, 0

    # 不计算梯度
    with torch.no_grad():
        for imgs, target in dataloader:
            # 数据集标签
            imgs, target = imgs.to(device), target.to(device)

            # 预测结果
            pred = model(imgs)

            # 调用损失函数
            loss = loss_fn(pred, target)

            # 预测争取数量
            pred_right_count += (pred.argmax(1) == target).type(torch.float).sum().item()

            # 汇总到总损失值
            total_loss += loss.item()

    # 测试时平均损失
    test_loss = total_loss / num_batches
    # 测试正确率
    test_acc = pred_right_count / size

    return test_acc, test_loss
# 开始训练
# 训练多少轮
epochs = 10
# 追踪训练损失
train_loss = []
# 追踪训练准确率
train_acc = []
# 追踪测试损失
test_loss = []
# 追踪测试准确率
test_acc = []

# 多轮次训练
for epoch in range(epochs):
    # 开始训练模式
    model.train()

    # 开启训练 训练模式中,Dropout层会随机将向量中部分元素置为零,防止过拟合
    e_train_acc, e_train_loss = train(train_dl, model, loss_fn, opt)
    print(f'epoch:{epoch + 1} Train Complete')

    # 开启评估模式,开启后Dropout层会失效,所有神经元都参与计算
    model.eval()
    # 测试
    e_test_acc, e_test_loss = test(test_dl, model, loss_fn)

    train_acc.append(e_train_acc)
    train_loss.append(e_train_loss)

    test_acc.append(e_test_acc)
    test_loss.append(e_test_loss)

    print('Epoch:{:2d},Train_acc:{:.1f}%,Train_loss:{:.3f} Test_acc:{:.1f}% Test_loss:{:.3f}'.format(epoch + 1,
                                                                                                     e_train_acc * 100,
                                                                                                     e_train_loss,
                                                                                                     e_test_acc * 100,
                                                                                                     e_test_loss))

print('Train Done')

1740140265854.jpg

# 后续为训练结果可视化

warnings.filterwarnings('ignore')
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['SimHei']
# 用来正常显示负号
plt.rcParams['axes.unicode_minus'] = False
# 分辨率
plt.rcParams['figure.dpi'] = 100

# 获取当前时间
current_time = datetime.now()

epochs_range = range(epochs)

plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)

plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time)

plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

1740140258669.jpg