如何对NumPy MaskedArray子类进行序列化

69 阅读2分钟

我们试图将一个NumPy MaskedArray子类通过pickle进行序列化和反序列化,但是遇到了问题,子类中的额外属性无法被保留。举个例子:

import numpy as np
import cPickle as pickle
from numpy import ma


class SubArray(np.ndarray):

    def __new__(cls, arr, info={}):
        x = np.asanyarray(arr).view(cls)
        x.info = info
        return x

    def __array_finalize__(self, obj):
        self.info = getattr(obj, 'info', {'ATTR': 'MISSING'})
        return


class MSubArray(SubArray, ma.MaskedArray):

    def __new__(cls, data, info={}, mask=ma.nomask, dtype=None):
        subarr = SubArray(data, info)
        _data = ma.MaskedArray.__new__(cls, data=subarr, mask=mask, dtype=dtype)
        _data.info = subarr.info
        return _data

    def __array_finalize__(self, obj):
        ma.MaskedArray.__array_finalize__(self, obj)
        SubArray.__array_finalize__(self, obj)
        return


ms = MSubArray([1, 2], info={'a': 1})
print('Pre-pickle:', ms.info, ms.data.info)

pkl = pickle.dumps(ms)
ms_from_pkl = pickle.loads(pkl)
print('Post-pickle:', ms_from_pkl.info, ms_from_pkl.data.info)

这段代码会输出如下结果:

Pre-pickle: {'a': 1} {'a': 1}
Post-pickle: {} {}

可以看出,在反序列化后,子类中的额外属性info丢失了。

2、解决方案

要解决这个问题,可以参考以下建议:

  1. 重写pickle方法: 我们可以重写pickle中的__getstate____setstate__方法,以便在序列化和反序列化过程中保留子类中的额外属性。
import numpy as np
import cPickle as pickle
from numpy import ma


class SubArray(np.ndarray):

    def __new__(cls, arr, info={}):
        x = np.asanyarray(arr).view(cls)
        x.info = info
        return x

    def __getstate__(self):
        return self.info

    def __setstate__(self, state):
        self.info = state

    def __array_finalize__(self, obj):
        self.info = getattr(obj, 'info', {'ATTR': 'MISSING'})
        return


class MSubArray(SubArray, ma.MaskedArray):

    def __new__(cls, data, info={}, mask=ma.nomask, dtype=None):
        subarr = SubArray(data, info)
        _data = ma.MaskedArray.__new__(cls, data=subarr, mask=mask, dtype=dtype)
        _data.info = subarr.info
        return _data

    def __getstate__(self):
        return self.info

    def __setstate__(self, state):
        self.info = state

    def __array_finalize__(self, obj):
        ma.MaskedArray.__array_finalize__(self, obj)
        SubArray.__array_finalize__(self, obj)
        return


ms = MSubArray([1, 2], info={'a': 1})
print('Pre-pickle:', ms.info, ms.data.info)

pkl = pickle.dumps(ms)
ms_from_pkl = pickle.loads(pkl)
print('Post-pickle:', ms_from_pkl.info, ms_from_pkl.data.info)

这段代码将输出如下结果:

Pre-pickle: {'a': 1} {'a': 1}
Post-pickle: {'a': 1} {'a': 1}

可以看出,在反序列化后,子类中的额外属性info被保留了。

  1. 使用dill库: dill是一个比pickle更强大的序列化库,它可以序列化更多的对象类型,包括子类。
import dill

class SubArray(np.ndarray):

    def __new__(cls, arr, info={}):
        x = np.asanyarray(arr).view(cls)
        x.info = info
        return x

    def __array_finalize__(self, obj):
        self.info = getattr(obj, 'info', {'ATTR': 'MISSING'})
        return


class MSubArray(SubArray, ma.MaskedArray):

    def __new__(cls, data, info={}, mask=ma.nomask, dtype=None):
        subarr = SubArray(data, info)
        _data = ma.MaskedArray.__new__(cls, data=subarr, mask=mask, dtype=dtype)
        _data.info = subarr.info
        return _data

    def __array_finalize__(self, obj):
        ma.MaskedArray.__array_finalize__(self, obj)
        SubArray.__array_finalize__(self, obj)
        return


ms = MSubArray([1, 2], info={'a': 1})
print('Pre-pickle:', ms.info, ms.data.info)

pkl = dill.dumps(ms)
ms_from_pkl = dill.loads(pkl)
print('Post-pickle:', ms_from_pkl.info, ms_from_pkl.data.info)

这段代码将输出如下结果:

Pre-pickle: {'a': 1} {'a': 1}
Post-pickle: {'a': 1} {'a': 1}

可以看出,在反序列化后,子类中的额外属性info也被保留了。

以上是两种解决方法,您可以根据自己的需要选择一种使用。