本文已参与「新人创作礼」活动,一起开启掘金创作之路。
有时候我们会发现,load一个保存好的模型dict到一个新的模型对象的时候,并没有加载dict里面的数据。
一个简单例子
比如下图中,模型TestM里面定义了一个参数D,如果我不加:.cuda(),那么下面的例子是可以成功加载的,也就是说,print出来的 Loaded和Origin是相等的。
但是,如果对参数D加了 .cuda(),那么,Loaded就和Init 是一样的,并不等于Origin。
结论
模型参数定义不要加.cuda(),应该统一在外面调用model.cuda().
大家可以复制下面代码进行简单测试即可。
import torch
from torch import nn
class TestM(nn.Module):
def __init__(self):
super(TestM, self).__init__()
self.D = nn.Parameter(torch.randn(2, 2).float()) # 成功
# self.D = nn.Parameter(torch.randn(2, 2).float()).cuda() # 失败
def forward(self, x):
return x
tm1 = TestM()
print("Origin: ", tm1.D)
torch.save(tm.state_dict(), 'tm_test.pth')
tm2 = TestM()
print("Init: ", tm2.D)
tm2.load_state_dict(torch.load('tm_test.pth'))
print("Loaded: ", tm2.D)
知识点
后续还有些问题可能需要注意,比如,如果一个模型太大了,上百G那种,存储又不够,或者加载的时间太长,那怎么办?同学们可以自己思考和实践一下,如果你懂了这个state_dict的原理,自己实现起来,应该是可以解决这个问题的。
在 PyTorch 中,模型的可学习参数(即权重和偏差) torch.nn.Module包含在模型的参数中 (使用 访问model.parameters())。state_dict只是一个 Python 字典对象,它将每一层映射到其参数张量。请注意,只有具有可学习参数的层(卷积层、线性层等)和注册缓冲区(batchnorm 的 running_mean)在模型的state_dict中有条目。优化器对象 ( torch.optim) 也有一个state_dict,其中包含有关优化器状态的信息,以及使用的超参数。
因为state_dict对象是 Python 字典,所以它们可以很容易地保存、更新、更改和恢复,从而为 PyTorch 模型和优化器增加了大量的模块化。