问题简述
pytorch模型的参数以tensor形式存放在内存中,我们经常需要将训练好的模型存储进磁盘中,而pickle就是一个非常常用的模型存储工具。但是,在使用pickle来保存或加载PyTorch模型时,可能会导致精度丢失。本文给出了对应bug的复现代码以及可能的解决方案。
pickle存储tensor时的bug说明
在Tensor,Storage,或者module上调用 pickle.dump 函数时,两次运行后生成的字节并不相同。导致使用pickle.dump 时生成的字节不可重现。
问题代码
f = io.BytesIO()
torch.save(torch.IntTensor([1]), f)
print(f.getvalue())
运行两次后的不同结果
-b'\x80\x03ctorch._utils\n_rebuild_tensor_v2\nq\x00(ctorch.storage\n_load_from_bytes\nq\x01C\xff\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00longq\x07K\x04uu.\x80\x02(X\x07\x00\x00\x00storageq\x00ctorch\nIntStorage\nq\x01X\x0f\x00\x00\x00140655645110640q\x02X\x03\x00\x00\x00cpuq\x03K\x01Ntq\x04Q.\x80\x02]q\x00X\x0f\x00\x00\x00140655645110640q\x01a.\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00q\x02\x85q\x03Rq\x04K\x00K\x01\x85q\x05K\x01\x85q\x06\x89ccollections\nOrderedDict\nq\x07)Rq\x08tq\tRq\n.'
+b'\x80\x03ctorch._utils\n_rebuild_tensor_v2\nq\x00(ctorch.storage\n_load_from_bytes\nq\x01C\xff\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00longq\x07K\x04uu.\x80\x02(X\x07\x00\x00\x00storageq\x00ctorch\nIntStorage\nq\x01X\x0f\x00\x00\x00140365944682560q\x02X\x03\x00\x00\x00cpuq\x03K\x01Ntq\x04Q.\x80\x02]q\x00X\x0f\x00\x00\x00140365944682560q\x01a.\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00q\x02\x85q\x03Rq\x04K\x00K\x01\x85q\x05K\x01\x85q\x06\x89ccollections\nOrderedDict\nq\x07)Rq\x08tq\tRq\n.'
运行环境
Collecting environment information...
PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.1
OS: Ubuntu 18.04.4 LTS
GCC version: Could not collect
CMake version: Could not collect
Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 418.87.00
cuDNN version: Could not collect
Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] torch==1.5.0
[pip] torchvision==0.6.0a0+82fd1c8
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.1.243 h6bb024c_0
[conda] mkl 2020.0 166
[conda] mkl-service 2.3.0 py37he904b0f_0
[conda] mkl_fft 1.0.15 py37ha843d7b_0
[conda] mkl_random 1.1.0 py37hd6b4f25_0
[conda] numpy 1.18.1 py37h4f9e942_0
[conda] numpy-base 1.18.1 py37hde5b4d6_1
[conda] pytorch 1.5.0 py3.7_cuda10.1.243_cudnn7.6.3_0 pytorch
[conda] torchvision 0.6.0 py37_cu101 pytorch
不太完美的解决方案
pick.dump生成的字节序列不相同就可能会导致模型存储/载入后tensor中的数值发生变化,并且,由于torch.save()/torch.load()仅仅时一个pickle的包装器,因此该问题在torch.save()/torch.load()上也可能出现!一种可能的方法是将需要被存储的tensor转换成numpy数组,当使用pickle存储numpy数组是不会出现这种精度损失问题。
笔者猜测原因可能是pickle有时不能很好地处理PyTorch中的张量(Tensor)数据类型。Pickle将模型的参数转换为Python中的标准数据类型(如列表和字典),并在加载模型时再转换回PyTorch的张量。这个过程中可能会丢失精度。但是具体原因还是不知
。