保存
-
训练完模型后,插入一行代码
torch.save(model.state_dict(),'./model(文件名).pth')
文件格式是pth
或者pt
都是一样的,不过现在主流可能是pt
-
我下载的时候用了
torch.save(model.state_dict(),'./2.pt') #只保存模型权重参数
然后就能在Kaggle笔记本的Output文件目录下看到了,点击右边三个点,下载下来
读取
- 构建好模型后,插入读取
纯模型参数
model=modelName(模型名字)()
model.load_state_dict(torch.load(model(文件名).pth))
model.eval()
- 这里我用的是
PATH = r'data\2.pt' #参数文件的绝对路径
model = Classifier() #加载模型结构
model.load_state_dict(torch.load(PATH)) #根据模型结构,调用储存的模型参数
model.eval()
-
当然也可以直接把路径放进去
model.load_state_dict(torch.load(r'data\2.pt'))
-
如果报错找不到pth/pt文件,检查路径是否正确,格式是否正确。建议使用绝对路径
r'folder\file'
或者r'file'
,简洁方便。