当你用
checkpoint = torch.load("checkpoint.pth")
model.load_state_dict(checkpoint["state_dict"])
这样load一个 pretrained model 的时候,torch.load() 会默认把load进来的数据放到0卡上,这样4个进程全部会在0卡占用一部分显存。
解决的方法也很简单,就是把load进来的数据map到cpu上
checkpoint = torch.load("checkpoint.pth", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])