torch中模型保存与加载

917 阅读3分钟

参考链接

保存和加载模型 - PyTorch官方教程中文版 (pytorch123.com)

state_dict

state_dict基础内容

pytorch中的torch.nn.Module模型中的参数可以通过model.parameters()函数进行访问,此函数返回一个生成器,通过for循环查看model.parameters()结果是每层的参数数据。

for x in model.parameters():
    print(x)

#output:
Parameter containing:
tensor([[[[ 0.0908, -0.1033, -0.0175, -0.0567, -0.0543],
          [-0.0666,  0.0097,  0.0017, -0.1141,  0.1127],
          [ 0.0113, -0.1119,  0.0696, -0.0383, -0.0037],
          [-0.0199,  0.1131,  0.0044,  0.0083, -0.0113],
          [-0.0267, -0.0676,  0.0433, -0.0297,  0.0166]],
          ......

model.state_dict()返回的是一个字典(key-value)形式,key为模型每层的名称,value为对应层的参数数据。

调用state_dict()返回的形式为字典形式。除了model外,optimizer中也有参数,因此通过optimizer.state_dict()来获取参数用于保存,只有有参数的层才会在字典中显示出来。

import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torch


class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

if __name__=="__main__":
        model = TheModelClass()
        optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
        
        print("model's state_dict: ")
        for param_tensor in model.state_dict():
            print(param_tensor, "\t", model.state_dict()[param_tensor].size())

        print(optimizer.state_dict())

#output:
model's state_dict: 
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
{'state': {}, 'param_groups': [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]}

上述代码中需要注意:

  1. model.state_dict()返回字典形式,key为每层名称+权重或者偏置,value为对应的参数数据;
  2. optimizer.state_dict()返回为字典形式,第一个为'state':{},第二个为'param_groups':[]其中[]中是一个字典形式。因此在需要获取某个参数时需要注意,第二个参数字典保存在列表中。

保存与加载模型

保存与加载模型有三种方式

  • 第一种是只保存参数,保存为字典形式;
  • 第二种是将模型以及参数整体保存;
  • 第三种是使用字典的形式将模型的参数、优化器参数以及运行的环境保存为checkpoint.tar格式。

第二种代码少,但是保存的文件较大,且pickle无法保存模型,在其他项目使用或者重构后,代码可能会以各种方式中断。因此只学习其他两种。


  • 保存:torch.save(state_dict, path)将序列化对象保存在磁盘。其可以将模型或者字典使用pickle序列化而后保存。常常保存为.pth类型的文件。

  • 反序列化:torch.load(path)将文件反序列化变成字典,而后将字典形式的参数使用load_state_dict(torch.load(path))导入模型中。

  • 加载:torch.nn.Module.load_state_dict:将反序列化后的数据加载到模型。在load_state_dict()参数中有一个strict参数,当strict=False时表示导入参数中的key不需要完全与模型中的每层匹配,可以用于缺少某些键的state_dict加载或者从键的数目多于加载模型的state_dict

只保存模型的参数

如果在测试的情况下,要固定BN层以及dropout的参数,因此需要使用model.eval()

  • 保存
torch.save(model.state_dict(), "dict.pth")
  • 加载
model1 = TheModelClass()
model1.load_state_dict(torch.load("dict.pth"))
model1.eval()

保存模型、优化器参数以及环境

由于在训练时,有可能在训练的过程中出现停电或者其他意外情况,此时就需要不断的保存当前的训练结果,此训练结果包含有模型的参数、优化器参数、模型运行的环境等信息。因此可以考虑将其保存为一个字典形式,在加载后也是以字典形式访问。

  • 保存
epoch = 60
loss = 0.12
torch.save({
    'epoch': 60,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss
}, "checkpoint1.tar")
  • 加载
model = TheModelClass()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
checkpoint = torch.load("checkpoint1.tar")

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

模型加载硬件

保存在GPU,加载在GPU

torch.save(model.state_dict(), PATH)

device = torch.device("cuda") 
model = TheModelClass(*args, **kwargs) 
model.load_state_dict(torch.load(PATH)) 
model.to(device) # 确保在你提供给模型的任何输入张量上调用input = input.to(device)

当在GPU上训练并把模型保存在GPU,只需要使用model.to(torch.device('cuda')),将初始化的 model 转换为 CUDA 优化模型。

另外,请 务必在所有模型输入上使用.to(torch.device('cuda'))函数来为模型准备数据。

请注意,调用my_tensor.to(device)会在GPU上返回my_tensor的副本。 因此,请记住手动覆盖张量:my_tensor= my_tensor.to(torch.device('cuda'))