pytorch-load_state_dict占用额外显存

597 阅读1分钟

当你用

checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])

这样load一个 pretrained model 的时候,torch.load() 会默认把load进来的数据放到0卡上,这样4个进程全部会在0卡占用一部分显存。
解决的方法也很简单,就是把load进来的数据map到cpu上

checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])

在这里插入图片描述
在这里插入图片描述