实例二 - ResNet

251 阅读6分钟

复现ResNet

一、Model模块

import torch
import torch.nn as nn
/home/zhoumingyao/anaconda/yes/envs/pytorch/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

1. 基础残差块

这是18层和34层的ResNet。和其他残差网络的区别在于卷积核的个数。

class BasicBlock(nn.Module):
    # 用于标记卷积核的个数有无发生变化
    expansion = 1
    # 定义下采样标志,因为基础块既要有直接shortcut,又要有间接shortcut
    def __init__(self, in_c, out_c, stride = 1, downsample = None, **kwargs):
        super(BasicBlock, self).__init__()
        # 因为后面我们会使用BN,所以将偏置标志设置为 False
        # 这里的卷积层采用的是Same卷积,输入输出的大小一致
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size = 3, stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_c, out_c,kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_c)
        
        self.downsample = downsample
        
    def forward(self, x):
        # identity记录shortcut上的值
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)
        # 前向传播的输入依然是x,identity是需要被加上的值
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        # 残差块特色,需要将原始输入加上
        out += identity
        
        y = self.relu(out)
        # 别忘了返回结果
        return y

2. 瓶颈类

其他网络结构的基础残差块。
class Bottleneck(nn.Module):
    # 每层的第一个卷积层和最后一个卷积层的通道数量关系
    expansion = 4
    # 注意下面的组数
    def __init__(self, in_c, out_c, stride = 1, downsample = False, groups = 1, width_per_group = 64):
        super(Bottleneck,self).__init__()
        width = int(out_channel * (width_per_group / 64.)) * groups 
        
        # 用于放缩图像
        self.conv1 = nn.Conv2d(in_c,width,stride = 1,kernel_size = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(width)
        
        # Conv2的步长可能会随着shortcut而发生变化,所以要传入自己的stride
        self.conv2 = nn.Conv2d(width,width,kernel_size = 3, stride = stride, padding = 1, bias = False, groups = groups)
        self.bn2 = nn.BatchNorm2d(width)
        
        self.conv3 = nn.Conv2d(width,out_c*self.expansion,kernel_size = 1, stride = 1,  bias = False)
        self.bn3 = nn.BatchNorm2d(out_c*self.expansion)
        self.relu = nn.ReLU(inplace = True)
        
        self.downsample = downsample
    def forward(self, x):
        identity = x
        if self.downsample:
            identity = self.downsample(x)
        out = self.conv1(identity)
        out = self.bn1(x)
        
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out = self.conv3(out)
        out = self.bn3(out)
        # 瓶颈类也需要残差连接
        out += identity
        y = self.relu(out)
        # 别忘了返回结果!
        return y

3. ResNet类

class ResNet(nn.Module):
    # block根据我们需要建立的网络传入基础的残差块
    # block_num就是每层(Conv2.x - Conv5.x)每种残差块的堆叠个数一般是[3, 4 , 6 , 3]
    # include_top
    def __init__(self, block, block_num, num_classes = 1000, include_top = True, groups = 1 ,width_per_group = 64):
        super(ResNet,self).__init__()
        self.include_top = include_top
        self.groups = groups
        self.width_per_group = width_per_group
        # 所有网络类型的输入通道数都是64
        self.in_c = 64
        # 注意输入通道数in_c是我们自定义的类变量,而非模型初始化变量
        # 3对应着RGB图像
        self.conv1 = nn.Conv2d(3, self.in_c ,kernel_size = 7 , stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(self.in_c)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        
        self.layer1 = self._make_layer(block, 64, block_num[0])
        self.layer2 = self._make_layer(block, 128, block_num[1],stride = 2)
        self.layer3 = self._make_layer(block, 256, block_num[2],stride = 2)
        self.layer4 = self._make_layer(block, 512, block_num[3],stride = 2)
        
        if self.include_top:
            # 不管输入维度为多少,输出维度一定为(1,1)
            self.avgpool = nn.AdaptiveAvgPool2d((1,1))
            self.fc = nn.Linear(512 * block.expansion, num_classes)
        # 判别是否属于网络组件
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
    # 这里的步长默认为1    
    # channels是所有块的第一个卷积层通道数,18和34是一致的,而其他则是4倍关系
    def _make_layer(self, block, channels, num_blocks, stride = 1):
        # 先声明,不然就是先调用了
        downsample = None
        if stride!=1 or self.in_c != block.expansion*channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_c, channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(channels * block.expansion)
            )
        layers = []
        layers.append(block(self.in_c,
                            channels,
                            downsample=downsample,
                            stride=stride,
                            groups=self.groups,
                            width_per_group=self.width_per_group))
        # 重新定义输入通道
        # 从1开始是因为上面已经将一层加入了layers
        self.in_c = channels * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_c,
                                channels,
                                groups=self.groups,
                                width_per_group=self.width_per_group))
        return nn.Sequential(*layers)
    def forward(self, x):
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.maxpool(out)
        
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        if self.include_top:
            out = self.avgpool(out)
            # 输入FC之前需要展平
            out = torch.flatten(out, 1)
            out = self.fc(out)
        
        return out
    
# 需要下载此网络的预训练参数
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
def resnet34(num_classes = 1000, include_top = True):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes = num_classes,include_top = include_top)

二、导入预训练权重

import os

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight_path = './resnet34-333f7ec4.pth'
# 实例化原模型
model = resnet34()
model.load_state_dict(torch.load(weight_path))
in_c = model.fc.in_features
# 重构原模型的分类头
model.fc = nn.Linear(in_c, 5)
from torchsummary import summary

三、训练模块

import sys
import json
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm

def train():
    device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
    # 标准操作
    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        # Resize(长宽比不变,把最小边放缩为256)
        "val": transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
    img_path = './flower_data'
    batch_size = 32
    nw = 8
    trainset = datasets.ImageFolder(root = img_path+'/train',transform = data_transform['train'])
    num_train = len(trainset)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, num_workers = nw, shuffle = True)
    validset = datasets.ImageFolder(root = img_path+'/val',transform = data_transform['val'])
    num_val = len(validset)
    validloader = torch.utils.data.DataLoader(validset, batch_size = batch_size, num_workers = nw,shuffle = True)
    print('数据导入完毕,我们使用{}个训练集进行训练,{}个验证集进行验证!'.format(num_train,num_val))
    # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = trainset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())
    # write dict into json file
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)
    
    model = resnet34()
    model.load_state_dict(torch.load(weight_path))
    in_c = model.fc.in_features
    # 重构原模型的分类头
    model.fc = nn.Linear(in_c, 5)
    model.to(device)
    # 损失函数是nn模块内的
    Loss = nn.CrossEntropyLoss()
    
    # 分理出需要训练的参数
    parameters_required = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.Adam(parameters_required, lr = 0.0001)
    Epoch = 10
    best_acc = 0
    model_path = './model.pth'
    
    num_batch = len(trainloader)
    
    
    for epoch in range(Epoch):
        model.train()
        running_loss = 0
        train_bar = tqdm(trainloader,file = sys.stdout)
        for step,data in enumerate(train_bar):
            train_imgs,train_labels = data
            optimizer.zero_grad()
            y = model(train_imgs.to(device))
            loss = Loss(y, train_labels.to(device))
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            train_bar.desc= "train epoch[{}/{}] loss:{:.3f}".format(epoch+1, Epoch, loss)
        model.eval()
        acc = 0
        with torch.no_grad():
            valid_bar = tqdm(validloader, file = sys.stdout)
            for data in valid_bar:
                valid_imgs,valid_labels = data
                y = model(valid_imgs.to(device))
                predict = torch.max(y,dim = 1)[1]
                acc += torch.eq(predict, valid_labels.to(device)).sum().item()
                valid_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,Epoch)
        # 应该是基于验证集的长度
        acc = acc/num_val
        
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / num_batch, acc))

        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), model_path)
train()
数据导入完毕,我们使用3306个训练集进行训练,364个验证集进行验证!
train epoch[1/10] loss:0.156: 100%|███████████| 104/104 [00:07<00:00, 14.75it/s]
valid epoch[1/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 16.60it/s]
[epoch 1] train_loss: 0.452  val_accuracy: 0.951
train epoch[2/10] loss:0.684: 100%|███████████| 104/104 [00:05<00:00, 17.91it/s]
valid epoch[2/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.29it/s]
[epoch 2] train_loss: 0.268  val_accuracy: 0.945
train epoch[3/10] loss:0.407: 100%|███████████| 104/104 [00:05<00:00, 17.99it/s]
valid epoch[3/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.66it/s]
[epoch 3] train_loss: 0.235  val_accuracy: 0.951
train epoch[4/10] loss:0.380: 100%|███████████| 104/104 [00:05<00:00, 17.97it/s]
valid epoch[4/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.90it/s]
[epoch 4] train_loss: 0.197  val_accuracy: 0.953
train epoch[5/10] loss:0.047: 100%|███████████| 104/104 [00:05<00:00, 17.92it/s]
valid epoch[5/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.25it/s]
[epoch 5] train_loss: 0.189  val_accuracy: 0.940
train epoch[6/10] loss:0.259: 100%|███████████| 104/104 [00:05<00:00, 17.78it/s]
valid epoch[6/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 16.82it/s]
[epoch 6] train_loss: 0.166  val_accuracy: 0.967
train epoch[7/10] loss:0.249: 100%|███████████| 104/104 [00:05<00:00, 17.83it/s]
valid epoch[7/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.45it/s]
[epoch 7] train_loss: 0.163  val_accuracy: 0.967
train epoch[8/10] loss:0.242: 100%|███████████| 104/104 [00:05<00:00, 17.93it/s]
valid epoch[8/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.63it/s]
[epoch 8] train_loss: 0.142  val_accuracy: 0.967
train epoch[9/10] loss:0.171: 100%|███████████| 104/104 [00:05<00:00, 17.69it/s]
valid epoch[9/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.54it/s]
[epoch 9] train_loss: 0.143  val_accuracy: 0.953
train epoch[10/10] loss:0.141: 100%|██████████| 104/104 [00:05<00:00, 17.90it/s]
valid epoch[10/10]: 100%|███████████████████████| 12/12 [00:00<00:00, 17.79it/s]
[epoch 10] train_loss: 0.156  val_accuracy: 0.967

四、测试模块

from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
test_transform = transforms.Compose([transforms.Resize(256),
                                   transforms.CenterCrop(224),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

# 打开图片的同时需要进行转换
test_img = Image.open('./my_test/pvzsunflower.png')
plt.imshow(test_img)
test_img = test_transform(test_img)

test_img = torch.unsqueeze(test_img, dim = 0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)

with open(json_path, "r") as f:
    class_indict = json.load(f)

output_20_0.png

model = resnet34(num_classes= 5).to(device)
model_path = './model.pth'
# map_location是torch.load的参数
model.load_state_dict(torch.load(model_path,map_location = device))
model.eval()


with torch.no_grad():
    out = torch.squeeze(model(test_img.to(device))).cpu()
    predict = torch.softmax(out,dim = 0)
    predict_cla = torch.argmax(predict).numpy()

print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],
                                                 predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
    print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],
                                                  predict[i].numpy()))
test_img = Image.open('./my_test/pvzsunflower.png')
plt.imshow(test_img)
plt.show()
class: daisy        prob: 0.127
class: dandelion    prob: 0.0162
class: roses        prob: 0.716
class: sunflowers   prob: 0.0246
class: tulips       prob: 0.116



output_21_1.png