pytorch load_state_dict不生效解决

554 阅读2分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。


有时候我们会发现,load一个保存好的模型dict到一个新的模型对象的时候,并没有加载dict里面的数据。

一个简单例子

比如下图中,模型TestM里面定义了一个参数D,如果我不加:.cuda(),那么下面的例子是可以成功加载的,也就是说,print出来的 LoadedOrigin是相等的。

但是,如果对参数D加了 .cuda(),那么,Loaded就和Init 是一样的,并不等于Origin

结论

模型参数定义不要加.cuda(),应该统一在外面调用model.cuda().

大家可以复制下面代码进行简单测试即可。

import torch
from torch import nn

class TestM(nn.Module):
    def __init__(self):
        super(TestM, self).__init__()
        self.D = nn.Parameter(torch.randn(2, 2).float())  # 成功
        # self.D = nn.Parameter(torch.randn(2, 2).float()).cuda() # 失败
        
    def forward(self, x):
        return x

tm1 = TestM()
print("Origin: ", tm1.D)
torch.save(tm.state_dict(), 'tm_test.pth')

tm2 = TestM()
print("Init: ", tm2.D)
tm2.load_state_dict(torch.load('tm_test.pth'))
print("Loaded: ", tm2.D)

知识点

后续还有些问题可能需要注意,比如,如果一个模型太大了,上百G那种,存储又不够,或者加载的时间太长,那怎么办?同学们可以自己思考和实践一下,如果你懂了这个state_dict的原理,自己实现起来,应该是可以解决这个问题的。

在 PyTorch 中,模型的可学习参数(即权重和偏差) torch.nn.Module包含在模型的参数中 (使用 访问model.parameters())。state_dict只是一个 Python 字典对象,它将每一层映射到其参数张量。请注意,只有具有可学习参数的层(卷积层、线性层等)和注册缓冲区(batchnorm 的 running_mean)在模型的state_dict中有条目。优化器对象 ( torch.optim) 也有一个state_dict,其中包含有关优化器状态的信息,以及使用的超参数。

因为state_dict对象是 Python 字典,所以它们可以很容易地保存、更新、更改和恢复,从而为 PyTorch 模型和优化器增加了大量的模块化。

references

pytorch.org/tutorials/b…