开启掘金成长之旅!这是我参与「掘金日新计划 · 12 月更文挑战」的第2天,点击查看活动详情
本文是 pytorch 官方教程学习笔记,涉及模型保存和加载,文末还简述了模型转换相关内容
模型保存和加载
模型保存
import torch
import torchvision.models as models
model = models.mobilenet_v2(pretrained=True)
torch.save(model.state_dict(), 'E:\Downloads\model_weights.pth')
pretrained=True 表示下载并加载模型参数,官方文档中以 VGG16 为例,但是 VGG 太大了,在此以轻量化网络 MobileNet 为例
pytorch 模型以字典形式保存参数,被称为 state_dict
torch.save 方法用来保存模型,第一个参数表示保存模型的参数(但不保存模型的架构),第二个参数为文件路径
模型加载
model = models.mobilenet_v2()
model.load_state_dict(torch.load('E:\Downloads\model_weights.pth'))
model.eval()
创建一个不加载参数的空壳模型,所谓空壳是指只有下图所示的模型架构,但是参数值是初始的
使用
load_state_dict 加载参数值
请确保在推理前调用 model.eval() 方法,将 dropout 和 batch normalization layers 设置为评估模式。如果不这样做,将产生不一致的推理结果。
更简便的方式
model = models.mobilenet_v2(pretrained=True)
torch.save(model, 'E:\Downloads\model_weights2.pth')
model = torch.load('E:\Downloads\model_weights2.pth')
model.eval()
前面的方式只保存参数值而不保存网络结构
可以直接保存模型的架构和参数值,同样,加载时直接读出完整的模型
这种方法在序列化模型时使用 Python pickle 模块,因此它依赖于实际的类定义在加载模型时可用。
pytorch 模型转 ONNX
由多种深度学习框架训练出的模型,想要统一部署到移动端等设备上进行推理,可以将模型转换为一种中间格式,而 ONNX 就是这样一种过度的桥梁。这里只介绍 torch.onnx.export 函数
model = models.mobilenet_v2(pretrained=True)
model.eval()
inputData = torch.rand(1, 3, 224, 224)
torch.onnx.export(model, inputData, 'E:\Downloads\model_weights.onnx')
在将模型转为 ONNX 的过程中,除了提供模型和输出文件路径外,必须提供一个输入数据(与模型输入维度一致,我提供的是 MobileNetV2 所需输入维度的数据)
原因:需要一个输入帮助把 pytorch 的动态计算图转化为 ONNX 需要的静态计算图
其他内容可以参考从 pytorch 转换到 onnx