一起养成写作习惯!这是我参与「掘金日新计划 · 4 月更文挑战」的第6天,点击查看活动详情。
一.概念
模型=网络结构+网络参数 网络结构:VGG ,RASNet…… 网络参数:网络结构中kernel,weight之类的数据
二.怎么得到模型
1.下载别人的网络结构
我们有时候会借鉴别人的网络,大多数情况下他们会在GitHub中放上已经训练好的模型,这个时候,你就可以下载下来直接用。 ps:大多都在README的文件中会放一个云盘的链接,点链接下载。
2.保存自己的网络模型
if ite_num % save_frq == 0:
print(model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
torch.save(net, model_dir + model_name+"_bce_itr_%d_train_%3f_tar_%3f.pth" % (ite_num, running_loss / ite_num4val, running_tar_loss / ite_num4val))
我们可以设置一个阈值,规定都是个跑多少个ite的时候保存模型。 保存模型的时候最好把模型的相关数据一起加上去,便于后期的查看。
三.保存与加载
1.保存的是模型还是参数
保存模型:torch.save(net,path) 保存参数:torch.save(net.state_dict(),path)
个人觉的保存模型比保存参数好 因为模型包含参数,可以从模型中读取参数(下面第3点介绍)。 而且保存的pth大小差距不大。
2.怎么加载模型
下载模型:net=torch.load("D:/1.pth",map_location='cpu')
加载参数:
net=VGG() #首先你要得到相关网络的模型,里面的参数会自动随机初始化`
net.load_state_dict(torch.load("D:/2.pth",map_location='cpu) #更新得到更优的网络
3.怎么由模型得到数据
shu_ju=net.state_dict()
四.小问题
如果自己的网络与模型的网络不一样怎么办/怎么训练改进后的网络?
先冻结参数,然后修改优化器。
冻结参数有两种方法: 1.在原来的参数基础上直接与改进的参数一起训练。 2.冻结原来的数据,训练改机的网络一段时间后再解冻一起训练。
粗暴训练法
这个时候就要用三.2中的加载参数了,不同的是,我们在load_state_dict的时候添加了False的关键字,表示不严格的加载,即只加载关键字相同的参数(这里的False的F记得大写。。。踩过坑)。
net=VGG() #首先你要得到相关网络的模型,里面的参数会自动随机初始化`
net.load_state_dict(torch.load("D:/2.pth",map_location='cpu,False) #更新得到更优的网络
这样就可以把net扔进循环开始训练了
冻结训练法
怎么冻结?——> 神经网络如何训练?——> 依靠方向传播 ——> 梯度
网络中有一个requires_grad的参数的可以更改,如果为False的话,就可以达到冻结效果。
for p in self.parameters():
p.requires_grad = False
比如加载了resnet预训练模型之后,在resenet的基础上连接了新的模快,resenet模块那部分可以先暂时冻结不更新,只更新其他部分的参数,那么可以在下面加入上面那句话。
class RESNET_MF(nn.Module):
def __init__(self, model, pretrained):
super(RESNET_MF, self).__init__()
self.resnet = model(pretrained)
for p in self.parameters():
p.requires_grad = False
self.f = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.g = SpectralNorm(nn.Conv2d(2048, 512, 1))
self.h = SpectralNorm(nn.Conv2d(2048, 2048, 1))
...
我个人认为也可以这样: 法一:
net=load("D:/2.pth",map_loaction='cpu')
for key, value in net.named_parameters():
value.requires_grad = False
先直接把原来的模型中的requires_grad 全改为False,然后再更新到新的网络中。。
法二: 直接更新网络
net.load_state_dirc(torch.load(D:/1.pth),False)
#跳过改进的地方
for key, value in net.named_parameters():
if not ("beta" in key or "gamma" in key or 'alpha' in key):
value.requires_grad = True
修改优化器
在优化器中添加:
filter(lambda p: p.requires_grad, model.parameters())
用于过滤冻结的参数
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, betas=(0.9, 0.999),
eps=1e-08, weight_decay=1e-5)
训练一段时间后,我们就可以解冻,然后进行微调网络。