Kaggle纯模型参数保存与读取(Pytorch)

1,023 阅读1分钟

保存

  • 训练完模型后,插入一行代码 torch.save(model.state_dict(),'./model(文件名).pth') 文件格式是pth或者pt都是一样的,不过现在主流可能是pt

  • 我下载的时候用了

torch.save(model.state_dict(),'./2.pt') #只保存模型权重参数

然后就能在Kaggle笔记本的Output文件目录下看到了,点击右边三个点,下载下来

读取

  • 构建好模型后,插入读取纯模型参数
model=modelName(模型名字)()
model.load_state_dict(torch.load(model(文件名).pth))
model.eval()
  • 这里我用的是
PATH = r'data\2.pt' #参数文件的绝对路径
model = Classifier() #加载模型结构
model.load_state_dict(torch.load(PATH)) #根据模型结构,调用储存的模型参数
model.eval()
  • 当然也可以直接把路径放进去model.load_state_dict(torch.load(r'data\2.pt'))

  • 如果报错找不到pth/pt文件,检查路径是否正确,格式是否正确。建议使用绝对路径r'folder\file'或者r'file',简洁方便。

参考

[1] kaggle上模型的保存与读取----Pytorch框架

[2] pytorch模型保存方式(.pt,.pth,.pkl)