了解完缓冲区协议之后,我们就来手动实现一下。而实现方式可以使用原生的 Python/C API,也可以使用 Cython,但前者实现起来会非常麻烦,牵扯的知识也非常多,而 Cython 则简化了这一点。
我们来分别介绍一下这两种方式。
使用 Cython 实现缓冲区协议
Cython 对缓冲区协议也有着强大的支持,我们只需要定义一个魔法方法即可实现缓冲区协议。
from cpython cimport Py_buffer
from cpython.mem cimport PyMem_Malloc, PyMem_Free
cdef class Matrix:
cdef Py_ssize_t shape[2] # 数组的形状
cdef Py_ssize_t strides[2] # 数组的 stride
cdef float *array
def __cinit__(self, row, col):
self.shape[0] = <Py_ssize_t> row
self.shape[1] = <Py_ssize_t> col
self.strides[1] = sizeof(float)
self.strides[0] = self.strides[1] * self.shape[1]
self.array = <float *> PyMem_Malloc(
self.shape[0] * self.shape[1] * sizeof(float))
def set_item_by_index(self, int index, float value):
"""留一个接口,用来设置元素"""
if index >= self.shape[0] * self.shape[1] or index < 0:
raise ValueError("索引无效")
self.array[index] = value
def __getbuffer__(self, Py_buffer *buffer, int flags):
"""自定义缓冲区需要实现 __getbuffer__ 方法"""
cdef int i;
for i in range(self.shape[0] * self.shape[1]):
self.array[i] = float(i)
# 缓冲区,这里是就是 array 本身,但是需要转成 void *
buffer.buf = <void *> self.array
# 实现缓冲区协议的对象,显然是 selfself.shape[0] * self.shape[1]
buffer.obj = self
# 缓冲区的总大小
buffer.len = self.shape[0] * self.shape[1] * sizeof(float)
# 读写权限,这里让缓冲区可读写
buffer.readonly = 0
# 缓冲区每个元素的大小
buffer.itemsize = sizeof(float)
# 元素类型,"f" 表示 float
buffer.format = "f"
# 该对象的维度
buffer.ndim = 2
# shape
buffer.shape = self.shape
# strides
buffer.strides = self.strides
# 直接设置为 NULL 即可
buffer.suboffsets = NULL
def dealloc(self):
if self.array != NULL:
PyMem_Free(<void *> self.array)
在 Cython 中我们只需要实现一个相应的魔法方法即可,真的是非常方便,当然我们为了验证是否共享内存,专门定义了一个方法。
import pyximport
pyximport.install(language_level=3)
import cython_test
import numpy as np
m = cython_test.Matrix(5, 4)
# 基于 m 创建 Numpy 数组
np_m = np.asarray(m)
# m 和 np_m 是共享内存的
print(m)
"""
<cython_test.Matrix object at 0x7f96ba55a3f0>
"""
print(np_m)
"""
[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]
[ 8. 9. 10. 11.]
[12. 13. 14. 15.]
[16. 17. 18. 19.]]
"""
# 通过 m 修改元素,然后打印 np_m
m.set_item_by_index(13, 666.666)
print(np_m)
"""
[[ 0. 1. 2. 3. ]
[ 4. 5. 6. 7. ]
[ 8. 9. 10. 11. ]
[ 12. 666.666 14. 15. ]
[ 16. 17. 18. 19. ]]
"""
结果没有任何问题,以上就是 Cython 实现缓冲区协议,其实在日常工作中我们不需要直接面对它,但了解一下总是好的。
使用 Python/C API 实现缓冲区协议
注:通过原生的 Python/C API 实现缓冲区协议,这个过程非常麻烦,因为这需要你熟悉解释器源代码,以及这些 API 本身。我们的重点是 Cython,只要知道 Cython 如何实现缓冲区协议即可。至于原生的 Python/C API,感兴趣的话可以看一看,不感兴趣的话跳过即可。\
下面编写 C 源文件,文件名为 py_array.c。
#include <stdio.h>
#include <stdlib.h>
#include <Python.h>
// 定义一个一维的数组
typedef struct {
int *arr;
int length;
} Array;
// 初始化函数
void initial_Array(Array *array, int length) {
array->length = length;
if (length == 0) {
array->arr = NULL;
} else {
array->arr = (int *) malloc(sizeof(int) * length);
for (int i = 0; i < length; i++) {
array->arr[i] = i;
}
}
}
// 释放内存
void dealloc_Array(Array *array) {
if (array->arr != NULL) free(array->arr);
array->arr = NULL;
}
// Python 的对象在 C 中都嵌套了 PyObject
typedef struct {
PyObject_HEAD
Array array;
} PyArray;
// 初始化 __init__ 函数
static int
PyArray_init(PyArray *self, PyObject *args, PyObject *kwargs) {
if (self->array.arr != NULL) {
dealloc_Array(&self->array);
}
int length = 0;
static char *kwlist[] = {"length", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|i", kwlist, &length)) {
return -1;
}
// 因为给 Python 调用,所以这里额外对 length 进行一下检测
if (length < 0) {
PyErr_SetString(PyExc_ValueError, "argument 'length' can not be negative");
return -1;
}
initial_Array(&self->array, length);
return 0;
}
// 析构函数
static void
PyArray_dealloc(PyArray *self) {
dealloc_Array(&self->array);
Py_TYPE(self)->tp_free((PyObject *) self);
}
static PyObject *
PyArray_repr(PyArray *self) {
//转成列表打印
PyObject *list = PyList_New(self->array.length);
Py_ssize_t i;
for (i=0; i<self->array.length; i++){
PyList_SetItem(list, i, PyLong_FromLong(*(self->array.arr + i)));
}
PyObject *ret = PyObject_Str(list);Py_DECREF(list);
return ret;
}
// 实现缓冲区协议
static int
PyArray_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
if (view == NULL) {
PyErr_SetString(PyExc_ValueError,
"NULL view in getbuffer");
return -1;
}
PyArray* self = (PyArray *)obj;
view->obj = (PyObject*)self;
view->buf = (void*)self->array.arr;
view->len = self->array.length * sizeof(int);
view->readonly = 0;
view->itemsize = sizeof(int);
view->format = "i";
view->ndim = 1;
view->shape = (Py_ssize_t *) &self->array.length;
view->strides = &view->itemsize;
view->suboffsets = NULL;
view->internal = NULL;
Py_INCREF(self);
return 0;
}
// 将上面的函数放入到 PyBufferProcs 结构体中
static PyBufferProcs PyArray_as_buffer = {
(getbufferproc)PyArray_getbuffer,
(releasebufferproc)0
};
static PyTypeObject PyArrayType = {
PyVarObject_HEAD_INIT(NULL, 0)
"py_my_array.PyArray",
sizeof(PyArray),
0,
(destructor) PyArray_dealloc,
0,
0, /* tp_getattr */
0, /* tp_setattr */
0, /* tp_reserved */
(reprfunc)PyArray_repr, /* tp_repr */
0, /* tp_as_number */
0, /* tp_as_sequence */
0, /* tp_as_mapping */
0, /* tp_hash */
0, /* tp_call */
0, /* tp_str */
0, /* tp_getattro */
0, /* tp_setattro */
// 指定 tp_as_buffer
&PyArray_as_buffer, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
"PyArray object", /* tp_doc */
0, /* tp_traverse */
0, /* tp_clear */
0, /* tp_richcompare */
0, /* tp_weaklistoffset */
0, /* tp_iter */
0, /* tp_iternext */
0, /* tp_methods */
0, /* tp_members */
0, /* tp_getset */
0, /* tp_base */
0, /* tp_dict */
0, /* tp_descr_get */
0, /* tp_descr_set */
0, /* tp_dictoffset */
(initproc) PyArray_init, /* tp_init */
};
static PyModuleDef py_array_module = {
PyModuleDef_HEAD_INIT,
"py_array",
"this is a module named py_array",
-1,
0,
NULL,
NULL,
NULL,
NULL
};
PyMODINIT_FUNC
PyInit_py_array(void) {
PyObject *m;
PyArrayType.tp_new = PyType_GenericNew;
if (PyType_Ready(&PyArrayType) < 0) return NULL;
m = PyModule_Create(&py_array_module);
if (m == NULL) return NULL;
Py_XINCREF(&PyArrayType);
PyModule_AddObject(m, "PyArray",
(PyObject *) &PyArrayType);
return m;
}
现在相信你一定能体会到为什么要有 Cython 存在,因为写原生的 Python/C API 太痛苦了,而且为了简便我们这里使用的是一维数组,但即便如此,也已经很麻烦了。
我们编译成扩展,首先编写 setup.py:
from distutils.core import setup, Extension
from Cython.Build import cythonize
ext = [Extension("py_array",
sources=["py_array.c"])]
setup(ext_modules=cythonize(ext, language_level=3))
执行 python setup.py build 生成扩展模块,然后我们来导入它。
import numpy as np
import py_array
print(py_array)
"""
<module 'py_array' from ..\py_array.cp38-win_amd64.pyd'>
"""
arr = py_array.PyArray(5)
print(arr)
"""
[0, 1, 2, 3, 4]
"""
np_arr = np.asarray(arr)
print(np_arr)
"""
[0 1 2 3 4]
"""
# 两者也是共享内存
np_arr[0] = 123
print(arr)
print(np_arr)
"""
[123, 1, 2, 3, 4]
[123 1 2 3 4]
"""
显然此时万事大吉了,因为实现了缓冲区协议,Numpy 知道了缓冲区数据,因此会在此基础之上建一个 view,并且 array 和 np_arr 是共享内存的。
因此核心就在于对缓冲区协议的理解,它本质上就是一个结构体,内部的成员描述了缓冲区数据的所有信息。而我们只需要定义一个函数,然后根据数组初始化这些信息即可,最后构建 PyBufferProcs 实例作为 tp_as_buffer 成员的值。
以上就是本次分享的所有内容,如果你觉得文章还不错,欢迎关注公众号:Python编程学习圈,每日干货分享,内容覆盖Python电子书、教程、数据库编程、Django,爬虫,云计算等等。或是前往编程学习网,了解更多编程技术知识。