如何将 Tensor 保存下来并加载 ?

295 阅读1分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

将 Tensor 保存为 pth 文件再加载即可。这种方法对于不同 device 的服务器之间的算子测试,以及保存 size 很大的 Tensor 非常方便。

a=torch.randn([3,4],device='cpu')
torch.save(a,"a.pth")    # 保存Tensor为pth文件
b=torch.load("a.pth",map_location="cuda:0")   # 指定加载的device
print(b)

输出:

tensor([[-0.2660,  0.5144,  0.3616,  1.5060],
        [-1.0543,  2.4541, -0.3537, -0.1083],
        [ 0.0604,  2.6085, -2.9730, -1.0179]], device='cuda:0')