复现ResNet
一、Model模块
import torch
import torch.nn as nn
/home/zhoumingyao/anaconda/yes/envs/pytorch/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
1. 基础残差块
这是18层和34层的ResNet。和其他残差网络的区别在于卷积核的个数。
class BasicBlock(nn.Module):
# 用于标记卷积核的个数有无发生变化
expansion = 1
# 定义下采样标志,因为基础块既要有直接shortcut,又要有间接shortcut
def __init__(self, in_c, out_c, stride = 1, downsample = None, **kwargs):
super(BasicBlock, self).__init__()
# 因为后面我们会使用BN,所以将偏置标志设置为 False
# 这里的卷积层采用的是Same卷积,输入输出的大小一致
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size = 3, stride = stride, padding = 1, bias = False)
self.bn1 = nn.BatchNorm2d(out_c)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_c, out_c,kernel_size = 3, stride = 1, padding = 1, bias = False)
self.bn2 = nn.BatchNorm2d(out_c)
self.downsample = downsample
def forward(self, x):
# identity记录shortcut上的值
identity = x
if self.downsample is not None:
identity = self.downsample(x)
# 前向传播的输入依然是x,identity是需要被加上的值
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
# 残差块特色,需要将原始输入加上
out += identity
y = self.relu(out)
# 别忘了返回结果
return y
2. 瓶颈类
其他网络结构的基础残差块。
class Bottleneck(nn.Module):
# 每层的第一个卷积层和最后一个卷积层的通道数量关系
expansion = 4
# 注意下面的组数
def __init__(self, in_c, out_c, stride = 1, downsample = False, groups = 1, width_per_group = 64):
super(Bottleneck,self).__init__()
width = int(out_channel * (width_per_group / 64.)) * groups
# 用于放缩图像
self.conv1 = nn.Conv2d(in_c,width,stride = 1,kernel_size = 1, bias = False)
self.bn1 = nn.BatchNorm2d(width)
# Conv2的步长可能会随着shortcut而发生变化,所以要传入自己的stride
self.conv2 = nn.Conv2d(width,width,kernel_size = 3, stride = stride, padding = 1, bias = False, groups = groups)
self.bn2 = nn.BatchNorm2d(width)
self.conv3 = nn.Conv2d(width,out_c*self.expansion,kernel_size = 1, stride = 1, bias = False)
self.bn3 = nn.BatchNorm2d(out_c*self.expansion)
self.relu = nn.ReLU(inplace = True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample:
identity = self.downsample(x)
out = self.conv1(identity)
out = self.bn1(x)
out = self.conv2(out)
out = self.bn2(out)
out = self.conv3(out)
out = self.bn3(out)
# 瓶颈类也需要残差连接
out += identity
y = self.relu(out)
# 别忘了返回结果!
return y
3. ResNet类
class ResNet(nn.Module):
# block根据我们需要建立的网络传入基础的残差块
# block_num就是每层(Conv2.x - Conv5.x)每种残差块的堆叠个数一般是[3, 4 , 6 , 3]
# include_top
def __init__(self, block, block_num, num_classes = 1000, include_top = True, groups = 1 ,width_per_group = 64):
super(ResNet,self).__init__()
self.include_top = include_top
self.groups = groups
self.width_per_group = width_per_group
# 所有网络类型的输入通道数都是64
self.in_c = 64
# 注意输入通道数in_c是我们自定义的类变量,而非模型初始化变量
# 3对应着RGB图像
self.conv1 = nn.Conv2d(3, self.in_c ,kernel_size = 7 , stride = 2, padding = 3, bias = False)
self.bn1 = nn.BatchNorm2d(self.in_c)
self.relu = nn.ReLU(inplace = True)
self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
self.layer1 = self._make_layer(block, 64, block_num[0])
self.layer2 = self._make_layer(block, 128, block_num[1],stride = 2)
self.layer3 = self._make_layer(block, 256, block_num[2],stride = 2)
self.layer4 = self._make_layer(block, 512, block_num[3],stride = 2)
if self.include_top:
# 不管输入维度为多少,输出维度一定为(1,1)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
# 判别是否属于网络组件
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# 这里的步长默认为1
# channels是所有块的第一个卷积层通道数,18和34是一致的,而其他则是4倍关系
def _make_layer(self, block, channels, num_blocks, stride = 1):
# 先声明,不然就是先调用了
downsample = None
if stride!=1 or self.in_c != block.expansion*channels:
downsample = nn.Sequential(
nn.Conv2d(self.in_c, channels * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channels * block.expansion)
)
layers = []
layers.append(block(self.in_c,
channels,
downsample=downsample,
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group))
# 重新定义输入通道
# 从1开始是因为上面已经将一层加入了layers
self.in_c = channels * block.expansion
for _ in range(1, num_blocks):
layers.append(block(self.in_c,
channels,
groups=self.groups,
width_per_group=self.width_per_group))
return nn.Sequential(*layers)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.maxpool(out)
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
if self.include_top:
out = self.avgpool(out)
# 输入FC之前需要展平
out = torch.flatten(out, 1)
out = self.fc(out)
return out
# 需要下载此网络的预训练参数
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
def resnet34(num_classes = 1000, include_top = True):
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes = num_classes,include_top = include_top)
二、导入预训练权重
import os
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight_path = './resnet34-333f7ec4.pth'
# 实例化原模型
model = resnet34()
model.load_state_dict(torch.load(weight_path))
in_c = model.fc.in_features
# 重构原模型的分类头
model.fc = nn.Linear(in_c, 5)
from torchsummary import summary
三、训练模块
import sys
import json
import torch.optim as optim
from torchvision import transforms,datasets
from tqdm import tqdm
def train():
device = torch.device("cuda:0" if torch.cuda.is_available else "cpu")
# 标准操作
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
# Resize(长宽比不变,把最小边放缩为256)
"val": transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}
img_path = './flower_data'
batch_size = 32
nw = 8
trainset = datasets.ImageFolder(root = img_path+'/train',transform = data_transform['train'])
num_train = len(trainset)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = batch_size, num_workers = nw, shuffle = True)
validset = datasets.ImageFolder(root = img_path+'/val',transform = data_transform['val'])
num_val = len(validset)
validloader = torch.utils.data.DataLoader(validset, batch_size = batch_size, num_workers = nw,shuffle = True)
print('数据导入完毕,我们使用{}个训练集进行训练,{}个验证集进行验证!'.format(num_train,num_val))
# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = trainset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
json_file.write(json_str)
model = resnet34()
model.load_state_dict(torch.load(weight_path))
in_c = model.fc.in_features
# 重构原模型的分类头
model.fc = nn.Linear(in_c, 5)
model.to(device)
# 损失函数是nn模块内的
Loss = nn.CrossEntropyLoss()
# 分理出需要训练的参数
parameters_required = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(parameters_required, lr = 0.0001)
Epoch = 10
best_acc = 0
model_path = './model.pth'
num_batch = len(trainloader)
for epoch in range(Epoch):
model.train()
running_loss = 0
train_bar = tqdm(trainloader,file = sys.stdout)
for step,data in enumerate(train_bar):
train_imgs,train_labels = data
optimizer.zero_grad()
y = model(train_imgs.to(device))
loss = Loss(y, train_labels.to(device))
loss.backward()
optimizer.step()
running_loss += loss.item()
train_bar.desc= "train epoch[{}/{}] loss:{:.3f}".format(epoch+1, Epoch, loss)
model.eval()
acc = 0
with torch.no_grad():
valid_bar = tqdm(validloader, file = sys.stdout)
for data in valid_bar:
valid_imgs,valid_labels = data
y = model(valid_imgs.to(device))
predict = torch.max(y,dim = 1)[1]
acc += torch.eq(predict, valid_labels.to(device)).sum().item()
valid_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,Epoch)
# 应该是基于验证集的长度
acc = acc/num_val
print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
(epoch + 1, running_loss / num_batch, acc))
if acc > best_acc:
best_acc = acc
torch.save(model.state_dict(), model_path)
train()
数据导入完毕,我们使用3306个训练集进行训练,364个验证集进行验证!
train epoch[1/10] loss:0.156: 100%|███████████| 104/104 [00:07<00:00, 14.75it/s]
valid epoch[1/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 16.60it/s]
[epoch 1] train_loss: 0.452 val_accuracy: 0.951
train epoch[2/10] loss:0.684: 100%|███████████| 104/104 [00:05<00:00, 17.91it/s]
valid epoch[2/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.29it/s]
[epoch 2] train_loss: 0.268 val_accuracy: 0.945
train epoch[3/10] loss:0.407: 100%|███████████| 104/104 [00:05<00:00, 17.99it/s]
valid epoch[3/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.66it/s]
[epoch 3] train_loss: 0.235 val_accuracy: 0.951
train epoch[4/10] loss:0.380: 100%|███████████| 104/104 [00:05<00:00, 17.97it/s]
valid epoch[4/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.90it/s]
[epoch 4] train_loss: 0.197 val_accuracy: 0.953
train epoch[5/10] loss:0.047: 100%|███████████| 104/104 [00:05<00:00, 17.92it/s]
valid epoch[5/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.25it/s]
[epoch 5] train_loss: 0.189 val_accuracy: 0.940
train epoch[6/10] loss:0.259: 100%|███████████| 104/104 [00:05<00:00, 17.78it/s]
valid epoch[6/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 16.82it/s]
[epoch 6] train_loss: 0.166 val_accuracy: 0.967
train epoch[7/10] loss:0.249: 100%|███████████| 104/104 [00:05<00:00, 17.83it/s]
valid epoch[7/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.45it/s]
[epoch 7] train_loss: 0.163 val_accuracy: 0.967
train epoch[8/10] loss:0.242: 100%|███████████| 104/104 [00:05<00:00, 17.93it/s]
valid epoch[8/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.63it/s]
[epoch 8] train_loss: 0.142 val_accuracy: 0.967
train epoch[9/10] loss:0.171: 100%|███████████| 104/104 [00:05<00:00, 17.69it/s]
valid epoch[9/10]: 100%|████████████████████████| 12/12 [00:00<00:00, 17.54it/s]
[epoch 9] train_loss: 0.143 val_accuracy: 0.953
train epoch[10/10] loss:0.141: 100%|██████████| 104/104 [00:05<00:00, 17.90it/s]
valid epoch[10/10]: 100%|███████████████████████| 12/12 [00:00<00:00, 17.79it/s]
[epoch 10] train_loss: 0.156 val_accuracy: 0.967
四、测试模块
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
%matplotlib inline
test_transform = transforms.Compose([transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
# 打开图片的同时需要进行转换
test_img = Image.open('./my_test/pvzsunflower.png')
plt.imshow(test_img)
test_img = test_transform(test_img)
test_img = torch.unsqueeze(test_img, dim = 0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
model = resnet34(num_classes= 5).to(device)
model_path = './model.pth'
# map_location是torch.load的参数
model.load_state_dict(torch.load(model_path,map_location = device))
model.eval()
with torch.no_grad():
out = torch.squeeze(model(test_img.to(device))).cpu()
predict = torch.softmax(out,dim = 0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
test_img = Image.open('./my_test/pvzsunflower.png')
plt.imshow(test_img)
plt.show()
class: daisy prob: 0.127
class: dandelion prob: 0.0162
class: roses prob: 0.716
class: sunflowers prob: 0.0246
class: tulips prob: 0.116