Python3.8 中的shared_memory

3,220 阅读2分钟

python3.8中引入了shared_memory库, 可以用来做进程间的通信使用。比如下面用shared_memory中的ShareableList来传递数据, 计算π值。

其中, ShareableList的使用还是有一些限制, 比如只能是几种基础的数据类型, int, float, str, bytes, None, 而且大小不能超过10M, 而且要求定长,不可变, 即在ShareableList定义之后, 每个元素所占的内存大小不能改变, 所以实际使用的时候需要小心对待,一不注意,就会报错。

比如, 下面这段代码就会报错, 虽然都是int类型, 但显然两者底层占用的内存大小是不一样的:

In [17]: from multiprocessing import shared_memory                                       

In [18]: a = shared_memory.ShareableList([1])                                            

In [19]: a[0] = 1000000000                                                               

In [20]: a                                                                               
Out[20]: ShareableList([1000000000], name='psm_5b31ac45')

In [21]: a[0] = 10000000000000000000000000000000000                           
---------------------------------------------------------------------------
error                                     Traceback (most recent call last)
<ipython-input-21-648c79360b7b> in <module>
----> 1 a[0] = 10000000000000000000000000000000000

/usr/local/Cellar/python@3.8/3.8.2/Frameworks/Python.framework/Versions/3.8/lib/python3.8/multiprocessing/shared_memory.py in __setitem__(self, position, value)
    450         )
    451         value = value.encode(_encoding) if isinstance(value, str) else value
--> 452         struct.pack_into(new_format, self.shm.buf, offset, value)
    453 
    454     def __reduce__(self):

error: argument out of range

下面的代码为了简单灵活,使用了str来表示int, float,由于长度固定,所以需要自己做填充。

"""
python3.8+
"""


import os
import math
from multiprocessing import Process
from multiprocessing.managers import SharedMemoryManager


def slicing(mink, maxk):
    return sum(1.0 / ((2 * k + 1) ** 2) for k in range(mink, maxk))


def slicing_wrapper(shared_list, i):
    item = shared_list[i]
    length = len(item)
    a, b, _ = item.split(',')
    mink, maxk = int(a), int(b)
    result = slicing(mink, maxk)
    s = str(result)
    shared_list[i] = s + ',' + ' ' * (length - len(s) - 1)


def pi(n, process_num):
    unit = n // process_num
    payload = []
    for i in range(process_num):
        mink = unit * i
        maxk = mink + unit
        payload.append((mink, maxk))

    sized = 100
    # ShareableList中要求其中的每个元素'定长'
    # (运行期间内存所占大小固定), 所以这里做一些填充
    # ValueError: bytes/str item exceeds available storage
    payload_str = []
    for mink, maxk in payload:
        t = f'{mink},{maxk},'
        s = ' ' * (sized - len(t))
        payload_str.append(t + s)

    with SharedMemoryManager() as manager:
        nums = manager.ShareableList(payload_str)
        process = []

        for i, e in enumerate(nums):
            p = Process(target=slicing_wrapper, args=(nums, i))
            process.append(p)
            p.start()

        for p in process:
            p.join()

    pai = math.sqrt(sum(float(item.split(',')[0]) for item in nums) * 8)
    return pai


if __name__ == '__main__':
    pai = pi(10000000, os.cpu_count())
    print(pai)