[bug report]pickle存储tensor时可能出现精度损失问题

719 阅读3分钟

问题简述

pytorch模型的参数以tensor形式存放在内存中,我们经常需要将训练好的模型存储进磁盘中,而pickle就是一个非常常用的模型存储工具。但是,在使用pickle来保存或加载PyTorch模型时,可能会导致精度丢失。本文给出了对应bug的复现代码以及可能的解决方案。

pickle存储tensor时的bug说明

在Tensor,Storage,或者module上调用 pickle.dump 函数时,两次运行后生成的字节并不相同。导致使用pickle.dump 时生成的字节不可重现。

问题代码

f = io.BytesIO()
torch.save(torch.IntTensor([1]), f)
print(f.getvalue())

运行两次后的不同结果

-b'\x80\x03ctorch._utils\n_rebuild_tensor_v2\nq\x00(ctorch.storage\n_load_from_bytes\nq\x01C\xff\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00longq\x07K\x04uu.\x80\x02(X\x07\x00\x00\x00storageq\x00ctorch\nIntStorage\nq\x01X\x0f\x00\x00\x00140655645110640q\x02X\x03\x00\x00\x00cpuq\x03K\x01Ntq\x04Q.\x80\x02]q\x00X\x0f\x00\x00\x00140655645110640q\x01a.\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00q\x02\x85q\x03Rq\x04K\x00K\x01\x85q\x05K\x01\x85q\x06\x89ccollections\nOrderedDict\nq\x07)Rq\x08tq\tRq\n.' 
+b'\x80\x03ctorch._utils\n_rebuild_tensor_v2\nq\x00(ctorch.storage\n_load_from_bytes\nq\x01C\xff\x80\x02\x8a\nl\xfc\x9cF\xf9 j\xa8P\x19.\x80\x02M\xe9\x03.\x80\x02}q\x00(X\x10\x00\x00\x00protocol_versionq\x01M\xe9\x03X\r\x00\x00\x00little_endianq\x02\x88X\n\x00\x00\x00type_sizesq\x03}q\x04(X\x05\x00\x00\x00shortq\x05K\x02X\x03\x00\x00\x00intq\x06K\x04X\x04\x00\x00\x00longq\x07K\x04uu.\x80\x02(X\x07\x00\x00\x00storageq\x00ctorch\nIntStorage\nq\x01X\x0f\x00\x00\x00140365944682560q\x02X\x03\x00\x00\x00cpuq\x03K\x01Ntq\x04Q.\x80\x02]q\x00X\x0f\x00\x00\x00140365944682560q\x01a.\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00q\x02\x85q\x03Rq\x04K\x00K\x01\x85q\x05K\x01\x85q\x06\x89ccollections\nOrderedDict\nq\x07)Rq\x08tq\tRq\n.'

运行环境

Collecting environment information...
PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.4 LTS
GCC version: Could not collect
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 418.87.00
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.18.1
[pip] torch==1.5.0
[pip] torchvision==0.6.0a0+82fd1c8
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.1.243             h6bb024c_0  
[conda] mkl                       2020.0                      166  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.0.15           py37ha843d7b_0  
[conda] mkl_random                1.1.0            py37hd6b4f25_0  
[conda] numpy                     1.18.1           py37h4f9e942_0  
[conda] numpy-base                1.18.1           py37hde5b4d6_1  
[conda] pytorch                   1.5.0           py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchvision               0.6.0                py37_cu101    pytorch

不太完美的解决方案

pick.dump生成的字节序列不相同就可能会导致模型存储/载入后tensor中的数值发生变化,并且,由于torch.save()/torch.load()仅仅时一个pickle的包装器,因此该问题在torch.save()/torch.load()上也可能出现!一种可能的方法是将需要被存储的tensor转换成numpy数组,当使用pickle存储numpy数组是不会出现这种精度损失问题。
笔者猜测原因可能是pickle有时不能很好地处理PyTorch中的张量(Tensor)数据类型。Pickle将模型的参数转换为Python中的标准数据类型(如列表和字典),并在加载模型时再转换回PyTorch的张量。这个过程中可能会丢失精度。但是具体原因还是不知 。