我们试图将一个NumPy MaskedArray子类通过pickle进行序列化和反序列化,但是遇到了问题,子类中的额外属性无法被保留。举个例子:
import numpy as np
import cPickle as pickle
from numpy import ma
class SubArray(np.ndarray):
def __new__(cls, arr, info={}):
x = np.asanyarray(arr).view(cls)
x.info = info
return x
def __array_finalize__(self, obj):
self.info = getattr(obj, 'info', {'ATTR': 'MISSING'})
return
class MSubArray(SubArray, ma.MaskedArray):
def __new__(cls, data, info={}, mask=ma.nomask, dtype=None):
subarr = SubArray(data, info)
_data = ma.MaskedArray.__new__(cls, data=subarr, mask=mask, dtype=dtype)
_data.info = subarr.info
return _data
def __array_finalize__(self, obj):
ma.MaskedArray.__array_finalize__(self, obj)
SubArray.__array_finalize__(self, obj)
return
ms = MSubArray([1, 2], info={'a': 1})
print('Pre-pickle:', ms.info, ms.data.info)
pkl = pickle.dumps(ms)
ms_from_pkl = pickle.loads(pkl)
print('Post-pickle:', ms_from_pkl.info, ms_from_pkl.data.info)
这段代码会输出如下结果:
Pre-pickle: {'a': 1} {'a': 1}
Post-pickle: {} {}
可以看出,在反序列化后,子类中的额外属性info丢失了。
2、解决方案
要解决这个问题,可以参考以下建议:
- 重写pickle方法: 我们可以重写pickle中的
__getstate__和__setstate__方法,以便在序列化和反序列化过程中保留子类中的额外属性。
import numpy as np
import cPickle as pickle
from numpy import ma
class SubArray(np.ndarray):
def __new__(cls, arr, info={}):
x = np.asanyarray(arr).view(cls)
x.info = info
return x
def __getstate__(self):
return self.info
def __setstate__(self, state):
self.info = state
def __array_finalize__(self, obj):
self.info = getattr(obj, 'info', {'ATTR': 'MISSING'})
return
class MSubArray(SubArray, ma.MaskedArray):
def __new__(cls, data, info={}, mask=ma.nomask, dtype=None):
subarr = SubArray(data, info)
_data = ma.MaskedArray.__new__(cls, data=subarr, mask=mask, dtype=dtype)
_data.info = subarr.info
return _data
def __getstate__(self):
return self.info
def __setstate__(self, state):
self.info = state
def __array_finalize__(self, obj):
ma.MaskedArray.__array_finalize__(self, obj)
SubArray.__array_finalize__(self, obj)
return
ms = MSubArray([1, 2], info={'a': 1})
print('Pre-pickle:', ms.info, ms.data.info)
pkl = pickle.dumps(ms)
ms_from_pkl = pickle.loads(pkl)
print('Post-pickle:', ms_from_pkl.info, ms_from_pkl.data.info)
这段代码将输出如下结果:
Pre-pickle: {'a': 1} {'a': 1}
Post-pickle: {'a': 1} {'a': 1}
可以看出,在反序列化后,子类中的额外属性info被保留了。
- 使用
dill库:dill是一个比pickle更强大的序列化库,它可以序列化更多的对象类型,包括子类。
import dill
class SubArray(np.ndarray):
def __new__(cls, arr, info={}):
x = np.asanyarray(arr).view(cls)
x.info = info
return x
def __array_finalize__(self, obj):
self.info = getattr(obj, 'info', {'ATTR': 'MISSING'})
return
class MSubArray(SubArray, ma.MaskedArray):
def __new__(cls, data, info={}, mask=ma.nomask, dtype=None):
subarr = SubArray(data, info)
_data = ma.MaskedArray.__new__(cls, data=subarr, mask=mask, dtype=dtype)
_data.info = subarr.info
return _data
def __array_finalize__(self, obj):
ma.MaskedArray.__array_finalize__(self, obj)
SubArray.__array_finalize__(self, obj)
return
ms = MSubArray([1, 2], info={'a': 1})
print('Pre-pickle:', ms.info, ms.data.info)
pkl = dill.dumps(ms)
ms_from_pkl = dill.loads(pkl)
print('Post-pickle:', ms_from_pkl.info, ms_from_pkl.data.info)
这段代码将输出如下结果:
Pre-pickle: {'a': 1} {'a': 1}
Post-pickle: {'a': 1} {'a': 1}
可以看出,在反序列化后,子类中的额外属性info也被保留了。
以上是两种解决方法,您可以根据自己的需要选择一种使用。