NumPy 源码解析(六十六)
.\numpy\numpy\_core\src\multiarray\dtype_traversal.h
/* NumPy DType clear (object DECREF + NULLing) implementations */
// 获取清除对象类型的循环函数,用于逐个元素清除对象类型数据
NPY_NO_EXPORT int
npy_get_clear_object_strided_loop(
void *traverse_context, const PyArray_Descr *descr, int aligned,
npy_intp fixed_stride,
PyArrayMethod_TraverseLoop **out_loop, NpyAuxData **out_traversedata,
NPY_ARRAYMETHOD_FLAGS *flags);
// 获取清除 void 和遗留用户数据类型的循环函数,用于逐个元素清除 void 类型和遗留用户数据类型数据
NPY_NO_EXPORT int
npy_get_clear_void_and_legacy_user_dtype_loop(
void *traverse_context, const _PyArray_LegacyDescr *descr, int aligned,
npy_intp fixed_stride,
PyArrayMethod_TraverseLoop **out_loop, NpyAuxData **out_traversedata,
NPY_ARRAYMETHOD_FLAGS *flags);
/* NumPy DType zero-filling implementations */
// 获取填充对象类型数据为零的循环函数,但实际未使用
NPY_NO_EXPORT int
npy_object_get_fill_zero_loop(
void *NPY_UNUSED(traverse_context), const PyArray_Descr *NPY_UNUSED(descr),
int NPY_UNUSED(aligned), npy_intp NPY_UNUSED(fixed_stride),
PyArrayMethod_TraverseLoop **out_loop, NpyAuxData **NPY_UNUSED(out_auxdata),
NPY_ARRAYMETHOD_FLAGS *flags);
// 获取填充 void 和遗留用户数据类型数据为零的循环函数
NPY_NO_EXPORT int
npy_get_zerofill_void_and_legacy_user_dtype_loop(
void *traverse_context, const _PyArray_LegacyDescr *dtype, int aligned,
npy_intp stride, PyArrayMethod_TraverseLoop **out_func,
NpyAuxData **out_auxdata, NPY_ARRAYMETHOD_FLAGS *flags);
/* Helper to deal with calling or nesting simple strided loops */
// 辅助结构体,用于处理简单步进循环的调用或嵌套
typedef struct {
PyArrayMethod_TraverseLoop *func; // 循环函数指针
NpyAuxData *auxdata; // 辅助数据指针
const PyArray_Descr *descr; // 数据类型描述符指针
} NPY_traverse_info;
// 初始化 NPY_traverse_info 结构体
static inline void
NPY_traverse_info_init(NPY_traverse_info *cast_info)
{
cast_info->func = NULL; // 将循环函数指针置为 NULL,表示未初始化
cast_info->auxdata = NULL; // 允许保持辅助数据指针为 NULL
cast_info->descr = NULL; // 将数据类型描述符指针置为 NULL,表示未初始化
}
// 释放 NPY_traverse_info 结构体的资源
static inline void
NPY_traverse_info_xfree(NPY_traverse_info *traverse_info)
{
if (traverse_info->func == NULL) { // 如果循环函数指针为 NULL,直接返回
return;
}
traverse_info->func = NULL; // 将循环函数指针置为 NULL
NPY_AUXDATA_FREE(traverse_info->auxdata); // 释放辅助数据
Py_XDECREF(traverse_info->descr); // 释放数据类型描述符
}
// 复制 NPY_traverse_info 结构体内容
static inline int
NPY_traverse_info_copy(
NPY_traverse_info *traverse_info, NPY_traverse_info *original)
{
/* Note that original may be identical to traverse_info! */
if (original->func == NULL) {
/* Allow copying also of unused clear info */
traverse_info->func = NULL; // 允许复制未使用的清除信息
return 0;
}
if (original->auxdata != NULL) {
traverse_info->auxdata = NPY_AUXDATA_CLONE(original->auxdata); // 复制辅助数据
if (traverse_info->auxdata == NULL) {
traverse_info->func = NULL; // 复制失败时将循环函数指针置为 NULL
return -1;
}
}
else {
traverse_info->auxdata = NULL; // 原辅助数据为 NULL,则置为 NULL
}
Py_INCREF(original->descr); // 增加数据类型描述符的引用计数
traverse_info->descr = original->descr; // 复制数据类型描述符指针
traverse_info->func = original->func; // 复制循环函数指针
return 0;
}
NPY_NO_EXPORT int
PyArray_GetClearFunction(
int aligned, npy_intp stride, PyArray_Descr *dtype,
NPY_traverse_info *clear_info, NPY_ARRAYMETHOD_FLAGS *flags);
.\numpy\numpy\_core\src\multiarray\einsum_debug.h
/*
* This file provides debug macros used by the other einsum files.
*
* Copyright (c) 2011 by Mark Wiebe (mwwiebe@gmail.com)
* The University of British Columbia
*
* See LICENSE.txt for the license.
*/
/********** PRINTF DEBUG TRACING **************/
// 定义调试输出级别,0 表示关闭调试输出
// 如果开启了调试输出
// 包含标准输出头文件
// 定义输出宏,打印字符串
// 定义输出宏,打印带有一个参数的格式化字符串
// 定义输出宏,打印带有两个参数的格式化字符串
// 定义输出宏,打印带有三个参数的格式化字符串
// 如果未开启调试输出,则定义这些宏为空
.\numpy\numpy\_core\src\multiarray\einsum_sumprod.h
// 定义函数指针类型 sum_of_products_fn,用于表示一个函数指针,该函数接受四个参数:
// - int,参数个数
// - char **,参数数组
// - npy_intp const*,整数数组,常量指针
// - npy_intp,整数
typedef void (*sum_of_products_fn)(int, char **, npy_intp const*, npy_intp);
// 声明一个隐藏(visibility hidden)的函数 get_sum_of_products_function,返回类型为 sum_of_products_fn,
// 接受如下参数:
// - int,操作数个数
// - int,类型编号
// - npy_intp,项大小
// - npy_intp const*,固定步长的整数数组,常量指针
NPY_VISIBILITY_HIDDEN sum_of_products_fn
get_sum_of_products_function(int nop, int type_num,
npy_intp itemsize, npy_intp const *fixed_strides);
.\numpy\numpy\_core\src\multiarray\flagsobject.c
/*
* Array Flags Object
* 定义了一些与数组标志相关的宏和函数
*/
/*
* 清除 PY_SSIZE_T 未定义的宏
*/
/*
* 引入必要的头文件和库文件
*/
/*
* 静态函数声明:更新连续性标志
*/
static void
_UpdateContiguousFlags(PyArrayObject *ap);
/*
* 获取新的 ArrayFlagsObject 对象
*/
NPY_NO_EXPORT PyObject *
PyArray_NewFlagsObject(PyObject *obj)
{
PyObject *flagobj;
int flags;
// 如果传入的对象是空指针,则设置默认的标志位
if (obj == NULL) {
flags = NPY_ARRAY_C_CONTIGUOUS |
NPY_ARRAY_OWNDATA |
NPY_ARRAY_F_CONTIGUOUS |
NPY_ARRAY_ALIGNED;
}
else {
// 检查传入的对象是否为 NumPy 数组
if (!PyArray_Check(obj)) {
PyErr_SetString(PyExc_ValueError,
"Need a NumPy array to create a flags object");
return NULL;
}
// 获取传入数组的标志位
flags = PyArray_FLAGS((PyArrayObject *)obj);
}
// 分配并初始化新的 ArrayFlagsObject 对象
flagobj = PyArrayFlags_Type.tp_alloc(&PyArrayFlags_Type, 0);
if (flagobj == NULL) {
return NULL;
}
// 增加传入对象的引用计数,并将其赋值给 ArrayFlagsObject 对象的 arr 成员
Py_XINCREF(obj);
((PyArrayFlagsObject *)flagobj)->arr = obj;
// 将计算得到的标志位赋值给 ArrayFlagsObject 对象的 flags 成员
((PyArrayFlagsObject *)flagobj)->flags = flags;
return flagobj;
}
/*NUMPY_API
* 同时更新多个标志位
*/
NPY_NO_EXPORT void
PyArray_UpdateFlags(PyArrayObject *ret, int flagmask)
{
// 总是同时更新连续性标志位,因为从一个标志位推断另一个并不容易
if (flagmask & (NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_C_CONTIGUOUS)) {
_UpdateContiguousFlags(ret);
}
// 更新对齐标志位
if (flagmask & NPY_ARRAY_ALIGNED) {
if (IsAligned(ret)) {
PyArray_ENABLEFLAGS(ret, NPY_ARRAY_ALIGNED);
}
else {
PyArray_CLEARFLAGS(ret, NPY_ARRAY_ALIGNED);
}
}
/*
* 默认情况下,WRITEABLE 不在 UPDATE_ALL 中,所以需要额外检查
* 更新可写标志位
*/
if (flagmask & NPY_ARRAY_WRITEABLE) {
if (_IsWriteable(ret)) {
PyArray_ENABLEFLAGS(ret, NPY_ARRAY_WRITEABLE);
}
else {
PyArray_CLEARFLAGS(ret, NPY_ARRAY_WRITEABLE);
}
}
return;
}
/*
* 更新数组对象的连续性标志位。
* 根据数组维度和步长,确定数组是否以 C 或 F 连续存储。
*/
static void
_UpdateContiguousFlags(PyArrayObject *ap)
{
npy_intp sd; // 步长的值
npy_intp dim; // 数组的当前维度长度
int i; // 迭代器
npy_bool is_c_contig = 1; // 是否是 C 连续的标志位,默认为真
sd = PyArray_ITEMSIZE(ap); // 获取数组元素的大小
for (i = PyArray_NDIM(ap) - 1; i >= 0; --i) {
dim = PyArray_DIMS(ap)[i]; // 获取当前维度的长度
/* contiguous by definition */
if (dim == 0) { // 如果当前维度长度为0,数组被定义为连续的
PyArray_ENABLEFLAGS(ap, NPY_ARRAY_C_CONTIGUOUS); // 启用 C 连续标志
PyArray_ENABLEFLAGS(ap, NPY_ARRAY_F_CONTIGUOUS); // 启用 F 连续标志
return; // 返回
}
if (dim != 1) { // 如果当前维度长度不为1
if (PyArray_STRIDES(ap)[i] != sd) { // 检查当前维度的步长是否与期望的步长相同
is_c_contig = 0; // 如果步长不符合要求,则不是 C 连续
}
sd *= dim; // 更新步长值
}
}
if (is_c_contig) {
PyArray_ENABLEFLAGS(ap, NPY_ARRAY_C_CONTIGUOUS); // 启用 C 连续标志
}
else {
PyArray_CLEARFLAGS(ap, NPY_ARRAY_C_CONTIGUOUS); // 清除 C 连续标志
}
/* 检查是否是 Fortran 连续 */
sd = PyArray_ITEMSIZE(ap); // 重新获取数组元素的大小
for (i = 0; i < PyArray_NDIM(ap); ++i) {
dim = PyArray_DIMS(ap)[i]; // 获取当前维度的长度
if (dim != 1) { // 如果当前维度长度不为1
if (PyArray_STRIDES(ap)[i] != sd) { // 检查当前维度的步长是否与期望的步长相同
PyArray_CLEARFLAGS(ap, NPY_ARRAY_F_CONTIGUOUS); // 清除 F 连续标志
return; // 返回
}
sd *= dim; // 更新步长值
}
}
PyArray_ENABLEFLAGS(ap, NPY_ARRAY_F_CONTIGUOUS); // 启用 F 连续标志
return; // 返回
}
/*
* 释放数组标志对象的资源
*/
static void
arrayflags_dealloc(PyArrayFlagsObject *self)
{
Py_XDECREF(self->arr); // 释放数组对象的引用
Py_TYPE(self)->tp_free((PyObject *)self); // 释放对象内存
}
/*
* 定义获取标志位的宏函数,用于不同标志位的获取操作
*/
static PyObject * \
arrayflags_
PyArrayFlagsObject *self, void *NPY_UNUSED(ignored)) \
{ \
return PyBool_FromLong((self->flags & (UPPER)) == (UPPER)); \
}
/*
* 定义获取标志位并生成警告信息的宏函数
*/
static char *msg = "future versions will not create a writeable "
"array from broadcast_array. Set the writable flag explicitly to "
"avoid this warning.";
static PyObject * \
arrayflags_
PyArrayFlagsObject *self, void *NPY_UNUSED(ignored)) \
{ \
// 检查标志位中是否包含 NPY_ARRAY_WARN_ON_WRITE 标志
if (self->flags & NPY_ARRAY_WARN_ON_WRITE) { \
// 如果包含,则发出 FutureWarning 警告并检查是否出错
if (PyErr_Warn(PyExc_FutureWarning, msg) < 0) {\
// 如果发生错误,返回 NULL
return NULL; \
} \
}\
// 返回一个 PyBool 对象,表示 self->flags 中的 UPPER 标志是否全部设置
return PyBool_FromLong((self->flags & (UPPER)) == (UPPER)); \
}
/* 定义宏以获取相应标志 */
_define_get(NPY_ARRAY_C_CONTIGUOUS, contiguous)
_define_get(NPY_ARRAY_F_CONTIGUOUS, fortran)
_define_get(NPY_ARRAY_WRITEBACKIFCOPY, writebackifcopy)
_define_get(NPY_ARRAY_OWNDATA, owndata)
_define_get(NPY_ARRAY_ALIGNED, aligned)
_define_get(NPY_ARRAY_WRITEABLE, writeable_no_warn)
_define_get_warn(NPY_ARRAY_WRITEABLE, writeable)
_define_get_warn(NPY_ARRAY_ALIGNED|
NPY_ARRAY_WRITEABLE, behaved)
_define_get_warn(NPY_ARRAY_ALIGNED|
NPY_ARRAY_WRITEABLE|
NPY_ARRAY_C_CONTIGUOUS, carray)
/* 定义静态函数:获取数组标志中的 C 连续和 F 连续 */
static PyObject *
arrayflags_forc_get(PyArrayFlagsObject *self, void *NPY_UNUSED(ignored))
{
PyObject *item;
// 如果数组是 F 连续或者 C 连续,返回 Py_True,否则返回 Py_False
if (((self->flags & NPY_ARRAY_F_CONTIGUOUS) == NPY_ARRAY_F_CONTIGUOUS) ||
((self->flags & NPY_ARRAY_C_CONTIGUOUS) == NPY_ARRAY_C_CONTIGUOUS)) {
item = Py_True;
}
else {
item = Py_False;
}
Py_INCREF(item);
return item;
}
/* 定义静态函数:获取数组标志中的 F 连续但非 C 连续 */
static PyObject *
arrayflags_fnc_get(PyArrayFlagsObject *self, void *NPY_UNUSED(ignored))
{
PyObject *item;
// 如果数组是 F 连续且不是 C 连续,返回 Py_True,否则返回 Py_False
if (((self->flags & NPY_ARRAY_F_CONTIGUOUS) == NPY_ARRAY_F_CONTIGUOUS) &&
!((self->flags & NPY_ARRAY_C_CONTIGUOUS) == NPY_ARRAY_C_CONTIGUOUS)) {
item = Py_True;
}
else {
item = Py_False;
}
Py_INCREF(item);
return item;
}
/* 定义静态函数:获取数组标志中的 A (对齐)、W (可写)、F (F 连续) 的组合 */
static PyObject *
arrayflags_farray_get(PyArrayFlagsObject *self, void *NPY_UNUSED(ignored))
{
PyObject *item;
// 如果数组同时满足对齐、可写、F 连续的条件,且不是 C 连续,返回 Py_True,否则返回 Py_False
if (((self->flags & (NPY_ARRAY_ALIGNED|
NPY_ARRAY_WRITEABLE|
NPY_ARRAY_F_CONTIGUOUS)) != 0) &&
!((self->flags & NPY_ARRAY_C_CONTIGUOUS) != 0)) {
item = Py_True;
}
else {
item = Py_False;
}
Py_INCREF(item);
return item;
}
/* 定义静态函数:获取数组标志的整数表示 */
static PyObject *
arrayflags_num_get(PyArrayFlagsObject *self, void *NPY_UNUSED(ignored))
{
// 返回数组标志的整数表示
return PyLong_FromLong(self->flags);
}
/* 定义静态函数:设置 writebackifcopy 标志 */
/* 假定 setflags 的顺序是 write、align、uic */
static int
arrayflags_writebackifcopy_set(
PyArrayFlagsObject *self, PyObject *obj, void *NPY_UNUSED(ignored))
{
PyObject *res;
// 如果传入的 obj 是 NULL,则不能删除 writebackifcopy 属性,返回错误
if (obj == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete flags writebackifcopy attribute");
return -1;
}
// 如果 self->arr 是 NULL,则不能在数组标量上设置标志,返回错误
if (self->arr == NULL) {
PyErr_SetString(PyExc_ValueError,
"Cannot set flags on array scalars.");
return -1;
}
// 判断传入的 obj 是否为真值
int istrue = PyObject_IsTrue(obj);
if (istrue == -1) {
return -1;
}
// 调用 self->arr 的 setflags 方法,传递三个参数:Py_None、Py_None 和 istrue 的真假值
res = PyObject_CallMethod(self->arr, "setflags", "OOO", Py_None, Py_None,
(istrue ? Py_True : Py_False));
if (res == NULL) {
return -1;
}
Py_DECREF(res);
return 0;
}
/* 定义静态函数:设置 aligned 标志 */
static int
arrayflags_aligned_set(
PyArrayFlagsObject *self, PyObject *obj, void *NPY_UNUSED(ignored))
{
PyObject *res;
// 如果传入的 obj 是 NULL,则不能删除 aligned 属性,返回错误
if (obj == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete flags aligned attribute");
return -1;
}
// 检查数组指针是否为 NULL,如果是,则抛出异常并返回错误码
if (self->arr == NULL) {
PyErr_SetString(PyExc_ValueError,
"Cannot set flags on array scalars.");
return -1;
}
// 检查对象是否为真值,并返回对应的整数结果,如果出错则返回错误码
int istrue = PyObject_IsTrue(obj);
if (istrue == -1) {
return -1;
}
// 调用数组对象的 setflags 方法,设置其标志位
// 参数依次为 Py_None(空对象)、(istrue 为真时为 Py_True 否则为 Py_False)、Py_None
res = PyObject_CallMethod(self->arr, "setflags", "OOO", Py_None,
(istrue ? Py_True : Py_False),
Py_None);
// 如果调用出错(返回结果为 NULL),则返回错误码
if (res == NULL) {
return -1;
}
// 释放调用结果对象的引用计数,避免内存泄漏
Py_DECREF(res);
// 返回成功状态码
return 0;
static int
arrayflags_writeable_set(
PyArrayFlagsObject *self, PyObject *obj, void *NPY_UNUSED(ignored))
{
PyObject *res;
// 如果传入的对象是空,则设置错误并返回
if (obj == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete flags writeable attribute");
return -1;
}
// 如果数组对象为 NULL,则设置错误并返回
if (self->arr == NULL) {
PyErr_SetString(PyExc_ValueError,
"Cannot set flags on array scalars.");
return -1;
}
// 检查传入的对象是否为真值
int istrue = PyObject_IsTrue(obj);
if (istrue == -1) {
return -1;
}
// 调用数组对象的 setflags 方法来设置可写标志
res = PyObject_CallMethod(self->arr, "setflags", "OOO",
(istrue ? Py_True : Py_False),
Py_None, Py_None);
if (res == NULL) {
return -1;
}
// 减少对结果的引用计数,避免内存泄漏
Py_DECREF(res);
// 返回成功
return 0;
}
static int
arrayflags_warn_on_write_set(
PyArrayFlagsObject *self, PyObject *obj, void *NPY_UNUSED(ignored))
{
/*
* This code should go away in a future release, so do not mangle the
* array_setflags function with an extra kwarg
*/
int ret;
// 如果传入的对象是空,则设置错误并返回
if (obj == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete flags _warn_on_write attribute");
return -1;
}
// 检查传入的对象是否为真值
ret = PyObject_IsTrue(obj);
if (ret > 0) {
// 如果数组对象不可写,则设置错误并返回
if (!(PyArray_FLAGS((PyArrayObject*)self->arr) & NPY_ARRAY_WRITEABLE)) {
PyErr_SetString(PyExc_ValueError,
"cannot set '_warn_on_write' flag when 'writable' is "
"False");
return -1;
}
// 启用数组对象的 WARN_ON_WRITE 标志
PyArray_ENABLEFLAGS((PyArrayObject*)self->arr, NPY_ARRAY_WARN_ON_WRITE);
}
else if (ret < 0) {
return -1;
}
else {
// 如果传入的对象不是真值,则设置错误并返回
PyErr_SetString(PyExc_ValueError,
"cannot clear '_warn_on_write', set "
"writeable True to clear this private flag");
return -1;
}
// 返回成功
return 0;
}
static PyGetSetDef arrayflags_getsets[] = {
{"contiguous",
(getter)arrayflags_contiguous_get,
NULL,
NULL, NULL},
{"c_contiguous",
(getter)arrayflags_contiguous_get,
NULL,
NULL, NULL},
{"f_contiguous",
(getter)arrayflags_fortran_get,
NULL,
NULL, NULL},
{"fortran",
(getter)arrayflags_fortran_get,
NULL,
NULL, NULL},
{"writebackifcopy",
(getter)arrayflags_writebackifcopy_get,
(setter)arrayflags_writebackifcopy_set,
NULL, NULL},
{"owndata",
(getter)arrayflags_owndata_get,
NULL,
NULL, NULL},
{"aligned",
(getter)arrayflags_aligned_get,
(setter)arrayflags_aligned_set,
NULL, NULL},
{"writeable",
(getter)arrayflags_writeable_get,
(setter)arrayflags_writeable_set,
NULL, NULL},
{"_writeable_no_warn",
(getter)arrayflags_writeable_no_warn_get,
(setter)NULL,
NULL, NULL},
{"_warn_on_write",
(getter)NULL,
(setter)arrayflags_warn_on_write_set,
NULL, NULL},
设置一个名为 `_warn_on_write` 的属性,它具有以下特性:
- getter 为 NULL,表示没有定义 getter 函数。
- setter 函数为 `arrayflags_warn_on_write_set`,用于设置 `_warn_on_write` 属性的值。
- 没有额外的文档字符串或数据。
{"fnc",
(getter)arrayflags_fnc_get,
NULL,
NULL, NULL},
设置一个名为 `fnc` 的属性,具有以下特性:
- getter 函数为 `arrayflags_fnc_get`,用于获取 `fnc` 属性的值。
- 没有 setter 函数。
- 没有额外的文档字符串或数据。
{"forc",
(getter)arrayflags_forc_get,
NULL,
NULL, NULL},
设置一个名为 `forc` 的属性,具有以下特性:
- getter 函数为 `arrayflags_forc_get`,用于获取 `forc` 属性的值。
- 没有 setter 函数。
- 没有额外的文档字符串或数据。
{"behaved",
(getter)arrayflags_behaved_get,
NULL,
NULL, NULL},
设置一个名为 `behaved` 的属性,具有以下特性:
- getter 函数为 `arrayflags_behaved_get`,用于获取 `behaved` 属性的值。
- 没有 setter 函数。
- 没有额外的文档字符串或数据。
{"carray",
(getter)arrayflags_carray_get,
NULL,
NULL, NULL},
设置一个名为 `carray` 的属性,具有以下特性:
- getter 函数为 `arrayflags_carray_get`,用于获取 `carray` 属性的值。
- 没有 setter 函数。
- 没有额外的文档字符串或数据。
{"farray",
(getter)arrayflags_farray_get,
NULL,
NULL, NULL},
设置一个名为 `farray` 的属性,具有以下特性:
- getter 函数为 `arrayflags_farray_get`,用于获取 `farray` 属性的值。
- 没有 setter 函数。
- 没有额外的文档字符串或数据。
{"num",
(getter)arrayflags_num_get,
NULL,
NULL, NULL},
设置一个名为 `num` 的属性,具有以下特性:
- getter 函数为 `arrayflags_num_get`,用于获取 `num` 属性的值。
- 没有 setter 函数。
- 没有额外的文档字符串或数据。
{NULL, NULL, NULL, NULL, NULL},
属性列表结束标记,用于指示没有更多的属性。
};
// 定义 arrayflags_getitem 函数,接收一个 PyArrayFlagsObject 类型的 self 参数和一个 ind 参数
static PyObject *
arrayflags_getitem(PyArrayFlagsObject *self, PyObject *ind)
{
// 声明一个指向字符的指针 key,初始化为 NULL
char *key = NULL;
// 声明一个字符数组 buf,用于临时存储字符串
char buf[16];
// 声明一个整数变量 n,用于存储字符串的长度
int n;
// 如果 ind 是 Unicode 字符串
if (PyUnicode_Check(ind)) {
// 声明一个 PyObject 指针 tmp_str
PyObject *tmp_str;
// 将 Unicode 字符串转换为 ASCII 字符串,并赋给 tmp_str
tmp_str = PyUnicode_AsASCIIString(ind);
// 如果转换失败,返回 NULL
if (tmp_str == NULL) {
return NULL;
}
// 将 tmp_str 转换为 C 风格的字符串,并赋给 key
key = PyBytes_AS_STRING(tmp_str);
// 获取 key 的长度,并赋给 n
n = PyBytes_GET_SIZE(tmp_str);
// 如果字符串长度超过 16,释放 tmp_str 并跳转到 fail 标签处
if (n > 16) {
Py_DECREF(tmp_str);
goto fail;
}
// 将 key 复制到 buf 中
memcpy(buf, key, n);
// 释放 tmp_str
Py_DECREF(tmp_str);
// 将 buf 的地址赋给 key
key = buf;
}
// 如果 ind 是字节字符串
else if (PyBytes_Check(ind)) {
// 将 ind 转换为 C 风格的字符串,并赋给 key
key = PyBytes_AS_STRING(ind);
// 获取 key 的长度,并赋给 n
n = PyBytes_GET_SIZE(ind);
}
// 如果 ind 不是字符串类型,跳转到 fail 标签处
else {
goto fail;
}
// 根据字符串长度 n 执行不同的操作
switch(n) {
// 如果字符串长度为 1
case 1:
// 根据 key 的第一个字符执行不同的操作
switch(key[0]) {
case 'C':
// 返回 arrayflags_contiguous_get 函数的结果
return arrayflags_contiguous_get(self, NULL);
case 'F':
// 返回 arrayflags_fortran_get 函数的结果
return arrayflags_fortran_get(self, NULL);
case 'W':
// 返回 arrayflags_writeable_get 函数的结果
return arrayflags_writeable_get(self, NULL);
case 'B':
// 返回 arrayflags_behaved_get 函数的结果
return arrayflags_behaved_get(self, NULL);
case 'O':
// 返回 arrayflags_owndata_get 函数的结果
return arrayflags_owndata_get(self, NULL);
case 'A':
// 返回 arrayflags_aligned_get 函数的结果
return arrayflags_aligned_get(self, NULL);
case 'X':
// 返回 arrayflags_writebackifcopy_get 函数的结果
return arrayflags_writebackifcopy_get(self, NULL);
default:
// 如果 key 不匹配上述字符,跳转到 fail 标签处
goto fail;
}
break;
// 如果字符串长度为 2
case 2:
// 如果 key 是 "CA",返回 arrayflags_carray_get 函数的结果
if (strncmp(key, "CA", n) == 0) {
return arrayflags_carray_get(self, NULL);
}
// 如果 key 是 "FA",返回 arrayflags_farray_get 函数的结果
if (strncmp(key, "FA", n) == 0) {
return arrayflags_farray_get(self, NULL);
}
break;
// 如果字符串长度为 3
case 3:
// 如果 key 是 "FNC",返回 arrayflags_fnc_get 函数的结果
if (strncmp(key, "FNC", n) == 0) {
return arrayflags_fnc_get(self, NULL);
}
break;
// 如果字符串长度为 4
case 4:
// 如果 key 是 "FORC",返回 arrayflags_forc_get 函数的结果
if (strncmp(key, "FORC", n) == 0) {
return arrayflags_forc_get(self, NULL);
}
break;
// 如果字符串长度为 6
case 6:
// 如果 key 是 "CARRAY",返回 arrayflags_carray_get 函数的结果
if (strncmp(key, "CARRAY", n) == 0) {
return arrayflags_carray_get(self, NULL);
}
// 如果 key 是 "FARRAY",返回 arrayflags_farray_get 函数的结果
if (strncmp(key, "FARRAY", n) == 0) {
return arrayflags_farray_get(self, NULL);
}
break;
// 如果字符串长度为 7
case 7:
// 根据 key 的值返回相应函数的结果
if (strncmp(key,"FORTRAN",n) == 0) {
return arrayflags_fortran_get(self, NULL);
}
if (strncmp(key,"BEHAVED",n) == 0) {
return arrayflags_behaved_get(self, NULL);
}
if (strncmp(key,"OWNDATA",n) == 0) {
return arrayflags_owndata_get(self, NULL);
}
if (strncmp(key,"ALIGNED",n) == 0) {
return arrayflags_aligned_get(self, NULL);
}
break;
// 如果字符串长度为 9
case 9:
// 如果 key 是 "WRITEABLE",返回 arrayflags_writeable_get 函数的结果
if (strncmp(key,"WRITEABLE",n) == 0) {
return arrayflags_writeable_get(self, NULL);
}
break;
// 如果字符串长度为 10
case 10:
// 如果 key 是 "CONTIGUOUS",返回 arrayflags_contiguous_get 函数的结果
if (strncmp(key,"CONTIGUOUS",n) == 0) {
return arrayflags_contiguous_get(self, NULL);
}
break;
// 如果字符串长度不匹配上述任何情况,跳转到 fail 标签处
fail:
// 返回 NULL
return NULL;
}
```
case 12:
if (strncmp(key, "C_CONTIGUOUS", n) == 0) {
return arrayflags_contiguous_get(self, NULL);
}
if (strncmp(key, "F_CONTIGUOUS", n) == 0) {
return arrayflags_fortran_get(self, NULL);
}
break;
case 15:
if (strncmp(key, "WRITEBACKIFCOPY", n) == 0) {
return arrayflags_writebackifcopy_get(self, NULL);
}
break;
}
fail:
PyErr_SetString(PyExc_KeyError, "Unknown flag");
return NULL;
}
static int
arrayflags_setitem(PyArrayFlagsObject *self, PyObject *ind, PyObject *item)
{
char *key; // 声明一个指向字符的指针变量 key
char buf[16]; // 声明一个长度为 16 的字符数组 buf,用于存储字符串
int n; // 声明一个整型变量 n,用于存储字符串长度
if (PyUnicode_Check(ind)) { // 检查 ind 是否为 Unicode 对象
PyObject *tmp_str; // 声明一个 PyObject 类型的指针 tmp_str
tmp_str = PyUnicode_AsASCIIString(ind); // 将 Unicode 对象转换为 ASCII 字符串对象
key = PyBytes_AS_STRING(tmp_str); // 获取转换后的 ASCII 字符串的指针
n = PyBytes_GET_SIZE(tmp_str); // 获取转换后的 ASCII 字符串的长度
if (n > 16) n = 16; // 如果长度超过 16,则截断为 16
memcpy(buf, key, n); // 将 key 指向的内容复制到 buf 中
Py_DECREF(tmp_str); // 释放临时字符串对象的引用
key = buf; // 将 key 指向 buf,此时 key 指向 buf 的内容
}
else if (PyBytes_Check(ind)) { // 检查 ind 是否为字节对象
key = PyBytes_AS_STRING(ind); // 获取字节对象的指针
n = PyBytes_GET_SIZE(ind); // 获取字节对象的长度
}
else {
goto fail; // 如果 ind 既不是 Unicode 对象也不是字节对象,则跳转到 fail 标签处
}
if (((n==9) && (strncmp(key, "WRITEABLE", n) == 0)) || // 检查是否为 "WRITEABLE" 或 "W"
((n==1) && (strncmp(key, "W", n) == 0))) {
return arrayflags_writeable_set(self, item, NULL); // 调用 arrayflags_writeable_set 处理
}
else if (((n==7) && (strncmp(key, "ALIGNED", n) == 0)) || // 检查是否为 "ALIGNED" 或 "A"
((n==1) && (strncmp(key, "A", n) == 0))) {
return arrayflags_aligned_set(self, item, NULL); // 调用 arrayflags_aligned_set 处理
}
else if (((n==15) && (strncmp(key, "WRITEBACKIFCOPY", n) == 0)) || // 检查是否为 "WRITEBACKIFCOPY" 或 "X"
((n==1) && (strncmp(key, "X", n) == 0))) {
return arrayflags_writebackifcopy_set(self, item, NULL); // 调用 arrayflags_writebackifcopy_set 处理
}
fail:
PyErr_SetString(PyExc_KeyError, "Unknown flag"); // 设置 Key 错误异常
return -1; // 返回 -1 表示出错
}
static char *
_torf_(int flags, int val)
{
if ((flags & val) == val) { // 检查 flags 中是否包含 val 的位
return "True"; // 如果包含,返回字符串 "True"
}
else {
return "False"; // 如果不包含,返回字符串 "False"
}
}
static PyObject *
arrayflags_print(PyArrayFlagsObject *self)
{
int fl = self->flags; // 获取 self 对象的 flags 属性值
const char *_warn_on_write = ""; // 声明一个指向常量字符的指针 _warn_on_write,并初始化为空字符串
if (fl & NPY_ARRAY_WARN_ON_WRITE) { // 检查 flags 中是否包含 NPY_ARRAY_WARN_ON_WRITE 标志位
_warn_on_write = " (with WARN_ON_WRITE=True)"; // 如果包含,设置 _warn_on_write
}
return PyUnicode_FromFormat(
" %s : %s\n %s : %s\n"
" %s : %s\n %s : %s%s\n"
" %s : %s\n %s : %s\n",
"C_CONTIGUOUS", _torf_(fl, NPY_ARRAY_C_CONTIGUOUS), // 使用 _torf_ 函数获取对应标志位的值
"F_CONTIGUOUS", _torf_(fl, NPY_ARRAY_F_CONTIGUOUS),
"OWNDATA", _torf_(fl, NPY_ARRAY_OWNDATA),
"WRITEABLE", _torf_(fl, NPY_ARRAY_WRITEABLE),
_warn_on_write, // 输出 WARN_ON_WRITE 的状态信息
"ALIGNED", _torf_(fl, NPY_ARRAY_ALIGNED),
"WRITEBACKIFCOPY", _torf_(fl, NPY_ARRAY_WRITEBACKIFCOPY)
);
}
static PyObject*
arrayflags_richcompare(PyObject *self, PyObject *other, int cmp_op)
{
if (!PyObject_TypeCheck(other, &PyArrayFlags_Type)) { // 检查 other 是否为 PyArrayFlagsObject 类型
Py_RETURN_NOTIMPLEMENTED; // 如果不是,返回未实现错误
}
npy_bool eq = ((PyArrayFlagsObject*) self)->flags == // 比较 self 和 other 的 flags 属性是否相等
((PyArrayFlagsObject*) other)->flags;
if (cmp_op == Py_EQ) { // 如果比较操作是等于
return PyBool_FromLong(eq); // 返回布尔值表示是否相等
}
else if (cmp_op == Py_NE) { // 如果比较操作是不等于
return PyBool_FromLong(!eq); // 返回布尔值表示是否不相等
}
else {
Py_RETURN_NOTIMPLEMENTED; // 其他比较操作返回未实现错误
}
}
static PyMappingMethods arrayflags_as_mapping = {
(lenfunc)NULL, /*mp_length*/ // 长度函数为空,表示不支持长度操作
(binaryfunc)arrayflags_getitem, /*mp_subscript*/ // 子script操作使用 arrayflags_getitem 函数
(objobjargproc)arrayflags_setitem, /*mp_ass_subscript*/
(objobjargproc)arrayflags_setitem,
static PyObject *
arrayflags_new(PyTypeObject *NPY_UNUSED(self), PyObject *args, PyObject *NPY_UNUSED(kwds))
{
PyObject *arg=NULL;
// 解包参数args,获取函数的唯一参数arg
if (!PyArg_UnpackTuple(args, "flagsobj", 0, 1, &arg)) {
// 解包失败,返回NULL
return NULL;
}
// 如果arg非空且为PyArray对象
if ((arg != NULL) && PyArray_Check(arg)) {
// 返回一个新的PyArrayFlagsObject对象,其标记与给定的PyArray对象相关联
return PyArray_NewFlagsObject(arg);
}
else {
// 否则,返回一个新的PyArrayFlagsObject对象,不与任何PyArray对象相关联
return PyArray_NewFlagsObject(NULL);
}
}
// 定义PyArrayFlags_Type类型对象
NPY_NO_EXPORT PyTypeObject PyArrayFlags_Type = {
PyVarObject_HEAD_INIT(NULL, 0)
// 对象类型名称
.tp_name = "numpy._core.multiarray.flagsobj",
// 对象基本大小
.tp_basicsize = sizeof(PyArrayFlagsObject),
// 对象析构函数,用于释放对象占用的内存
.tp_dealloc = (destructor)arrayflags_dealloc,
// 对象的字符串表示函数,用于打印对象信息
.tp_repr = (reprfunc)arrayflags_print,
// 对象作为映射类型的接口
.tp_as_mapping = &arrayflags_as_mapping,
// 对象的字符串表示函数,与tp_repr相同
.tp_str = (reprfunc)arrayflags_print,
// 对象的标志位,默认为Py_TPFLAGS_DEFAULT
.tp_flags = Py_TPFLAGS_DEFAULT,
// 对象的富比较函数
.tp_richcompare = arrayflags_richcompare,
// 对象的属性获取和设置函数
.tp_getset = arrayflags_getsets,
// 对象的构造函数,用于创建新对象实例
.tp_new = arrayflags_new,
};
.\numpy\numpy\_core\src\multiarray\flagsobject.h
/* Array Flags Object */
// 定义了一个结构体 PyArrayFlagsObject,用于表示数组的标志信息
typedef struct PyArrayFlagsObject {
PyObject_HEAD
PyObject *arr; // 指向数组对象的指针
int flags; // 数组的标志位
} PyArrayFlagsObject;
// 导出了 PyArrayFlags_Type 类型对象
extern NPY_NO_EXPORT PyTypeObject PyArrayFlags_Type;
// 创建并返回一个新的数组标志对象
NPY_NO_EXPORT PyObject *
PyArray_NewFlagsObject(PyObject *obj);
// 更新数组对象的标志位
NPY_NO_EXPORT void
PyArray_UpdateFlags(PyArrayObject *ret, int flagmask);
.\numpy\numpy\_core\src\multiarray\getset.c
/* Array Descr Object */
/* Define to prevent deprecated API usage */
/* Define to enable multiarray module */
/* Ensure Python.h uses modern Py_ssize_t definitions */
/* Include Python core header */
/* Include structmember.h for C struct and object member API */
/* Include NumPy's array object header */
#include "numpy/arrayobject.h"
/* Include NumPy configuration */
#include "npy_config.h"
/* Include NumPy import utilities */
#include "npy_import.h"
/* Include common utility functions */
#include "common.h"
/* Include conversion utilities */
#include "conversion_utils.h"
/* Include constructors for arrays */
#include "ctors.h"
/* Include dtype meta information */
#include "dtypemeta.h"
/* Include scalar types definitions */
#include "scalartypes.h"
/* Include array descriptor definitions */
#include "descriptor.h"
/* Include flags object definitions */
#include "flagsobject.h"
/* Include getter/setter definitions */
#include "getset.h"
/* Include main array object definitions */
#include "arrayobject.h"
/* Include memory overlap handling */
#include "mem_overlap.h"
/* Include memory allocation utilities */
#include "alloc.h"
/* Include buffer handling utilities */
#include "npy_buffer.h"
/* Include shape manipulation utilities */
#include "shape.h"
/* Include multiarray module utilities */
#include "multiarraymodule.h"
/******************* array attribute get and set routines ******************/
/* Retrieve the number of dimensions of the array */
static PyObject *
array_ndim_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
return PyLong_FromLong(PyArray_NDIM(self));
}
/* Retrieve array flags as a flags object */
static PyObject *
array_flags_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
return PyArray_NewFlagsObject((PyObject *)self);
}
/* Retrieve array shape as a tuple of integers */
static PyObject *
array_shape_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
return PyArray_IntTupleFromIntp(PyArray_NDIM(self), PyArray_DIMS(self));
}
/* Set array shape from a Python object */
static int
array_shape_set(PyArrayObject *self, PyObject *val, void* NPY_UNUSED(ignored))
{
int nd;
PyArrayObject *ret;
/* Check if val is NULL (deletion not allowed) */
if (val == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete array shape");
return -1;
}
/* Attempt to reshape the array using the provided shape */
ret = (PyArrayObject *)PyArray_Reshape(self, val);
if (ret == NULL) {
return -1;
}
/* Check if the reshape operation resulted in a new data pointer */
if (PyArray_DATA(ret) != PyArray_DATA(self)) {
Py_DECREF(ret);
PyErr_SetString(PyExc_AttributeError,
"Incompatible shape for in-place modification. Use "
"`.reshape()` to make a copy with the desired shape.");
return -1;
}
/* Obtain the number of dimensions after reshape */
nd = PyArray_NDIM(ret);
/* Handle array reshaping */
if (nd > 0) {
/* Allocate new dimensions and strides */
npy_intp *_dimensions = npy_alloc_cache_dim(2 * nd);
if (_dimensions == NULL) {
Py_DECREF(ret);
PyErr_NoMemory();
return -1;
}
/* Free old dimensions and strides */
npy_free_cache_dim_array(self);
/* Update array fields with new dimensions and strides */
((PyArrayObject_fields *)self)->nd = nd;
((PyArrayObject_fields *)self)->dimensions = _dimensions;
((PyArrayObject_fields *)self)->strides = _dimensions + nd;
/* Copy new dimensions and strides */
if (nd) {
memcpy(PyArray_DIMS(self), PyArray_DIMS(ret), nd*sizeof(npy_intp));
memcpy(PyArray_STRIDES(self), PyArray_STRIDES(ret), nd*sizeof(npy_intp));
}
}
else {
/* Free old dimensions and strides for zero-dimensional arrays */
npy_free_cache_dim_array(self);
((PyArrayObject_fields *)self)->nd = 0;
((PyArrayObject_fields *)self)->dimensions = NULL;
((PyArrayObject_fields *)self)->strides = NULL;
}
/* Release temporary reshape result */
Py_DECREF(ret);
/* Update array flags */
PyArray_UpdateFlags(self, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS);
/* Successful array shape update */
return 0;
}
/*
* 从数组对象中获取步幅信息,并返回一个包含步幅信息的元组对象
*/
array_strides_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
return PyArray_IntTupleFromIntp(PyArray_NDIM(self), PyArray_STRIDES(self));
}
/*
* 设置数组对象的步幅信息
*/
static int
array_strides_set(PyArrayObject *self, PyObject *obj, void *NPY_UNUSED(ignored))
{
PyArray_Dims newstrides = {NULL, -1};
PyArrayObject *new;
npy_intp numbytes = 0;
npy_intp offset = 0;
npy_intp lower_offset = 0;
npy_intp upper_offset = 0;
Py_buffer view;
if (obj == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete array strides");
return -1;
}
if (!PyArray_OptionalIntpConverter(obj, &newstrides) ||
newstrides.len == -1) {
PyErr_SetString(PyExc_TypeError, "invalid strides");
return -1;
}
if (newstrides.len != PyArray_NDIM(self)) {
PyErr_Format(PyExc_ValueError, "strides must be " \
" same length as shape (%d)", PyArray_NDIM(self));
goto fail;
}
new = self;
while(PyArray_BASE(new) && PyArray_Check(PyArray_BASE(new))) {
new = (PyArrayObject *)(PyArray_BASE(new));
}
/*
* 通过缓冲区接口获取PyArray_BASE(new)的可用内存,如果失败则从当前的new获取
*/
if (PyArray_BASE(new) &&
PyObject_GetBuffer(PyArray_BASE(new), &view, PyBUF_SIMPLE) >= 0) {
offset = PyArray_BYTES(self) - (char *)view.buf;
numbytes = view.len + offset;
PyBuffer_Release(&view);
}
else {
PyErr_Clear();
offset_bounds_from_strides(PyArray_ITEMSIZE(new), PyArray_NDIM(new),
PyArray_DIMS(new), PyArray_STRIDES(new),
&lower_offset, &upper_offset);
offset = PyArray_BYTES(self) - (PyArray_BYTES(new) + lower_offset);
numbytes = upper_offset - lower_offset;
}
/* numbytes == 0 is special here, but the 0-size array case always works */
if (!PyArray_CheckStrides(PyArray_ITEMSIZE(self), PyArray_NDIM(self),
numbytes, offset,
PyArray_DIMS(self), newstrides.ptr)) {
PyErr_SetString(PyExc_ValueError, "strides is not "\
"compatible with available memory");
goto fail;
}
if (newstrides.len) {
memcpy(PyArray_STRIDES(self), newstrides.ptr, sizeof(npy_intp)*newstrides.len);
}
PyArray_UpdateFlags(self, NPY_ARRAY_C_CONTIGUOUS | NPY_ARRAY_F_CONTIGUOUS |
NPY_ARRAY_ALIGNED);
npy_free_cache_dim_obj(newstrides);
return 0;
fail:
npy_free_cache_dim_obj(newstrides);
return -1;
}
/*
* 返回数组对象的优先级作为一个Python浮点数对象
*/
static PyObject *
array_priority_get(PyArrayObject *NPY_UNUSED(self), void *NPY_UNUSED(ignored))
{
return PyFloat_FromDouble(NPY_PRIORITY);
}
/*
* 返回数组对象的类型描述符字符串
*/
static PyObject *
array_typestr_get(PyArrayObject *self)
{
return arraydescr_protocol_typestr_get(PyArray_DESCR(self), NULL);
}
/*
* 继续添加函数定义...
*/
static PyObject *
array_interface_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
PyObject *dict; // 创建一个新的 Python 字典对象,用于存储接口信息
PyObject *obj; // 用于临时存储从其他函数返回的 Python 对象
int ret; // 用于存储 PyDict_SetItemString 函数的返回值
dict = PyDict_New(); // 创建一个新的空字典对象
if (dict == NULL) { // 检查字典创建是否成功
return NULL; // 如果创建失败,返回空指针
}
/* dataptr */
obj = array_dataptr_get(self, NULL); // 调用 array_dataptr_get 获取数据指针对象
ret = PyDict_SetItemString(dict, "data", obj); // 将数据指针对象存入字典中
Py_DECREF(obj); // 减少数据指针对象的引用计数
if (ret < 0) { // 检查 PyDict_SetItemString 是否成功
Py_DECREF(dict); // 失败时释放字典对象并返回空指针
return NULL;
}
obj = array_protocol_strides_get(self); // 调用 array_protocol_strides_get 获取步幅对象
ret = PyDict_SetItemString(dict, "strides", obj); // 将步幅对象存入字典中
Py_DECREF(obj); // 减少步幅对象的引用计数
if (ret < 0) { // 检查 PyDict_SetItemString 是否成功
Py_DECREF(dict); // 失败时释放字典对象并返回空指针
return NULL;
}
obj = array_protocol_descr_get(self); // 调用 array_protocol_descr_get 获取描述符对象
ret = PyDict_SetItemString(dict, "descr", obj); // 将描述符对象存入字典中
Py_DECREF(obj); // 减少描述符对象的引用计数
if (ret < 0) { // 检查 PyDict_SetItemString 是否成功
Py_DECREF(dict); // 失败时释放字典对象并返回空指针
return NULL;
}
obj = arraydescr_protocol_typestr_get(PyArray_DESCR(self), NULL); // 调用 arraydescr_protocol_typestr_get 获取类型字符串对象
ret = PyDict_SetItemString(dict, "typestr", obj); // 将类型字符串对象存入字典中
Py_DECREF(obj); // 减少类型字符串对象的引用计数
if (ret < 0) { // 检查 PyDict_SetItemString 是否成功
Py_DECREF(dict); // 失败时释放字典对象并返回空指针
return NULL;
}
obj = array_shape_get(self, NULL); // 调用 array_shape_get 获取形状对象
ret = PyDict_SetItemString(dict, "shape", obj); // 将形状对象存入字典中
Py_DECREF(obj); // 减少形状对象的引用计数
// 返回填充完毕的字典对象,包含了"data", "strides", "descr", "typestr", "shape"等键对应的值
return dict;
}
# 如果 ret 小于 0,则表示在之前的操作中出现了错误,需要释放字典对象并返回空指针
if (ret < 0) {
Py_DECREF(dict);
return NULL;
}
# 创建一个整数对象,表示版本号为 3
obj = PyLong_FromLong(3);
# 将整数对象作为值,键为 "version",添加到字典中
ret = PyDict_SetItemString(dict, "version", obj);
Py_DECREF(obj);
# 如果 ret 小于 0,则表示在设置字典项时出现了错误,需要释放字典对象并返回空指针
if (ret < 0) {
Py_DECREF(dict);
return NULL;
}
# 返回已经填充好的字典对象
return dict;
}
/*
* 返回一个内存视图对象,表示数组的数据
*/
static PyObject *
array_data_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
return PyMemoryView_FromObject((PyObject *)self);
}
/*
* 返回一个长整型对象,表示数组每个元素的字节大小
*/
static PyObject *
array_itemsize_get(PyArrayObject *self, void* NPY_UNUSED(ignored))
{
return PyLong_FromLong((long) PyArray_ITEMSIZE(self));
}
/*
* 返回一个整型对象,表示数组中元素的总数
*/
static PyObject *
array_size_get(PyArrayObject *self, void* NPY_UNUSED(ignored))
{
return PyArray_PyIntFromIntp(PyArray_SIZE(self));
}
/*
* 返回一个整型对象,表示数组所占用的总字节数
*/
static PyObject *
array_nbytes_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
return PyArray_PyIntFromIntp(PyArray_NBYTES(self));
}
/*
* 当数组的数据类型发生改变时调用此函数。
* 若itemsize保持不变或者数组是单段的(连续或Fortran),并且维度兼容,
* 则形状和步长也将被相应调整。
*/
static int
array_descr_set(PyArrayObject *self, PyObject *arg, void *NPY_UNUSED(ignored))
{
PyArray_Descr *newtype = NULL;
// 如果传入的arg为NULL,则不能删除数组的数据类型,抛出异常
if (arg == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete array dtype");
return -1;
}
// 尝试将arg转换为PyArray_Descr类型的对象newtype
if (!(PyArray_DescrConverter(arg, &newtype)) ||
newtype == NULL) {
PyErr_SetString(PyExc_TypeError,
"invalid data-type for array");
return -1;
}
/* 检查是否涉及到包含对象的内存重解释 */
if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(newtype)) {
PyObject *safe;
// 导入numpy._core._internal模块中的_view_is_safe函数
npy_cache_import("numpy._core._internal", "_view_is_safe",
&npy_thread_unsafe_state._view_is_safe);
if (npy_thread_unsafe_state._view_is_safe == NULL) {
goto fail;
}
// 调用_view_is_safe函数,检查是否安全
safe = PyObject_CallFunction(npy_thread_unsafe_state._view_is_safe,
"OO", PyArray_DESCR(self), newtype);
if (safe == NULL) {
goto fail;
}
Py_DECREF(safe);
}
/*
* 若新类型是无大小的void类型,则它的大小应该与当前dtype的itemsize相匹配。
* 若不匹配,则将newtype的大小调整为当前dtype的itemsize。
*/
if (newtype->type_num == NPY_VOID &&
PyDataType_ISUNSIZED(newtype) &&
newtype->elsize != PyArray_ITEMSIZE(self)) {
PyArray_DESCR_REPLACE(newtype);
if (newtype == NULL) {
return -1;
}
newtype->elsize = PyArray_ITEMSIZE(self);
}
/* 更改dtype的大小会导致形状发生变化 */
// 检查新数据类型的元素大小是否与数组元素大小相同
if (newtype->elsize != PyArray_ITEMSIZE(self)) {
/* 禁止的情况 */
// 如果数组是0维的,只有当元素大小不变时才支持更改数据类型
if (PyArray_NDIM(self) == 0) {
PyErr_SetString(PyExc_ValueError,
"Changing the dtype of a 0d array is only supported "
"if the itemsize is unchanged");
// 跳转到错误处理标签
goto fail;
}
// 如果新数据类型是子数组类型,只有当总元素大小不变时才支持更改数据类型
else if (PyDataType_HASSUBARRAY(newtype)) {
PyErr_SetString(PyExc_ValueError,
"Changing the dtype to a subarray type is only supported "
"if the total itemsize is unchanged");
// 跳转到错误处理标签
goto fail;
}
/* 只在最后一个轴上调整大小 */
int axis = PyArray_NDIM(self) - 1;
// 如果最后一个轴的维度不为1,并且数组大小不为0,并且最后一个轴上的步长不等于元素大小,则报错
if (PyArray_DIMS(self)[axis] != 1 &&
PyArray_SIZE(self) != 0 &&
PyArray_STRIDES(self)[axis] != PyArray_ITEMSIZE(self)) {
PyErr_SetString(PyExc_ValueError,
"To change to a dtype of a different size, the last axis "
"must be contiguous");
// 跳转到错误处理标签
goto fail;
}
npy_intp newdim;
// 如果新数据类型的元素大小小于数组的元素大小
if (newtype->elsize < PyArray_ITEMSIZE(self)) {
/* 如果兼容,增加最后一个轴的大小 */
// 如果新数据类型的元素大小为0或者原始数据类型的大小不能整除新数据类型的大小,则报错
if (newtype->elsize == 0 ||
PyArray_ITEMSIZE(self) % newtype->elsize != 0) {
PyErr_SetString(PyExc_ValueError,
"When changing to a smaller dtype, its size must be a "
"divisor of the size of original dtype");
// 跳转到错误处理标签
goto fail;
}
// 计算新的维度大小
newdim = PyArray_ITEMSIZE(self) / newtype->elsize;
// 更新最后一个轴的维度
PyArray_DIMS(self)[axis] *= newdim;
// 更新最后一个轴的步长为新数据类型的元素大小
PyArray_STRIDES(self)[axis] = newtype->elsize;
}
else /* newtype->elsize > PyArray_ITEMSIZE(self) */ {
/* 如果兼容,减少相关轴的大小 */
// 计算新的维度大小
newdim = PyArray_DIMS(self)[axis] * PyArray_ITEMSIZE(self);
// 如果不能整除新数据类型的大小,则报错
if ((newdim % newtype->elsize) != 0) {
PyErr_SetString(PyExc_ValueError,
"When changing to a larger dtype, its size must be a "
"divisor of the total size in bytes of the last axis "
"of the array.");
// 跳转到错误处理标签
goto fail;
}
// 更新最后一个轴的维度
PyArray_DIMS(self)[axis] = newdim / newtype->elsize;
// 更新最后一个轴的步长为新数据类型的元素大小
PyArray_STRIDES(self)[axis] = newtype->elsize;
}
}
/* 将视图作为子数组会增加维数 */
if (PyDataType_HASSUBARRAY(newtype)) {
/*
* 如果新类型有子数组,
* 创建新的数组对象,并从中更新维度、步长和描述符
*/
PyArrayObject *temp;
/*
* 在这里我们会减少 newtype 的引用计数。
* temp 将会获取它的引用
*/
temp = (PyArrayObject *)
PyArray_NewFromDescr(&PyArray_Type, newtype, PyArray_NDIM(self),
PyArray_DIMS(self), PyArray_STRIDES(self),
PyArray_DATA(self), PyArray_FLAGS(self), NULL);
if (temp == NULL) {
return -1;
}
npy_free_cache_dim_array(self);
((PyArrayObject_fields *)self)->dimensions = PyArray_DIMS(temp);
((PyArrayObject_fields *)self)->nd = PyArray_NDIM(temp);
((PyArrayObject_fields *)self)->strides = PyArray_STRIDES(temp);
newtype = PyArray_DESCR(temp);
Py_INCREF(PyArray_DESCR(temp));
/* 避免释放器删除这些 */
((PyArrayObject_fields *)temp)->nd = 0;
((PyArrayObject_fields *)temp)->dimensions = NULL;
Py_DECREF(temp);
}
Py_DECREF(PyArray_DESCR(self));
((PyArrayObject_fields *)self)->descr = newtype;
PyArray_UpdateFlags(self, NPY_ARRAY_UPDATE_ALL);
return 0;
fail:
Py_DECREF(newtype);
return -1;
static PyObject *
array_struct_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 分配内存以存储 PyArrayInterface 结构
PyArrayInterface *inter;
inter = (PyArrayInterface *)PyArray_malloc(sizeof(PyArrayInterface));
// 检查内存分配是否成功
if (inter==NULL) {
// 内存分配失败,返回内存错误异常
return PyErr_NoMemory();
}
// 设置 PyArrayInterface 结构的字段
inter->two = 2;
inter->nd = PyArray_NDIM(self);
inter->typekind = PyArray_DESCR(self)->kind;
inter->itemsize = PyArray_ITEMSIZE(self);
inter->flags = PyArray_FLAGS(self);
if (inter->flags & NPY_ARRAY_WARN_ON_WRITE) {
/* Export a warn-on-write array as read-only */
// 如果数组标志包含 NPY_ARRAY_WARN_ON_WRITE,将其设置为只读
inter->flags = inter->flags & ~NPY_ARRAY_WARN_ON_WRITE;
inter->flags = inter->flags & ~NPY_ARRAY_WRITEABLE;
}
/* reset unused flags */
// 重置未使用的标志位
inter->flags &= ~(NPY_ARRAY_WRITEBACKIFCOPY | NPY_ARRAY_OWNDATA);
if (PyArray_ISNOTSWAPPED(self)) inter->flags |= NPY_ARRAY_NOTSWAPPED;
/*
* Copy shape and strides over since these can be reset
* when the array is "reshaped".
*/
// 复制形状和步幅,因为这些在“重塑”数组时可能会被重置
if (PyArray_NDIM(self) > 0) {
// 分配内存以存储形状和步幅
inter->shape = (npy_intp *)PyArray_malloc(2*sizeof(npy_intp)*PyArray_NDIM(self));
if (inter->shape == NULL) {
// 内存分配失败,释放已分配的内存,并返回内存错误异常
PyArray_free(inter);
return PyErr_NoMemory();
}
// 设置步幅为形状数组的末尾
inter->strides = inter->shape + PyArray_NDIM(self);
if (PyArray_NDIM(self)) {
// 复制形状和步幅数据
memcpy(inter->shape, PyArray_DIMS(self), sizeof(npy_intp)*PyArray_NDIM(self));
memcpy(inter->strides, PyArray_STRIDES(self), sizeof(npy_intp)*PyArray_NDIM(self));
}
}
else {
// 数组没有维度,设置形状和步幅为 NULL
inter->shape = NULL;
inter->strides = NULL;
}
// 设置数据指针
inter->data = PyArray_DATA(self);
if (PyDataType_HASFIELDS(PyArray_DESCR(self))) {
// 如果数据类型有字段,获取字段描述符
inter->descr = arraydescr_protocol_descr_get(PyArray_DESCR(self), NULL);
if (inter->descr == NULL) {
// 获取描述符失败,清除错误状态
PyErr_Clear();
}
else {
// 设置数组描述符标志位
inter->flags &= NPY_ARR_HAS_DESCR;
}
}
else {
// 没有字段,描述符设置为 NULL
inter->descr = NULL;
}
// 创建 PyCapsule 对象来封装 inter 结构
PyObject *ret = PyCapsule_New(inter, NULL, gentype_struct_free);
if (ret == NULL) {
// 创建 PyCapsule 对象失败,返回 NULL
return NULL;
}
// 增加数组对象的引用计数
Py_INCREF(self);
// 将数组对象设置为 PyCapsule 对象的上下文
if (PyCapsule_SetContext(ret, self) < 0) {
// 设置上下文失败,返回 NULL
return NULL;
}
// 返回 PyCapsule 对象
return ret;
}
static PyObject *
array_base_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 检查数组的基础对象是否为 NULL
if (PyArray_BASE(self) == NULL) {
// 基础对象为 NULL,返回 None
Py_RETURN_NONE;
}
else {
// 增加基础对象的引用计数,并返回基础对象
Py_INCREF(PyArray_BASE(self));
return PyArray_BASE(self);
}
}
/*
* Create a view of a complex array with an equivalent data-type
* except it is real instead of complex.
*/
static PyArrayObject *
_get_part(PyArrayObject *self, int imag)
{
// 定义浮点类型编号、数据类型、返回的数组对象、偏移量等变量
int float_type_num;
PyArray_Descr *type;
PyArrayObject *ret;
int offset;
# 根据当前数组的描述符中的类型号进行切换
switch (PyArray_DESCR(self)->type_num) {
# 如果是复数浮点数,设置对应的浮点数类型号为NPY_FLOAT
case NPY_CFLOAT:
float_type_num = NPY_FLOAT;
break;
# 如果是双精度复数浮点数,设置浮点数类型号为NPY_DOUBLE
case NPY_CDOUBLE:
float_type_num = NPY_DOUBLE;
break;
# 如果是长双精度复数浮点数,设置浮点数类型号为NPY_LONGDOUBLE
case NPY_CLONGDOUBLE:
float_type_num = NPY_LONGDOUBLE;
break;
# 如果以上情况都不匹配,则抛出异常并返回NULL
default:
PyErr_Format(PyExc_ValueError,
"Cannot convert complex type number %d to float",
PyArray_DESCR(self)->type_num);
return NULL;
}
# 根据浮点数类型号获取对应的描述符
type = PyArray_DescrFromType(float_type_num);
# 如果获取描述符失败,则返回NULL
if (type == NULL) {
return NULL;
}
# 如果imag为真,则偏移量为描述符元素大小,否则偏移量为0
offset = (imag ? type->elsize : 0);
# 如果数组的字节顺序不是本机字节顺序
if (!PyArray_ISNBO(PyArray_DESCR(self)->byteorder)) {
# 复制描述符并检查是否成功
Py_SETREF(type, PyArray_DescrNew(type));
if (type == NULL) {
return NULL;
}
# 设置复制后的描述符的字节顺序与数组的字节顺序相同
type->byteorder = PyArray_DESCR(self)->byteorder;
}
# 使用提供的描述符和数据创建新的数组对象
ret = (PyArrayObject *)PyArray_NewFromDescrAndBase(
Py_TYPE(self),
type,
PyArray_NDIM(self),
PyArray_DIMS(self),
PyArray_STRIDES(self),
PyArray_BYTES(self) + offset,
PyArray_FLAGS(self), (PyObject *)self, (PyObject *)self);
# 如果创建数组对象失败,则返回NULL
if (ret == NULL) {
return NULL;
}
# 返回创建的数组对象
return ret;
/* For Object arrays, we need to get and set the
real part of each element.
*/
static PyObject *
array_real_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
PyArrayObject *ret;
// 如果数组是复数类型的
if (PyArray_ISCOMPLEX(self)) {
// 调用内部函数获取数组的实部
ret = _get_part(self, 0);
return (PyObject *)ret;
}
else {
// 如果不是复数类型,增加引用计数并返回自身
Py_INCREF(self);
return (PyObject *)self;
}
}
static int
array_real_set(PyArrayObject *self, PyObject *val, void *NPY_UNUSED(ignored))
{
PyArrayObject *ret;
PyArrayObject *new;
int retcode;
// 如果传入的值是空,则无法删除数组的实部
if (val == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete array real part");
return -1;
}
// 如果数组是复数类型的
if (PyArray_ISCOMPLEX(self)) {
// 调用内部函数获取数组的实部
ret = _get_part(self, 0);
if (ret == NULL) {
return -1;
}
}
else {
// 如果不是复数类型,增加引用计数并返回自身
Py_INCREF(self);
ret = self;
}
// 将传入的值转换为数组对象
new = (PyArrayObject *)PyArray_FROM_O(val);
if (new == NULL) {
Py_DECREF(ret);
return -1;
}
// 将新值复制到实部数组中
retcode = PyArray_CopyInto(ret, new);
Py_DECREF(ret);
Py_DECREF(new);
return retcode;
}
/* For Object arrays we need to get
and set the imaginary part of
each element
*/
static PyObject *
array_imag_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
PyArrayObject *ret;
// 如果数组是复数类型的
if (PyArray_ISCOMPLEX(self)) {
// 调用内部函数获取数组的虚部
ret = _get_part(self, 1);
}
else {
// 如果不是复数类型,增加描述符的引用计数并创建一个新的数组对象
Py_INCREF(PyArray_DESCR(self));
ret = (PyArrayObject *)PyArray_NewFromDescr_int(
Py_TYPE(self),
PyArray_DESCR(self),
PyArray_NDIM(self),
PyArray_DIMS(self),
NULL, NULL,
PyArray_ISFORTRAN(self),
(PyObject *)self, NULL, _NPY_ARRAY_ZEROED);
if (ret == NULL) {
return NULL;
}
// 清除可写标志,使得数组不可写
PyArray_CLEARFLAGS(ret, NPY_ARRAY_WRITEABLE);
}
return (PyObject *) ret;
}
static int
array_imag_set(PyArrayObject *self, PyObject *val, void *NPY_UNUSED(ignored))
{
// 如果传入的值是空,则无法删除数组的虚部
if (val == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete array imaginary part");
return -1;
}
// 如果数组是复数类型的
if (PyArray_ISCOMPLEX(self)) {
PyArrayObject *ret;
PyArrayObject *new;
int retcode;
// 调用内部函数获取数组的虚部
ret = _get_part(self, 1);
if (ret == NULL) {
return -1;
}
// 将传入的值转换为数组对象
new = (PyArrayObject *)PyArray_FROM_O(val);
if (new == NULL) {
Py_DECREF(ret);
return -1;
}
// 将新值复制到虚部数组中
retcode = PyArray_CopyInto(ret, new);
Py_DECREF(ret);
Py_DECREF(new);
return retcode;
}
else {
// 如果不是复数类型,抛出类型错误
PyErr_SetString(PyExc_TypeError,
"array does not have imaginary part to set");
return -1;
}
}
static PyObject *
array_flat_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 返回一个数组的迭代器对象
return PyArray_IterNew((PyObject *)self);
}
static int
array_flat_set(PyArrayObject *self, PyObject *val, void *NPY_UNUSED(ignored))
{
// 省略部分代码,不在注释范围内
}
// 声明一个指向 PyArrayObject 结构体的指针,用于存储数组对象
PyArrayObject *arr = NULL;
// 初始化返回值为 -1,表示函数执行失败
int retval = -1;
// 声明两个迭代器对象的指针,分别用于自身数组和传入数组
PyArrayIterObject *selfit = NULL, *arrit = NULL;
// 声明一个指向数组描述符的指针,用于存储数组的类型信息
PyArray_Descr *typecode;
// 用于存储是否需要交换字节顺序的标志
int swap;
// 声明一个函数指针,指向复制和交换数据的函数
PyArray_CopySwapFunc *copyswap;
// 如果传入的值为 NULL,则设置异常并返回 -1
if (val == NULL) {
PyErr_SetString(PyExc_AttributeError,
"Cannot delete array flat iterator");
return -1;
}
// 检查当前数组是否可写,若不可写则返回 -1
if (PyArray_FailUnlessWriteable(self, "array") < 0) return -1;
// 获取当前数组的类型描述符,并增加其引用计数
typecode = PyArray_DESCR(self);
Py_INCREF(typecode);
// 尝试根据传入的值创建一个新的数组对象,并强制类型转换为当前数组的类型
arr = (PyArrayObject *)PyArray_FromAny(val, typecode,
0, 0, NPY_ARRAY_FORCECAST | PyArray_FORTRAN_IF(self), NULL);
// 如果创建数组对象失败,则返回 -1
if (arr == NULL) {
return -1;
}
// 创建传入数组的迭代器对象
arrit = (PyArrayIterObject *)PyArray_IterNew((PyObject *)arr);
// 如果创建迭代器对象失败,则跳转到退出标签
if (arrit == NULL) {
goto exit;
}
// 创建自身数组的迭代器对象
selfit = (PyArrayIterObject *)PyArray_IterNew((PyObject *)self);
// 如果创建迭代器对象失败,则跳转到退出标签
if (selfit == NULL) {
goto exit;
}
// 如果传入数组的大小为 0,则直接将返回值设为 0 并跳转到退出标签
if (arrit->size == 0) {
retval = 0;
goto exit;
}
// 判断是否需要交换字节顺序
swap = PyArray_ISNOTSWAPPED(self) != PyArray_ISNOTSWAPPED(arr);
// 获取当前数组类型描述符的复制和交换函数
copyswap = PyDataType_GetArrFuncs(PyArray_DESCR(self))->copyswap;
// 如果当前数组的数据类型需要引用计数检查,则执行以下循环
if (PyDataType_REFCHK(PyArray_DESCR(self))) {
// 在自身数组的迭代器上进行循环,释放每个元素的引用计数,并增加传入数组对应元素的引用计数
while (selfit->index < selfit->size) {
PyArray_Item_XDECREF(selfit->dataptr, PyArray_DESCR(self));
PyArray_Item_INCREF(arrit->dataptr, PyArray_DESCR(arr));
// 使用 memmove 函数复制传入数组的元素到自身数组中
memmove(selfit->dataptr, arrit->dataptr, sizeof(PyObject **));
// 如果需要交换字节顺序,则调用对应的交换函数
if (swap) {
copyswap(selfit->dataptr, NULL, swap, self);
}
// 更新迭代器的指针位置
PyArray_ITER_NEXT(selfit);
PyArray_ITER_NEXT(arrit);
// 如果传入数组的迭代器达到末尾,则重置迭代器
if (arrit->index == arrit->size) {
PyArray_ITER_RESET(arrit);
}
}
// 设置返回值为 0 并跳转到退出标签
retval = 0;
goto exit;
}
// 若当前数组的数据类型不需要引用计数检查,则执行以下循环
while(selfit->index < selfit->size) {
// 调用复制和可能的交换函数,将传入数组的元素复制到自身数组中
copyswap(selfit->dataptr, arrit->dataptr, swap, self);
// 更新迭代器的指针位置
PyArray_ITER_NEXT(selfit);
PyArray_ITER_NEXT(arrit);
// 如果传入数组的迭代器达到末尾,则重置迭代器
if (arrit->index == arrit->size) {
PyArray_ITER_RESET(arrit);
}
}
// 设置返回值为 0
exit:
// 释放迭代器对象的引用计数
Py_XDECREF(selfit);
Py_XDECREF(arrit);
Py_XDECREF(arr);
// 返回函数的执行结果
return retval;
}
// 定义静态函数 `array_transpose_get`,用于获取数组的转置
static PyObject *
array_transpose_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 调用 NumPy 提供的函数 PyArray_Transpose 对数组进行转置操作
return PyArray_Transpose(self, NULL);
}
// 定义静态函数 `array_matrix_transpose_get`,用于获取矩阵的转置
static PyObject *
array_matrix_transpose_get(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 调用 NumPy 提供的函数 PyArray_MatrixTranspose 对矩阵进行转置操作
return PyArray_MatrixTranspose(self);
}
// 定义静态函数 `array_ptp`,处理 ptp 属性被移除的错误情况
static PyObject *
array_ptp(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 设置异常消息,说明 ptp 属性在 NumPy 2.0 版本中被移除,提供替代方法的建议
PyErr_SetString(PyExc_AttributeError,
"`ptp` was removed from the ndarray class in NumPy 2.0. "
"Use np.ptp(arr, ...) instead.");
// 返回空指针,表示操作失败
return NULL;
}
// 定义静态函数 `array_newbyteorder`,处理 newbyteorder 属性被移除的错误情况
static PyObject *
array_newbyteorder(PyArrayObject *self, PyObject *args)
{
// 设置异常消息,说明 newbyteorder 属性在 NumPy 2.0 版本中被移除,提供替代方法的建议
PyErr_SetString(PyExc_AttributeError,
"`newbyteorder` was removed from the ndarray class "
"in NumPy 2.0. "
"Use `arr.view(arr.dtype.newbyteorder(order))` instead.");
// 返回空指针,表示操作失败
return NULL;
}
// 定义静态函数 `array_itemset`,处理 itemset 属性被移除的错误情况
static PyObject *
array_itemset(PyArrayObject *self, PyObject *args)
{
// 设置异常消息,说明 itemset 属性在 NumPy 2.0 版本中被移除,提供替代方法的建议
PyErr_SetString(PyExc_AttributeError,
"`itemset` was removed from the ndarray class in "
"NumPy 2.0. Use `arr[index] = value` instead.");
// 返回空指针,表示操作失败
return NULL;
}
// 定义静态函数 `array_device`,返回数组对象所在的设备类型
static PyObject *
array_device(PyArrayObject *self, void *NPY_UNUSED(ignored))
{
// 返回一个 PyUnicode 对象,表示该数组位于 CPU 设备上
return PyUnicode_FromString("cpu");
}
{"itemset",
(getter)array_itemset,
NULL,
NULL, NULL},
# 创建一个元组项,包含属性名称 "itemset" 和一个函数指针 (getter)array_itemset,
# 后面的三个元素都设置为 NULL,表示没有其他特定的设置。
{"device",
(getter)array_device,
NULL,
NULL, NULL},
# 创建一个元组项,包含属性名称 "device" 和一个函数指针 (getter)array_device,
# 后面的三个元素都设置为 NULL,表示没有其他特定的设置。
{"__array_interface__",
(getter)array_interface_get,
NULL,
NULL, NULL},
# 创建一个元组项,包含属性名称 "__array_interface__" 和一个函数指针 (getter)array_interface_get,
# 后面的三个元素都设置为 NULL,表示没有其他特定的设置。
{"__array_struct__",
(getter)array_struct_get,
NULL,
NULL, NULL},
# 创建一个元组项,包含属性名称 "__array_struct__" 和一个函数指针 (getter)array_struct_get,
# 后面的三个元素都设置为 NULL,表示没有其他特定的设置。
{"__array_priority__",
(getter)array_priority_get,
NULL,
NULL, NULL},
# 创建一个元组项,包含属性名称 "__array_priority__" 和一个函数指针 (getter)array_priority_get,
# 后面的三个元素都设置为 NULL,表示没有其他特定的设置。
{NULL, NULL, NULL, NULL, NULL}, /* Sentinel */
# 创建一个 Sentinel(哨兵)项,所有的元素都为 NULL,用于标志元组的结束。
};
/****************** end of attribute get and set routines *******************/
注释:
};
这是一个注释行,表示以下代码段是 JavaScript 或类似语言中的对象或结构的结束。
/****************** end of attribute get and set routines *******************/
这是一个注释行,标记了上面代码段的结束,说明之前的代码段包含了属性的获取和设置相关的例程。
.\numpy\numpy\_core\src\multiarray\getset.h
// 定义了一个条件编译指令,用于防止重复包含本文件内容
// 如果 NUMPY_CORE_SRC_MULTIARRAY_GETSET_H_ 这个宏未定义,则执行下面的内容
extern NPY_NO_EXPORT PyGetSetDef array_getsetlist[];
.\numpy\numpy\_core\src\multiarray\hashdescr.c
/*
* Define NPY_NO_DEPRECATED_API to the current NPY_API_VERSION
* and _MULTIARRAYMODULE for compilation purposes.
*/
/*
* Clean PY_SSIZE_T_CLEAN to ensure Python.h defines Py_ssize_t.
* Include necessary headers for Python and NumPy.
*/
/*
* Include the configuration header for NumPy and hash descriptor header.
*/
/*
* How does this work ? The hash is computed from a list which contains all the
* information specific to a type. The hard work is to build the list
* (_array_descr_walk). The list is built as follows:
* * If the dtype is builtin (no fields, no subarray), then the list
* contains 6 items which uniquely define one dtype (_array_descr_builtin)
* * If the dtype is a compound array, one walk on each field. For each
* field, we append title, names, offset to the final list used for
* hashing, and then append the list recursively built for each
* corresponding dtype (_array_descr_walk_fields)
* * If the dtype is a subarray, one adds the shape tuple to the list, and
* then append the list recursively built for each corresponding dtype
* (_array_descr_walk_subarray)
*/
/*
* Static function declarations for internal use.
*/
static int _is_array_descr_builtin(PyArray_Descr* descr);
static int _array_descr_walk(PyArray_Descr* descr, PyObject *l);
static int _array_descr_walk_fields(PyObject *names, PyObject* fields, PyObject* l);
static int _array_descr_builtin(PyArray_Descr* descr, PyObject *l);
/*
* Normalize endian character: always return 'I', '<' or '>'
*/
static char _normalize_byteorder(char byteorder)
{
switch(byteorder) {
case '=':
if (PyArray_GetEndianness() == NPY_CPU_BIG) {
return '>';
}
else {
return '<';
}
default:
return byteorder;
}
}
/*
* Return true if descr is a builtin type
*/
static int _is_array_descr_builtin(PyArray_Descr* descr)
{
if (PyDataType_HASFIELDS(descr)) {
return 0;
}
if (PyDataType_HASSUBARRAY(descr)) {
return 0;
}
return 1;
}
/*
* Add to l all the items which uniquely define a builtin type
*/
static int _array_descr_builtin(PyArray_Descr* descr, PyObject *l)
{
Py_ssize_t i;
PyObject *t, *item;
char nbyteorder = _normalize_byteorder(descr->byteorder);
/*
* For builtin type, hash relies on : kind + byteorder + flags +
* type_num + elsize + alignment
*/
t = Py_BuildValue("(cccii)", descr->kind, nbyteorder,
descr->flags, descr->elsize, descr->alignment);
for(i = 0; i < PyTuple_Size(t); ++i) {
item = PyTuple_GetItem(t, i);
if (item == NULL) {
PyErr_SetString(PyExc_SystemError,
"(Hash) Error while computing builtin hash");
goto clean_t;
}
PyList_Append(l, item);
}
Py_DECREF(t);
return 0;
clean_t:
Py_DECREF(t);
return -1;
}
/*
* Walk inside the fields and add every item which will be used for hashing
* into the list l
*
* Return 0 on success
*/
/*
* 遍历数组描述字段,将字段名、描述符、偏移量添加到列表 l 中
*
* 如果 names 不是元组,则设置异常并返回 -1
*/
static int _array_descr_walk_fields(PyObject *names, PyObject* fields, PyObject* l)
{
PyObject *key, *value, *foffset, *fdescr, *ftitle;
Py_ssize_t pos = 0;
int st;
if (!PyTuple_Check(names)) {
PyErr_SetString(PyExc_SystemError,
"(Hash) names is not a tuple ???");
return -1;
}
/*
* 如果 fields 不是字典,则设置异常并返回 -1
*/
if (!PyDict_Check(fields)) {
PyErr_SetString(PyExc_SystemError,
"(Hash) fields is not a dict ???");
return -1;
}
for (pos = 0; pos < PyTuple_GET_SIZE(names); pos++) {
/*
* 对于每个字段,将键、描述符、偏移量添加到 l 中
*/
key = PyTuple_GET_ITEM(names, pos);
value = PyDict_GetItem(fields, key);
/* XXX: 这些检查是否必要? */
/*
* 如果值为空,则设置异常并返回 -1
*/
if (value == NULL) {
PyErr_SetString(PyExc_SystemError,
"(Hash) names and fields inconsistent ???");
return -1;
}
/*
* 如果键不是 Unicode 字符串,则设置异常并返回 -1
*/
if (!PyUnicode_Check(key)) {
PyErr_SetString(PyExc_SystemError,
"(Hash) key of dtype dict not a string ???");
return -1;
}
/*
* 如果值不是元组,则设置异常并返回 -1
*/
if (!PyTuple_Check(value)) {
PyErr_SetString(PyExc_SystemError,
"(Hash) value of dtype dict not a dtype ???");
return -1;
}
/*
* 如果元组中的项少于 2 个,则设置异常并返回 -1
*/
if (PyTuple_GET_SIZE(value) < 2) {
PyErr_SetString(PyExc_SystemError,
"(Hash) Less than 2 items in dtype dict ???");
return -1;
}
/*
* 将键添加到列表 l 中
*/
PyList_Append(l, key);
/*
* 获取元组中的描述符,检查是否为有效的描述符
*/
fdescr = PyTuple_GET_ITEM(value, 0);
if (!PyArray_DescrCheck(fdescr)) {
PyErr_SetString(PyExc_SystemError,
"(Hash) First item in compound dtype tuple not a descr ???");
return -1;
}
else {
Py_INCREF(fdescr);
/*
* 递归调用 _array_descr_walk 处理描述符
*/
st = _array_descr_walk((PyArray_Descr*)fdescr, l);
Py_DECREF(fdescr);
if (st) {
return -1;
}
}
/*
* 获取元组中的偏移量,检查是否为整数类型
*/
foffset = PyTuple_GET_ITEM(value, 1);
if (!PyLong_Check(foffset)) {
PyErr_SetString(PyExc_SystemError,
"(Hash) Second item in compound dtype tuple not an int ???");
return -1;
}
else {
/*
* 将偏移量添加到列表 l 中
*/
PyList_Append(l, foffset);
}
/*
* 如果元组中的项数大于 2,则获取并添加第三项到列表 l 中
*/
if (PyTuple_GET_SIZE(value) > 2) {
ftitle = PyTuple_GET_ITEM(value, 2);
PyList_Append(l, ftitle);
}
}
return 0;
}
/*
* 遍历子数组描述,将其形状和描述符本身添加到列表 l 中
*
* 成功时返回 0
*/
static int _array_descr_walk_subarray(PyArray_ArrayDescr* adescr, PyObject *l)
{
PyObject *item;
Py_ssize_t i;
int st;
/*
* 将形状和描述符本身添加到要哈希的对象列表中
*/
// 检查 adescr 结构体中的 shape 成员是否为元组类型
if (PyTuple_Check(adescr->shape)) {
// 遍历元组中的每个元素
for(i = 0; i < PyTuple_Size(adescr->shape); ++i) {
// 获取元组中的第 i 个元素
item = PyTuple_GetItem(adescr->shape, i);
// 如果获取的元素为空,则设置异常并返回错误码 -1
if (item == NULL) {
PyErr_SetString(PyExc_SystemError,
"(Hash) Error while getting shape item of subarray dtype ???");
return -1;
}
// 将获取的元素添加到列表 l 中
PyList_Append(l, item);
}
}
// 如果 adescr 结构体中的 shape 成员是整数类型
else if (PyLong_Check(adescr->shape)) {
// 将整数类型的 shape 添加到列表 l 中
PyList_Append(l, adescr->shape);
}
// 如果 shape 不是元组也不是整数类型,则设置异常并返回错误码 -1
else {
PyErr_SetString(PyExc_SystemError,
"(Hash) Shape of subarray dtype neither a tuple or int ???");
return -1;
}
// 增加 adescr 结构体中 base 成员的引用计数
Py_INCREF(adescr->base);
// 递归调用 _array_descr_walk 函数,处理 adescr 结构体中 base 成员,将结果保存在列表 l 中
st = _array_descr_walk(adescr->base, l);
// 减少 adescr 结构体中 base 成员的引用计数
Py_DECREF(adescr->base);
// 返回 _array_descr_walk 函数的返回值
return st;
/*
* 'Root' function to walk into a dtype. May be called recursively
*/
static int _array_descr_walk(PyArray_Descr* descr, PyObject *l)
{
int st;
// 检查描述符是否是内置数组描述符,如果是则调用内置函数处理
if (_is_array_descr_builtin(descr)) {
return _array_descr_builtin(descr, l);
}
else {
// 将描述符转换为旧版数组描述符对象
_PyArray_LegacyDescr *ldescr = (_PyArray_LegacyDescr *)descr;
// 如果字段不为空,则遍历字段进行处理
if(ldescr->fields != NULL && ldescr->fields != Py_None) {
st = _array_descr_walk_fields(ldescr->names, ldescr->fields, l);
if (st) {
return -1;
}
}
// 如果存在子数组描述符,则递归处理子数组描述符
if(ldescr->subarray != NULL) {
st = _array_descr_walk_subarray(ldescr->subarray, l);
if (st) {
return -1;
}
}
}
return 0;
}
/*
* Return 0 if successful
*/
static int _PyArray_DescrHashImp(PyArray_Descr *descr, npy_hash_t *hash)
{
PyObject *l, *tl;
int st;
// 创建一个空列表对象
l = PyList_New(0);
if (l == NULL) {
return -1;
}
// 递归遍历描述符结构,将信息存入列表对象
st = _array_descr_walk(descr, l);
if (st) {
Py_DECREF(l);
return -1;
}
/*
* Convert the list to tuple and compute the tuple hash using python
* builtin function
*/
// 将列表对象转换为元组对象,并使用 Python 内置的哈希函数计算元组的哈希值
tl = PyList_AsTuple(l);
Py_DECREF(l);
if (tl == NULL)
return -1;
// 将计算得到的哈希值存入指定的变量
*hash = PyObject_Hash(tl);
Py_DECREF(tl);
if (*hash == -1) {
/* XXX: does PyObject_Hash set an exception on failure ? */
PyErr_SetString(PyExc_SystemError,
"(Hash) Error while hashing final tuple");
return -1;
}
return 0;
}
NPY_NO_EXPORT npy_hash_t
PyArray_DescrHash(PyObject* odescr)
{
PyArray_Descr *descr;
int st;
// 检查输入的对象是否为有效的数组描述符,否则报错
if (!PyArray_DescrCheck(odescr)) {
PyErr_SetString(PyExc_ValueError,
"PyArray_DescrHash argument must be a type descriptor");
return -1;
}
descr = (PyArray_Descr*)odescr;
// 如果描述符的哈希值为-1,表示尚未计算哈希值,需要计算
if (descr->hash == -1) {
// 调用内部函数计算描述符的哈希值
st = _PyArray_DescrHashImp(descr, &descr->hash);
if (st) {
return -1;
}
}
// 返回计算得到的哈希值
return descr->hash;
}
.\numpy\numpy\_core\src\multiarray\hashdescr.h
NPY_NO_EXPORT npy_hash_t
PyArray_DescrHash(PyObject* odescr);
.\numpy\numpy\_core\src\multiarray\item_selection.c
/*
* 定义 NPY_NO_DEPRECATED_API 以及 _MULTIARRAYMODULE 宏
*/
/*
* 定义 PY_SSIZE_T_CLEAN 宏,并包含必要的头文件
*/
/*
* 包含其他头文件,用于数组操作和类型处理
*/
/*
* 定义静态内联函数 npy_fasttake_impl
* 实现快速取值操作,支持多线程处理
*/
static NPY_GCC_OPT_3 inline int
npy_fasttake_impl(
char *dest, char *src, const npy_intp *indices,
npy_intp n, npy_intp m, npy_intp max_item,
npy_intp nelem, npy_intp chunk,
NPY_CLIPMODE clipmode, npy_intp itemsize, int needs_refcounting,
PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, int axis)
{
NPY_BEGIN_THREADS_DEF; /* 定义多线程开始 */
NPY_cast_info cast_info; /* 定义类型转换信息结构体 */
NPY_ARRAYMETHOD_FLAGS flags; /* 定义数组方法标志 */
NPY_cast_info_init(&cast_info); /* 初始化类型转换信息 */
if (!needs_refcounting) {
/* 如果不需要引用计数,直接使用 memcpy 进行简单的拷贝 */
NPY_BEGIN_THREADS; /* 开始多线程处理 */
}
else {
/* 如果需要引用计数 */
if (PyArray_GetDTypeTransferFunction(
1, itemsize, itemsize, src_dtype, dst_dtype, 0,
&cast_info, &flags) < 0) {
return -1; /* 获取数据类型转换函数失败,返回错误 */
}
if (!(flags & NPY_METH_REQUIRES_PYAPI)) {
NPY_BEGIN_THREADS; /* 开始多线程处理 */
}
}
/* 多线程结束 */
NPY_END_THREADS;
NPY_cast_info_xfree(&cast_info); /* 释放类型转换信息结构体内存 */
return 0; /* 返回成功 */
fail:
/* 失败时,已经确保多线程结束 */
NPY_cast_info_xfree(&cast_info); /* 释放类型转换信息结构体内存 */
return -1; /* 返回失败 */
}
/*
* 辅助函数,实例化 npy_fasttake_impl 在不同分支中以优化每个特定的 itemsize
*/
static NPY_GCC_OPT_3 int
npy_fasttake(
char *dest, char *src, const npy_intp *indices,
npy_intp n, npy_intp m, npy_intp max_item,
npy_intp nelem, npy_intp chunk,
NPY_CLIPMODE clipmode, npy_intp itemsize, int needs_refcounting,
PyArray_Descr *src_dtype, PyArray_Descr *dst_dtype, int axis)
{
if (!needs_refcounting) {
if (chunk == 1) {
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
if (chunk == 2) {
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
if (chunk == 4) {
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
if (chunk == 8) {
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
if (chunk == 16) {
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
if (chunk == 32) {
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
}
return npy_fasttake_impl(
dest, src, indices, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_dtype,
dst_dtype, axis);
}
/*NUMPY_API
* Take
*/
/* 定义 PyArray_TakeFrom 函数,接受一个 NumPy 数组对象 self0,一个索引数组对象 indices0,
* 一个轴 axis,一个输出数组对象 out,以及一个剪切模式 clipmode */
NPY_NO_EXPORT PyObject *
PyArray_TakeFrom(PyArrayObject *self0, PyObject *indices0, int axis,
PyArrayObject *out, NPY_CLIPMODE clipmode)
{
PyArray_Descr *dtype; /* PyArray_TakeFrom 函数中用到的数组描述符 */
PyArrayObject *obj = NULL, *self, *indices; /* 定义 PyArrayObject 类型的指针变量 obj, self, indices */
npy_intp nd, i, n, m, max_item, chunk, itemsize, nelem; /* 定义 numpy 的整型数据类型变量 */
npy_intp shape[NPY_MAXDIMS]; /* 定义数组形状的数组 */
npy_bool needs_refcounting; /* 是否需要引用计数 */
indices = NULL; /* 将 indices 初始化为 NULL */
/* 将 self0 转换为 PyArrayObject 类型,并检查轴,返回只读的 C 数组类型的 self,若失败则返回 NULL */
self = (PyArrayObject *)PyArray_CheckAxis(self0, &axis,
NPY_ARRAY_CARRAY_RO);
if (self == NULL) {
return NULL; /* 如果 self 为 NULL,则直接返回 NULL */
}
/* 将 indices0 转换为 PyArrayObject 类型,要求数据类型为 NPY_INTP,若失败则跳转到 fail 标签处 */
indices = (PyArrayObject *)PyArray_FromAny(indices0,
PyArray_DescrFromType(NPY_INTP),
0, 0,
NPY_ARRAY_SAME_KIND_CASTING | NPY_ARRAY_DEFAULT,
NULL);
if (indices == NULL) {
goto fail; /* 如果 indices 为 NULL,则跳转到 fail 标签处 */
}
n = m = chunk = 1; /* 初始化 n, m, chunk 为 1 */
nd = PyArray_NDIM(self) + PyArray_NDIM(indices) - 1; /* 计算结果数组的维度 */
/* 遍历计算结果数组的形状 */
for (i = 0; i < nd; i++) {
if (i < axis) {
shape[i] = PyArray_DIMS(self)[i]; /* 如果 i 小于 axis,则取 self 对应维度的大小 */
n *= shape[i]; /* 计算 n */
}
else {
if (i < axis+PyArray_NDIM(indices)) {
shape[i] = PyArray_DIMS(indices)[i-axis]; /* 计算 indices 对应维度的大小 */
m *= shape[i]; /* 计算 m */
}
else {
shape[i] = PyArray_DIMS(self)[i-PyArray_NDIM(indices)+1]; /* 计算剩余维度的大小 */
chunk *= shape[i]; /* 计算 chunk */
}
}
}
/* 如果没有指定输出数组 out */
if (!out) {
dtype = PyArray_DESCR(self); /* 获取 self 的数据类型描述符 */
Py_INCREF(dtype); /* 增加数据类型描述符的引用计数 */
/* 使用给定的描述符创建新的数组对象 obj */
obj = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(self),
dtype,
nd, shape,
NULL, NULL, 0,
(PyObject *)self);
if (obj == NULL) {
goto fail; /* 如果创建 obj 失败,则跳转到 fail 标签处 */
}
}
else {
int flags = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY; /* 设置数组的标志 */
/* 检查输出数组 out 的维度是否与结果数组的维度相同,若不同则抛出 ValueError 异常 */
if ((PyArray_NDIM(out) != nd) ||
!PyArray_CompareLists(PyArray_DIMS(out), shape, nd)) {
PyErr_SetString(PyExc_ValueError,
"output array does not match result of ndarray.take");
goto fail; /* 如果维度不匹配,则跳转到 fail 标签处 */
}
/* 如果 out 和 self 有重叠的部分,则设置标志为 NPY_ARRAY_ENSURECOPY */
if (arrays_overlap(out, self)) {
flags |= NPY_ARRAY_ENSURECOPY;
}
/* 如果剪切模式为 NPY_RAISE,则需要确保获取副本 */
if (clipmode == NPY_RAISE) {
/*
* 我们需要确保获取一个副本
* 这样在调用错误之前不会改变输入数组
*/
flags |= NPY_ARRAY_ENSURECOPY;
}
dtype = PyArray_DESCR(self); /* 获取 self 的数据类型描述符 */
Py_INCREF(dtype); /* 增加数据类型描述符的引用计数 */
/* 使用给定的数组 out、描述符 dtype 和标志 flags 创建新的数组对象 obj */
obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags);
if (obj == NULL) {
goto fail; /* 如果创建 obj 失败,则跳转到 fail 标签处 */
}
}
max_item = PyArray_DIMS(self)[axis]; /* 获取 self 在轴 axis 上的大小 */
nelem = chunk; /* 设置 nelem 为 chunk */
itemsize = PyArray_ITEMSIZE(obj); /* 获取 obj 的单个元素大小 */
chunk = chunk * itemsize; /* 计算 chunk 的大小 */
char *src = PyArray_DATA(self); /* 获取 self 的数据指针 */
char *dest = PyArray_DATA(obj); /* 获取 obj 的数据指针 */
// 获取源数组的描述符
PyArray_Descr *src_descr = PyArray_DESCR(self);
// 获取目标数组的描述符
PyArray_Descr *dst_descr = PyArray_DESCR(obj);
// 检查是否需要引用计数,基于源数组的数据类型描述符
needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(self));
// 将索引数组的数据转换为 npy_intp 类型指针
npy_intp *indices_data = (npy_intp *)PyArray_DATA(indices);
// 如果 max_item 为 0 且目标数组不为空,则抛出索引错误
if ((max_item == 0) && (PyArray_SIZE(obj) != 0)) {
/* Index error, since that is the usual error for raise mode */
PyErr_SetString(PyExc_IndexError,
"cannot do a non-empty take from an empty axes.");
// 跳转到失败处理部分
goto fail;
}
// 调用 npy_fasttake 函数执行快速取值操作
if (npy_fasttake(
dest, src, indices_data, n, m, max_item, nelem, chunk,
clipmode, itemsize, needs_refcounting, src_descr, dst_descr,
axis) < 0) {
// 如果操作失败,跳转到失败处理部分
goto fail;
}
// 如果指定了输出数组且输出数组不等于原始数组
if (out != NULL && out != obj) {
// 尝试解析写回(writeback)操作,如果失败则跳转到失败处理部分
if (PyArray_ResolveWritebackIfCopy(obj) < 0) {
goto fail;
}
// 释放原始数组的引用
Py_DECREF(obj);
// 增加输出数组的引用
Py_INCREF(out);
// 更新 obj 指向输出数组
obj = out;
}
// 释放索引数组的引用
Py_XDECREF(indices);
// 释放 self 指向的数组的引用
Py_XDECREF(self);
// 返回 obj 对象(PyObject 类型的指针)
return (PyObject *)obj;
fail:
// 放弃写回操作(如果是复制的情况)
PyArray_DiscardWritebackIfCopy(obj);
// 释放 obj 对象的引用
Py_XDECREF(obj);
// 释放索引数组的引用
Py_XDECREF(indices);
// 释放 self 指向的数组的引用
Py_XDECREF(self);
// 返回 NULL 表示操作失败
return NULL;
}
/*NUMPY_API
* Put values into an array
*/
/* 将值放入数组中的函数定义,是 NumPy C API 中的一部分 */
NPY_NO_EXPORT PyObject *
PyArray_PutTo(PyArrayObject *self, PyObject* values0, PyObject *indices0,
NPY_CLIPMODE clipmode)
{
PyArrayObject *indices, *values;
npy_intp i, itemsize, ni, max_item, nv, tmp;
char *src, *dest;
int copied = 0;
int overlap = 0;
NPY_BEGIN_THREADS_DEF; /* 定义 NumPy 线程支持 */
NPY_cast_info cast_info; /* 定义类型转换信息结构体 */
NPY_ARRAYMETHOD_FLAGS flags; /* 定义数组方法的标志 */
NPY_cast_info_init(&cast_info); /* 初始化类型转换信息结构体 */
indices = NULL;
values = NULL;
if (!PyArray_Check(self)) {
PyErr_SetString(PyExc_TypeError,
"put: first argument must be an array");
return NULL;
}
/* 检查第一个参数是否为数组,若不是则设置类型错误并返回空指针 */
if (PyArray_FailUnlessWriteable(self, "put: output array") < 0) {
return NULL;
}
/* 确保输出数组可写,若不可写则返回空指针 */
indices = (PyArrayObject *)PyArray_ContiguousFromAny(indices0,
NPY_INTP, 0, 0);
/* 将 indices0 转换为连续的整型数组对象 */
if (indices == NULL) {
goto fail;
}
/* 若转换失败则跳转到 fail 标签处 */
ni = PyArray_SIZE(indices); /* 获取 indices 数组的大小 */
if ((ni > 0) && (PyArray_Size((PyObject *)self) == 0)) {
PyErr_SetString(PyExc_IndexError,
"cannot replace elements of an empty array");
goto fail;
}
/* 若 indices 非空且 self 是空数组,则设置索引错误并跳转到 fail 标签处 */
Py_INCREF(PyArray_DESCR(self)); /* 增加 self 数组的类型描述的引用计数 */
values = (PyArrayObject *)PyArray_FromAny(values0, PyArray_DESCR(self), 0, 0,
NPY_ARRAY_DEFAULT | NPY_ARRAY_FORCECAST, NULL);
/* 将 values0 转换为具有指定描述符的数组对象 */
if (values == NULL) {
goto fail;
}
/* 若转换失败则跳转到 fail 标签处 */
nv = PyArray_SIZE(values); /* 获取 values 数组的大小 */
if (nv <= 0) {
goto finish;
}
/* 若 values 为空则跳转到 finish 标签处 */
overlap = arrays_overlap(self, values) || arrays_overlap(self, indices);
/* 检查 self 与 values 或 indices 是否有重叠 */
if (overlap || !PyArray_ISCONTIGUOUS(self)) {
/* 若有重叠或 self 不是连续的数组 */
PyArrayObject *obj;
int flags = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY |
NPY_ARRAY_ENSURECOPY;
Py_INCREF(PyArray_DESCR(self)); /* 增加 self 数组的类型描述的引用计数 */
obj = (PyArrayObject *)PyArray_FromArray(self,
PyArray_DESCR(self), flags);
/* 根据给定数组创建新的数组对象 */
copied = 1;
assert(self != obj); /* 断言 self 与 obj 不相同 */
self = obj;
}
max_item = PyArray_SIZE(self); /* 获取 self 数组的大小 */
dest = PyArray_DATA(self); /* 获取 self 数组的数据指针 */
itemsize = PyArray_ITEMSIZE(self); /* 获取 self 数组的单个元素大小 */
int has_references = PyDataType_REFCHK(PyArray_DESCR(self));
/* 检查 self 数组的数据类型是否包含引用 */
if (!has_references) {
/* 若数据类型不包含引用,则直接使用 memcpy 进行简单复制 */
NPY_BEGIN_THREADS_THRESHOLDED(ni); /* 根据线程阈值启动线程 */
}
else {
PyArray_Descr *dtype = PyArray_DESCR(self);
if (PyArray_GetDTypeTransferFunction(
PyArray_ISALIGNED(self), itemsize, itemsize, dtype, dtype, 0,
&cast_info, &flags) < 0) {
goto fail;
}
/* 获取数据类型之间的传输函数信息 */
if (!(flags & NPY_METH_REQUIRES_PYAPI)) {
NPY_BEGIN_THREADS_THRESHOLDED(ni); /* 根据线程阈值启动线程 */
}
}
```
// 如果有引用存在时执行以下代码块
if (has_references) {
// 定义常量 one 为 npy_intp 类型的 1
const npy_intp one = 1;
// 定义步长数组 strides,包含两个元素,每个元素都是 itemsize
const npy_intp strides[2] = {itemsize, itemsize};
// 根据 clipmode 的不同情况执行不同的处理
switch(clipmode) {
// 如果 clipmode 是 NPY_RAISE
case NPY_RAISE:
// 遍历索引数组中的元素
for (i = 0; i < ni; i++) {
// 计算 src 指针指向的位置
src = PyArray_BYTES(values) + itemsize*(i % nv);
// 获取索引数组中的值作为 tmp
tmp = ((npy_intp *)(PyArray_DATA(indices)))[i];
// 检查并调整索引 tmp,如果失败则跳转至 fail 标签处
if (check_and_adjust_index(&tmp, max_item, 0, _save) < 0) {
goto fail;
}
// 定义包含两个指针的 data 数组
char *data[2] = {src, dest + tmp*itemsize};
// 调用 cast_info.func 进行类型转换,处理 data 数据
if (cast_info.func(
&cast_info.context, data, &one, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
goto fail;
}
}
break;
// 如果 clipmode 是 NPY_WRAP
case NPY_WRAP:
// 遍历索引数组中的元素
for (i = 0; i < ni; i++) {
// 计算 src 指针指向的位置
src = PyArray_BYTES(values) + itemsize * (i % nv);
// 获取索引数组中的值作为 tmp
tmp = ((npy_intp *)(PyArray_DATA(indices)))[i];
// 如果 tmp 小于 0,则循环增加 tmp 直到大于等于 0
if (tmp < 0) {
while (tmp < 0) {
tmp += max_item;
}
}
// 如果 tmp 大于等于 max_item,则循环减小 tmp 直到小于 max_item
else if (tmp >= max_item) {
while (tmp >= max_item) {
tmp -= max_item;
}
}
// 定义包含两个指针的 data 数组
char *data[2] = {src, dest + tmp*itemsize};
// 调用 cast_info.func 进行类型转换,处理 data 数据
if (cast_info.func(
&cast_info.context, data, &one, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
goto fail;
}
}
break;
// 如果 clipmode 是 NPY_CLIP
case NPY_CLIP:
// 遍历索引数组中的元素
for (i = 0; i < ni; i++) {
// 计算 src 指针指向的位置
src = PyArray_BYTES(values) + itemsize * (i % nv);
// 获取索引数组中的值作为 tmp
tmp = ((npy_intp *)(PyArray_DATA(indices)))[i];
// 如果 tmp 小于 0,则将 tmp 设为 0
if (tmp < 0) {
tmp = 0;
}
// 如果 tmp 大于等于 max_item,则将 tmp 设为 max_item - 1
else if (tmp >= max_item) {
tmp = max_item - 1;
}
// 定义包含两个指针的 data 数组
char *data[2] = {src, dest + tmp*itemsize};
// 调用 cast_info.func 进行类型转换,处理 data 数据
if (cast_info.func(
&cast_info.context, data, &one, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
goto fail;
}
}
break;
}
}
else {
switch(clipmode) {
case NPY_RAISE:
// 如果 clipmode 是 NPY_RAISE,进行严格模式处理
for (i = 0; i < ni; i++) {
// 计算源数据的位置
src = PyArray_BYTES(values) + itemsize * (i % nv);
// 获取索引值并进行调整
tmp = ((npy_intp *)(PyArray_DATA(indices)))[i];
if (check_and_adjust_index(&tmp, max_item, 0, _save) < 0) {
// 若索引越界则跳转至失败处理标签
goto fail;
}
// 将数据从 src 复制到 dest 中的指定位置
memmove(dest + tmp * itemsize, src, itemsize);
}
break;
case NPY_WRAP:
// 如果 clipmode 是 NPY_WRAP,进行环绕模式处理
for (i = 0; i < ni; i++) {
// 计算源数据的位置
src = PyArray_BYTES(values) + itemsize * (i % nv);
// 获取索引值
tmp = ((npy_intp *)(PyArray_DATA(indices)))[i];
if (tmp < 0) {
// 处理负索引,使其在范围内
while (tmp < 0) {
tmp += max_item;
}
}
else if (tmp >= max_item) {
// 处理超出最大索引,使其在范围内
while (tmp >= max_item) {
tmp -= max_item;
}
}
// 将数据从 src 复制到 dest 中的指定位置
memmove(dest + tmp * itemsize, src, itemsize);
}
break;
case NPY_CLIP:
// 如果 clipmode 是 NPY_CLIP,进行截断模式处理
for (i = 0; i < ni; i++) {
// 计算源数据的位置
src = PyArray_BYTES(values) + itemsize * (i % nv);
// 获取索引值并将其限制在合法范围内
tmp = ((npy_intp *)(PyArray_DATA(indices)))[i];
if (tmp < 0) {
tmp = 0;
}
else if (tmp >= max_item) {
tmp = max_item - 1;
}
// 将数据从 src 复制到 dest 中的指定位置
memmove(dest + tmp * itemsize, src, itemsize);
}
break;
}
}
// 结束多线程操作
NPY_END_THREADS;
finish:
// 释放类型转换信息资源
NPY_cast_info_xfree(&cast_info);
// 释放引用的对象
Py_XDECREF(values);
Py_XDECREF(indices);
if (copied) {
// 若有复制操作,解析写回(writeback)并释放数组对象
PyArray_ResolveWritebackIfCopy(self);
Py_DECREF(self);
}
// 返回 None
Py_RETURN_NONE;
fail:
// 失败处理,释放类型转换信息资源
NPY_cast_info_xfree(&cast_info);
// 释放引用的对象
Py_XDECREF(indices);
Py_XDECREF(values);
if (copied) {
// 若有复制操作,丢弃写回(writeback)并释放数组对象
PyArray_DiscardWritebackIfCopy(self);
Py_XDECREF(self);
}
// 返回 NULL 表示失败
return NULL;
/*NUMPY_API
* Put values into an array according to a mask.
*/
NPY_NO_EXPORT PyObject *
PyArray_PutMask(PyArrayObject *self, PyObject* values0, PyObject* mask0)
{
PyArrayObject *mask, *values;
PyArray_Descr *dtype;
npy_intp itemsize, ni, nv;
char *src, *dest;
npy_bool *mask_data;
int copied = 0;
int overlap = 0;
NPY_BEGIN_THREADS_DEF;
mask = NULL; // 初始化 mask 为 NULL
values = NULL; // 初始化 values 为 NULL
// 检查 self 是否为一个数组,如果不是,设置错误信息并返回 NULL
if (!PyArray_Check(self)) {
PyErr_SetString(PyExc_TypeError,
"putmask: first argument must "
"be an array");
return NULL;
}
// 检查 self 是否可写,如果不可写,返回 NULL
if (PyArray_FailUnlessWriteable(self, "putmask: output array") < 0) {
return NULL;
}
// 将 mask0 转换为 NPY_BOOL 类型的 PyArrayObject,要求 C 连续存储,并强制转换
mask = (PyArrayObject *)PyArray_FROM_OTF(mask0, NPY_BOOL,
NPY_ARRAY_CARRAY | NPY_ARRAY_FORCECAST);
// 如果转换失败,跳转到 fail 标签处理错误
if (mask == NULL) {
goto fail;
}
// 获取 mask 的元素个数 ni,与 self 的元素个数进行比较,如果不相等,设置错误信息并跳转到 fail 标签
ni = PyArray_SIZE(mask);
if (ni != PyArray_SIZE(self)) {
PyErr_SetString(PyExc_ValueError,
"putmask: mask and data must be "
"the same size");
goto fail;
}
// 获取 mask 数据的指针
mask_data = PyArray_DATA(mask);
// 获取 self 的数据类型描述符,并增加其引用计数
dtype = PyArray_DESCR(self);
Py_INCREF(dtype);
// 将 values0 转换为 PyArrayObject 类型的数组对象,使用给定的 dtype
// 如果转换失败,则跳转到错误处理标签 fail
values = (PyArrayObject *)PyArray_FromAny(values0, dtype,
0, 0, NPY_ARRAY_CARRAY, NULL);
if (values == NULL) {
goto fail;
}
// 获取数组 values 的元素个数 nv,如果数组为空,则 nv 为零
nv = PyArray_SIZE(values); /* 如果数组为空则为零 */
if (nv <= 0) {
// 如果 nv 小于等于 0,释放 values 和 mask 的引用,并返回 None
Py_XDECREF(values);
Py_XDECREF(mask);
Py_RETURN_NONE;
}
// 获取数组 values 的数据指针 src
src = PyArray_DATA(values);
// 检查 self 和 values 或 mask 是否有重叠,或者 self 是否非连续存储
overlap = arrays_overlap(self, values) || arrays_overlap(self, mask);
if (overlap || !PyArray_ISCONTIGUOUS(self)) {
// 如果有重叠或者 self 不是连续存储,则创建一个新的数组对象 obj
// 使用 flags 标志指定数组属性,保留原始数据或者创建数据的拷贝
int flags = NPY_ARRAY_CARRAY | NPY_ARRAY_WRITEBACKIFCOPY;
PyArrayObject *obj;
if (overlap) {
flags |= NPY_ARRAY_ENSURECOPY;
}
// 获取 self 的数据类型 dtype,并增加其引用计数
dtype = PyArray_DESCR(self);
Py_INCREF(dtype);
// 根据 self 的数据类型和 flags 创建一个新的数组对象 obj
obj = (PyArrayObject *)PyArray_FromArray(self, dtype, flags);
// 如果 obj 不等于 self,则表示进行了数据拷贝
if (obj != self) {
copied = 1;
}
// 将 self 指向新创建的数组对象 obj
self = obj;
}
// 获取 self 中每个元素的字节大小 itemsize 和数据指针 dest
itemsize = PyArray_ITEMSIZE(self);
dest = PyArray_DATA(self);
// 检查 self 的数据类型是否需要引用计数检查
if (PyDataType_REFCHK(PyArray_DESCR(self))) {
NPY_cast_info cast_info;
NPY_ARRAYMETHOD_FLAGS flags;
const npy_intp one = 1;
const npy_intp strides[2] = {itemsize, itemsize};
// 初始化类型转换信息 cast_info
NPY_cast_info_init(&cast_info);
// 获取数据类型之间的转换函数,如果失败则跳转到错误处理标签 fail
if (PyArray_GetDTypeTransferFunction(
PyArray_ISALIGNED(self), itemsize, itemsize, dtype, dtype, 0,
&cast_info, &flags) < 0) {
goto fail;
}
// 如果转换函数不需要 Python API,则启动线程
if (!(flags & NPY_METH_REQUIRES_PYAPI)) {
NPY_BEGIN_THREADS;
}
// 遍历 mask_data,根据 mask_data[i] 决定是否进行类型转换
for (npy_intp i = 0, j = 0; i < ni; i++, j++) {
// 如果 j 超过了 values 的长度 nv,则重置为 0
if (j >= nv) {
j = 0;
}
// 如果 mask_data[i] 为真,则进行类型转换
if (mask_data[i]) {
char *data[2] = {src + j*itemsize, dest + i*itemsize};
// 调用转换函数进行数据转换,如果失败则跳转到错误处理标签 fail
if (cast_info.func(
&cast_info.context, data, &one, strides,
cast_info.auxdata) < 0) {
NPY_END_THREADS;
NPY_cast_info_xfree(&cast_info);
goto fail;
}
}
}
// 释放类型转换信息 cast_info
NPY_cast_info_xfree(&cast_info);
}
else {
// 如果 self 的数据类型不需要引用计数检查,则启动线程进行快速的数据更新
NPY_BEGIN_THREADS;
npy_fastputmask(dest, src, mask_data, ni, nv, itemsize);
}
// 结束线程
NPY_END_THREADS;
// 释放 values 和 mask 的引用
Py_XDECREF(values);
Py_XDECREF(mask);
// 如果进行了数据拷贝,则解析写回数据并释放 self
if (copied) {
PyArray_ResolveWritebackIfCopy(self);
Py_DECREF(self);
}
// 返回 None
Py_RETURN_NONE;
fail:
// 错误处理,释放 mask 和 values 的引用,并根据是否拷贝了数据进行处理
Py_XDECREF(mask);
Py_XDECREF(values);
if (copied) {
PyArray_DiscardWritebackIfCopy(self);
Py_XDECREF(self);
}
// 返回 NULL 指示出现错误
return NULL;
}
static NPY_GCC_OPT_3 inline int
npy_fastrepeat_impl(
npy_intp n_outer, npy_intp n, npy_intp nel, npy_intp chunk,
npy_bool broadcast, npy_intp* counts, char* new_data, char* old_data,
npy_intp elsize, NPY_cast_info cast_info, int needs_refcounting)
{
// 外层循环,循环次数为 n_outer
npy_intp i, j, k;
for (i = 0; i < n_outer; i++) {
// 内层循环,循环次数为 n
for (j = 0; j < n; j++) {
// 计算重复次数 tmp,如果是广播则取 counts[0],否则取 counts[j]
npy_intp tmp = broadcast ? counts[0] : counts[j];
// 根据 tmp 执行数据复制操作的循环
for (k = 0; k < tmp; k++) {
// 如果不需要引用计数,直接使用 memcpy 复制数据块
if (!needs_refcounting) {
memcpy(new_data, old_data, chunk);
}
// 否则,执行类型转换并复制数据块
else {
char *data[2] = {old_data, new_data};
npy_intp strides[2] = {elsize, elsize};
// 调用类型转换函数
if (cast_info.func(&cast_info.context, data, &nel,
strides, cast_info.auxdata) < 0) {
return -1;
}
}
// 更新 new_data 的位置到下一个数据块的起始位置
new_data += chunk;
}
// 更新 old_data 的位置到下一个数据块的起始位置
old_data += chunk;
}
}
// 函数执行成功返回 0
return 0;
}
static NPY_GCC_OPT_3 int
npy_fastrepeat(
npy_intp n_outer, npy_intp n, npy_intp nel, npy_intp chunk,
npy_bool broadcast, npy_intp* counts, char* new_data, char* old_data,
npy_intp elsize, NPY_cast_info cast_info, int needs_refcounting)
{
// 如果不需要引用计数,直接调用 npy_fastrepeat_impl 函数
if (!needs_refcounting) {
// 根据不同的 chunk 大小调用 npy_fastrepeat_impl 函数
if (chunk == 1 || chunk == 2 || chunk == 4 || chunk == 8 ||
chunk == 16 || chunk == 32) {
return npy_fastrepeat_impl(
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data,
elsize, cast_info, needs_refcounting);
}
}
// 否则,无论 chunk 大小,都调用 npy_fastrepeat_impl 函数
return npy_fastrepeat_impl(
n_outer, n, nel, chunk, broadcast, counts, new_data, old_data, elsize,
cast_info, needs_refcounting);
}
/*NUMPY_API
* Repeat the array.
*/
NPY_NO_EXPORT PyObject *
PyArray_Repeat(PyArrayObject *aop, PyObject *op, int axis)
{
// 声明和初始化变量
npy_intp *counts;
npy_intp i, j, n, n_outer, chunk, elsize, nel;
npy_intp total = 0;
npy_bool broadcast = NPY_FALSE;
PyArrayObject *repeats = NULL;
PyObject *ap = NULL;
PyArrayObject *ret = NULL;
char *new_data, *old_data;
NPY_cast_info cast_info;
NPY_ARRAYMETHOD_FLAGS flags;
int needs_refcounting;
repeats = (PyArrayObject *)PyArray_ContiguousFromAny(op, NPY_INTP, 0, 1);
if (repeats == NULL) {
return NULL;
}
/*
* 标量和大小为1的'repeat'数组可以广播到任何形状,对于所有其他输入,维度必须完全匹配。
*/
if (PyArray_NDIM(repeats) == 0 || PyArray_SIZE(repeats) == 1) {
broadcast = NPY_TRUE;
}
counts = (npy_intp *)PyArray_DATA(repeats);
if ((ap = PyArray_CheckAxis(aop, &axis, NPY_ARRAY_CARRAY)) == NULL) {
Py_DECREF(repeats);
return NULL;
}
aop = (PyArrayObject *)ap;
n = PyArray_DIM(aop, axis);
NPY_cast_info_init(&cast_info);
needs_refcounting = PyDataType_REFCHK(PyArray_DESCR(aop));
if (!broadcast && PyArray_SIZE(repeats) != n) {
PyErr_Format(PyExc_ValueError,
"operands could not be broadcast together "
"with shape (%zd,) (%zd,)", n, PyArray_DIM(repeats, 0));
goto fail;
}
if (broadcast) {
total = counts[0] * n;
}
else {
for (j = 0; j < n; j++) {
if (counts[j] < 0) {
PyErr_SetString(PyExc_ValueError,
"repeats may not contain negative values.");
goto fail;
}
total += counts[j];
}
}
/* 构建新的数组 */
PyArray_DIMS(aop)[axis] = total;
Py_INCREF(PyArray_DESCR(aop));
ret = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(aop),
PyArray_DESCR(aop),
PyArray_NDIM(aop),
PyArray_DIMS(aop),
NULL, NULL, 0,
(PyObject *)aop);
PyArray_DIMS(aop)[axis] = n;
if (ret == NULL) {
goto fail;
}
new_data = PyArray_DATA(ret);
old_data = PyArray_DATA(aop);
nel = 1;
elsize = PyArray_ITEMSIZE(aop);
for(i = axis + 1; i < PyArray_NDIM(aop); i++) {
nel *= PyArray_DIMS(aop)[i];
}
chunk = nel*elsize;
n_outer = 1;
for (i = 0; i < axis; i++) {
n_outer *= PyArray_DIMS(aop)[i];
}
if (needs_refcounting) {
if (PyArray_GetDTypeTransferFunction(
1, elsize, elsize, PyArray_DESCR(aop), PyArray_DESCR(aop), 0,
&cast_info, &flags) < 0) {
goto fail;
}
}
if (npy_fastrepeat(n_outer, n, nel, chunk, broadcast, counts, new_data,
old_data, elsize, cast_info, needs_refcounting) < 0) {
goto fail;
}
Py_DECREF(repeats);
Py_XDECREF(aop);
NPY_cast_info_xfree(&cast_info);
return (PyObject *)ret;
fail:
Py_DECREF(repeats);
Py_XDECREF(aop);
Py_XDECREF(ret);
NPY_cast_info_xfree(&cast_info);
return NULL;
/*
* 转换所有输入为相同类型的数组
* 同时使它们变为 C 连续数组
*/
mps = PyArray_ConvertToCommonType(op, &n);
if (mps == NULL) {
return NULL;
}
for (i = 0; i < n; i++) {
if (mps[i] == NULL) {
goto fail;
}
}
ap = (PyArrayObject *)PyArray_FROM_OT((PyObject *)ip, NPY_INTP);
if (ap == NULL) {
goto fail;
}
/* 将所有数组广播到彼此,最后是索引数组 */
multi = (PyArrayMultiIterObject *)
PyArray_MultiIterFromObjects((PyObject **)mps, n, 1, ap);
if (multi == NULL) {
goto fail;
}
dtype = PyArray_DESCR(mps[0]);
/* 设置返回数组 */
if (out == NULL) {
Py_INCREF(dtype);
obj = (PyArrayObject *)PyArray_NewFromDescr(Py_TYPE(ap),
dtype,
multi->nd,
multi->dimensions,
NULL, NULL, 0,
(PyObject *)ap);
}
else {
int flags = NPY_ARRAY_CARRAY |
NPY_ARRAY_WRITEBACKIFCOPY |
NPY_ARRAY_FORCECAST;
if ((PyArray_NDIM(out) != multi->nd)
|| !PyArray_CompareLists(PyArray_DIMS(out),
multi->dimensions,
multi->nd)) {
PyErr_SetString(PyExc_TypeError,
"choose: invalid shape for output array.");
goto fail;
}
for (i = 0; i < n; i++) {
if (arrays_overlap(out, mps[i])) {
flags |= NPY_ARRAY_ENSURECOPY;
}
}
if (clipmode == NPY_RAISE) {
/*
* 需要确保并获取一个副本,
* 以便在调用错误之前不更改输入数组
*/
flags |= NPY_ARRAY_ENSURECOPY;
}
Py_INCREF(dtype);
obj = (PyArrayObject *)PyArray_FromArray(out, dtype, flags);
}
if (obj == NULL) {
goto fail;
}
elsize = dtype->elsize;
ret_data = PyArray_DATA(obj);
npy_intp transfer_strides[2] = {elsize, elsize};
npy_intp one = 1;
NPY_ARRAYMETHOD_FLAGS transfer_flags = 0;
// 检查数据类型是否需要引用计数,如果是则执行以下操作
if (PyDataType_REFCHK(dtype)) {
// 检查对象是否按无符号整数对齐
int is_aligned = IsUintAligned(obj);
// 获取数据类型转换函数,并设置转换的相关信息
PyArray_GetDTypeTransferFunction(
is_aligned,
dtype->elsize,
dtype->elsize,
dtype,
dtype, 0, &cast_info,
&transfer_flags);
}
// 循环遍历多重迭代器,直到迭代完成
while (PyArray_MultiIter_NOTDONE(multi)) {
// 获取当前迭代的索引
mi = *((npy_intp *)PyArray_MultiIter_DATA(multi, n));
// 如果索引超出了有效范围,则根据剪切模式进行处理
if (mi < 0 || mi >= n) {
switch(clipmode) {
case NPY_RAISE:
// 如果剪切模式为 NPY_RAISE,则引发值错误异常
PyErr_SetString(PyExc_ValueError,
"invalid entry in choice "\
"array");
// 跳转到失败处理标签
goto fail;
case NPY_WRAP:
// 如果剪切模式为 NPY_WRAP,则根据循环方式处理超出范围的索引
if (mi < 0) {
while (mi < 0) {
mi += n;
}
}
else {
while (mi >= n) {
mi -= n;
}
}
break;
case NPY_CLIP:
// 如果剪切模式为 NPY_CLIP,则将超出范围的索引调整到合法范围内
if (mi < 0) {
mi = 0;
}
else if (mi >= n) {
mi = n - 1;
}
break;
}
}
// 如果转换信息中的函数为空,则使用 memcpy 复制数据,因为内存不重叠
if (cast_info.func == NULL) {
/* We ensure memory doesn't overlap, so can use memcpy */
memcpy(ret_data, PyArray_MultiIter_DATA(multi, mi), elsize);
}
else {
// 否则,调用转换函数进行数据转换
char *args[2] = {PyArray_MultiIter_DATA(multi, mi), ret_data};
if (cast_info.func(&cast_info.context, args, &one,
transfer_strides, cast_info.auxdata) < 0) {
// 如果转换失败,则跳转到失败处理标签
goto fail;
}
}
// 更新返回数据的指针位置
ret_data += elsize;
// 更新多重迭代器的状态,移动到下一个迭代位置
PyArray_MultiIter_NEXT(multi);
}
// 释放转换信息所占用的内存
NPY_cast_info_xfree(&cast_info);
// 释放多重迭代器对象
Py_DECREF(multi);
// 释放所有的缓冲数组对象
for (i = 0; i < n; i++) {
Py_XDECREF(mps[i]);
}
// 释放输入数组对象
Py_DECREF(ap);
// 释放缓冲数组对象的内存
PyDataMem_FREE(mps);
// 如果有输出对象且不是输入对象的引用,则增加输出对象的引用计数
if (out != NULL && out != obj) {
Py_INCREF(out);
// 解析写回(如果有必要)
PyArray_ResolveWritebackIfCopy(obj);
// 释放输入对象的引用
Py_DECREF(obj);
// 将输出对象设置为当前操作对象
obj = out;
}
// 返回 Python 对象指针类型的对象
return (PyObject *)obj;
fail:
// 在失败的情况下,释放转换信息所占用的内存
NPY_cast_info_xfree(&cast_info);
// 减少多重迭代器对象的引用计数
Py_XDECREF(multi);
// 释放所有的缓冲数组对象
for (i = 0; i < n; i++) {
Py_XDECREF(mps[i]);
}
// 减少输入数组对象的引用计数
Py_XDECREF(ap);
// 释放缓冲数组对象的内存
PyDataMem_FREE(mps);
// 放弃写回(如果有必要)
PyArray_DiscardWritebackIfCopy(obj);
// 减少当前操作对象的引用计数
Py_XDECREF(obj);
// 返回空指针,表示操作失败
return NULL;
/*
* These algorithms use special sorting. They are not called unless the
* underlying sort function for the type is available. Note that axis is
* already valid. The sort functions require 1-d contiguous and well-behaved
* data. Therefore, a copy will be made of the data if needed before handing
* it to the sorting routine. An iterator is constructed and adjusted to walk
* over all but the desired sorting axis.
*/
static int
_new_sortlike(PyArrayObject *op, int axis, PyArray_SortFunc *sort,
PyArray_PartitionFunc *part, npy_intp const *kth, npy_intp nkth)
{
// 获取指定轴向的数组维度大小
npy_intp N = PyArray_DIM(op, axis);
// 获取数组元素的大小(字节数)
npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op);
// 获取指定轴向的数组步长
npy_intp astride = PyArray_STRIDE(op, axis);
// 检查数组是否被字节交换过
int swap = PyArray_ISBYTESWAPPED(op);
// 检查数组是否内存对齐
int is_aligned = IsAligned(op);
// 判断是否需要复制数据
int needcopy = !is_aligned || swap || astride != elsize;
// 检查数组描述符是否需要 Python API 支持
int needs_api = PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_PYAPI);
// 缓冲区指针初始化为空
char *buffer = NULL;
// 迭代器对象初始化为空
PyArrayIterObject *it;
// 迭代器对象大小初始化
npy_intp size;
// 返回值初始化为 0
int ret = 0;
// 获取数组描述符
PyArray_Descr *descr = PyArray_DESCR(op);
// 原数组描述符初始化为空
PyArray_Descr *odescr = NULL;
// 转换信息初始化
NPY_cast_info to_cast_info = {.func = NULL};
NPY_cast_info from_cast_info = {.func = NULL};
// 多线程操作开始
NPY_BEGIN_THREADS_DEF;
/* Check if there is any sorting to do */
// 检查是否需要进行排序操作
if (N <= 1 || PyArray_SIZE(op) == 0) {
return 0;
}
// 获取内存处理器句柄
PyObject *mem_handler = PyDataMem_GetHandler();
// 如果内存处理器句柄为空,返回错误
if (mem_handler == NULL) {
return -1;
}
// 构建除指定轴外的所有维度的迭代器
it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis);
// 如果迭代器为空,释放内存处理器句柄并返回错误
if (it == NULL) {
Py_DECREF(mem_handler);
return -1;
}
// 获取迭代器的大小
size = it->size;
// 如果需要复制数据
if (needcopy) {
// 根据内存处理器分配缓冲区
buffer = PyDataMem_UserNEW(N * elsize, mem_handler);
// 如果分配缓冲区失败,设置返回值为 -1 并跳转到失败处理标签
if (buffer == NULL) {
ret = -1;
goto fail;
}
// 如果数组描述符标记需要初始化,将缓冲区初始化为零
if (PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT)) {
memset(buffer, 0, N * elsize);
}
// 如果数组被交换过字节顺序,创建新的字节顺序描述符
if (swap) {
odescr = PyArray_DescrNewByteorder(descr, NPY_SWAP);
}
// 否则直接使用原数组描述符
else {
odescr = descr;
Py_INCREF(odescr);
}
// 获取数据类型转换函数和转换标志信息
NPY_ARRAYMETHOD_FLAGS to_transfer_flags;
if (PyArray_GetDTypeTransferFunction(
is_aligned, astride, elsize, descr, odescr, 0, &to_cast_info,
&to_transfer_flags) != NPY_SUCCEED) {
goto fail;
}
// 获取数据类型转换函数和转换标志信息(反向转换)
NPY_ARRAYMETHOD_FLAGS from_transfer_flags;
if (PyArray_GetDTypeTransferFunction(
is_aligned, elsize, astride, odescr, descr, 0, &from_cast_info,
&from_transfer_flags) != NPY_SUCCEED) {
goto fail;
}
}
// 多线程操作开始(使用数组描述符)
NPY_BEGIN_THREADS_DESCR(descr);
// 循环,每次迭代减小 size 的值
while (size--) {
// 获取当前迭代器指向的数据指针
char *bufptr = it->dataptr;
// 如果需要复制数据
if (needcopy) {
// 设置函数参数和步幅以进行类型转换
char *args[2] = {it->dataptr, buffer};
npy_intp strides[2] = {astride, elsize};
// 调用类型转换函数,如果返回值小于 0,则跳转到失败标签
if (NPY_UNLIKELY(to_cast_info.func(
&to_cast_info.context, args, &N, strides,
to_cast_info.auxdata) < 0)) {
goto fail;
}
// 更新 bufptr 指向 buffer
bufptr = buffer;
}
/*
* TODO: 如果输入数组进行了字节交换,但是是连续且对齐的,
* 可以在不复制到缓冲区的情况下直接进行交换(稍后再进行还原)。
* 在调用 sort 或 part 函数时,需要确保即使调用出错,仍然能够在返回之前还原交换。
*/
// 如果 part 函数为 NULL,则调用 sort 函数对 bufptr 所指向的数据进行排序
if (part == NULL) {
ret = sort(bufptr, N, op);
// 如果需要 API 并且发生了异常,则设置返回值为 -1
if (needs_api && PyErr_Occurred()) {
ret = -1;
}
// 如果 sort 函数返回值小于 0,则跳转到失败标签
if (ret < 0) {
goto fail;
}
}
// 否则,调用 part 函数对 bufptr 所指向的数据进行分区处理
else {
npy_intp pivots[NPY_MAX_PIVOT_STACK];
npy_intp npiv = 0;
npy_intp i;
// 对 kth 中的每个值调用 part 函数
for (i = 0; i < nkth; ++i) {
ret = part(bufptr, N, kth[i], pivots, &npiv, nkth, op);
// 如果需要 API 并且发生了异常,则设置返回值为 -1
if (needs_api && PyErr_Occurred()) {
ret = -1;
}
// 如果 part 函数返回值小于 0,则跳转到失败标签
if (ret < 0) {
goto fail;
}
}
}
// 如果需要复制数据,则进行逆类型转换
if (needcopy) {
// 设置函数参数和步幅以进行逆类型转换
char *args[2] = {buffer, it->dataptr};
npy_intp strides[2] = {elsize, astride};
// 调用逆类型转换函数,如果返回值小于 0,则跳转到失败标签
if (NPY_UNLIKELY(from_cast_info.func(
&from_cast_info.context, args, &N, strides,
from_cast_info.auxdata) < 0)) {
goto fail;
}
}
// 更新迭代器以处理下一个元素
PyArray_ITER_NEXT(it);
}
fail:
# 结束可能存在的线程
NPY_END_THREADS_DESCR(descr);
/* cleanup internal buffer */
# 如果需要拷贝数据,清理内部缓冲区
if (needcopy) {
PyArray_ClearBuffer(odescr, buffer, elsize, N, 1);
PyDataMem_UserFREE(buffer, N * elsize, mem_handler);
Py_DECREF(odescr);
}
# 如果返回值小于0且没有发生Python错误
if (ret < 0 && !PyErr_Occurred()) {
/* Out of memory during sorting or buffer creation */
# 在排序或缓冲区创建期间内存不足
PyErr_NoMemory();
}
// if an error happened with a dtype that doesn't hold the GIL, need
// to make sure we return an error value from this function.
// note: only the first error is ever reported, subsequent errors
// must *not* set the error handler.
// 如果使用了不持有GIL的数据类型,并且发生了错误,确保从该函数返回错误值。
// 注意:只有第一个错误会被报告,后续错误不应设置错误处理程序。
if (PyErr_Occurred() && ret == 0) {
ret = -1;
}
Py_DECREF(it);
Py_DECREF(mem_handler);
NPY_cast_info_xfree(&to_cast_info);
NPY_cast_info_xfree(&from_cast_info);
return ret;
}
static PyObject*
_new_argsortlike(PyArrayObject *op, int axis, PyArray_ArgSortFunc *argsort,
PyArray_ArgPartitionFunc *argpart,
npy_intp const *kth, npy_intp nkth)
{
npy_intp N = PyArray_DIM(op, axis);
npy_intp elsize = (npy_intp)PyArray_ITEMSIZE(op);
npy_intp astride = PyArray_STRIDE(op, axis);
int swap = PyArray_ISBYTESWAPPED(op);
int is_aligned = IsAligned(op);
int needcopy = !is_aligned || swap || astride != elsize;
int needs_api = PyDataType_FLAGCHK(PyArray_DESCR(op), NPY_NEEDS_PYAPI);
int needidxbuffer;
char *valbuffer = NULL;
npy_intp *idxbuffer = NULL;
PyArrayObject *rop;
npy_intp rstride;
PyArrayIterObject *it, *rit;
npy_intp size;
int ret = 0;
PyArray_Descr *descr = PyArray_DESCR(op);
PyArray_Descr *odescr = NULL;
NPY_ARRAYMETHOD_FLAGS transfer_flags;
NPY_cast_info cast_info = {.func = NULL};
NPY_BEGIN_THREADS_DEF;
PyObject *mem_handler = PyDataMem_GetHandler();
if (mem_handler == NULL) {
return NULL;
}
rop = (PyArrayObject *)PyArray_NewFromDescr(
Py_TYPE(op), PyArray_DescrFromType(NPY_INTP),
PyArray_NDIM(op), PyArray_DIMS(op), NULL, NULL,
0, (PyObject *)op);
if (rop == NULL) {
Py_DECREF(mem_handler);
return NULL;
}
rstride = PyArray_STRIDE(rop, axis);
needidxbuffer = rstride != sizeof(npy_intp);
/* Check if there is any argsorting to do */
if (N <= 1 || PyArray_SIZE(op) == 0) {
Py_DECREF(mem_handler);
memset(PyArray_DATA(rop), 0, PyArray_NBYTES(rop));
return (PyObject *)rop;
}
it = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)op, &axis);
rit = (PyArrayIterObject *)PyArray_IterAllButAxis((PyObject *)rop, &axis);
if (it == NULL || rit == NULL) {
ret = -1;
goto fail;
}
size = it->size;
// 如果需要进行复制操作
if (needcopy) {
// 使用自定义的内存分配函数分配空间给valbuffer
valbuffer = PyDataMem_UserNEW(N * elsize, mem_handler);
// 检查分配是否成功
if (valbuffer == NULL) {
ret = -1;
goto fail;
}
// 如果描述符需要初始化,使用memset将valbuffer清零
if (PyDataType_FLAGCHK(descr, NPY_NEEDS_INIT)) {
memset(valbuffer, 0, N * elsize);
}
// 如果需要进行字节顺序转换
if (swap) {
// 创建一个新的描述符,指定为需要交换字节顺序
odescr = PyArray_DescrNewByteorder(descr, NPY_SWAP);
}
else {
// 否则直接使用当前的描述符
odescr = descr;
Py_INCREF(odescr);
}
// 获取适合描述符转换的函数,初始化转换信息
if (PyArray_GetDTypeTransferFunction(
is_aligned, astride, elsize, descr, odescr, 0, &cast_info,
&transfer_flags) != NPY_SUCCEED) {
goto fail;
}
}
// 如果需要索引缓冲区
if (needidxbuffer) {
// 使用自定义的内存分配函数分配空间给idxbuffer
idxbuffer = (npy_intp *)PyDataMem_UserNEW(N * sizeof(npy_intp),
mem_handler);
// 检查分配是否成功
if (idxbuffer == NULL) {
ret = -1;
goto fail;
}
}
// 开始线程安全操作,使用给定的描述符
NPY_BEGIN_THREADS_DESCR(descr);
// 迭代处理数据的每一项
while (size--) {
// 获取当前迭代器指向的值指针和索引指针
char *valptr = it->dataptr;
npy_intp *idxptr = (npy_intp *)rit->dataptr;
npy_intp *iptr, i;
// 如果需要复制操作
if (needcopy) {
// 设置参数数组和步长数组,调用类型转换函数
char *args[2] = {it->dataptr, valbuffer};
npy_intp strides[2] = {astride, elsize};
// 如果转换失败,则跳转到错误处理步骤
if (NPY_UNLIKELY(cast_info.func(
&cast_info.context, args, &N, strides,
cast_info.auxdata) < 0)) {
goto fail;
}
// 更新值指针为valbuffer,表示使用复制后的数据
valptr = valbuffer;
}
// 如果需要索引缓冲区,更新索引指针
if (needidxbuffer) {
idxptr = idxbuffer;
}
// 初始化iptr为idxptr,然后为每个元素设置递增索引
iptr = idxptr;
for (i = 0; i < N; ++i) {
*iptr++ = i;
}
// 如果未提供argpart函数,调用argsort函数进行排序
if (argpart == NULL) {
ret = argsort(valptr, idxptr, N, op);
/* 在Python 3中,对象比较可能引发异常 */
if (needs_api && PyErr_Occurred()) {
ret = -1;
}
// 如果排序操作返回小于0的值,跳转到错误处理步骤
if (ret < 0) {
goto fail;
}
}
// 否则,使用argpart函数进行部分排序
else {
npy_intp pivots[NPY_MAX_PIVOT_STACK];
npy_intp npiv = 0;
for (i = 0; i < nkth; ++i) {
ret = argpart(valptr, idxptr, N, kth[i], pivots, &npiv, nkth, op);
/* 在Python 3中,对象比较可能引发异常 */
if (needs_api && PyErr_Occurred()) {
ret = -1;
}
// 如果排序操作返回小于0的值,跳转到错误处理步骤
if (ret < 0) {
goto fail;
}
}
}
// 如果需要索引缓冲区,将排序后的索引写回原始数据中
if (needidxbuffer) {
char *rptr = rit->dataptr;
iptr = idxbuffer;
for (i = 0; i < N; ++i) {
*(npy_intp *)rptr = *iptr++;
rptr += rstride;
}
}
// 更新迭代器,使其指向下一个元素
PyArray_ITER_NEXT(it);
PyArray_ITER_NEXT(rit);
}
fail:
// 调用描述符的结束线程函数,完成线程操作
NPY_END_THREADS_DESCR(descr);
/* 清理内部缓冲区 */
// 如果需要复制数据,则清理值缓冲区,并释放相关内存
if (needcopy) {
PyArray_ClearBuffer(odescr, valbuffer, elsize, N, 1);
PyDataMem_UserFREE(valbuffer, N * elsize, mem_handler);
Py_DECREF(odescr);
}
// 释放索引缓冲区的内存
PyDataMem_UserFREE(idxbuffer, N * sizeof(npy_intp), mem_handler);
// 如果返回值小于0,表示出现错误
if (ret < 0) {
// 如果没有设置异常,则设置内存不足的异常
if (!PyErr_Occurred()) {
/* 在排序或缓冲区创建过程中内存不足 */
PyErr_NoMemory();
}
// 释放结果对象的引用
Py_XDECREF(rop);
rop = NULL;
}
// 释放迭代器对象的引用
Py_XDECREF(it);
Py_XDECREF(rit);
// 释放内存处理器对象的引用
Py_DECREF(mem_handler);
// 释放类型转换信息的内存
NPY_cast_info_xfree(&cast_info);
// 返回排序后的结果对象
return (PyObject *)rop;
}
/*NUMPY_API
* 对数组进行原地排序
*/
NPY_NO_EXPORT int
PyArray_Sort(PyArrayObject *op, int axis, NPY_SORTKIND which)
{
PyArray_SortFunc *sort = NULL;
int n = PyArray_NDIM(op);
// 检查并调整轴的值
if (check_and_adjust_axis(&axis, n) < 0) {
return -1;
}
// 确保数组可写
if (PyArray_FailUnlessWriteable(op, "sort array") < 0) {
return -1;
}
// 检查排序类型的有效性
if (which < 0 || which >= NPY_NSORTS) {
PyErr_SetString(PyExc_ValueError, "not a valid sort kind");
return -1;
}
// 获取排序函数
sort = PyDataType_GetArrFuncs(PyArray_DESCR(op))->sort[which];
// 如果排序函数为空
if (sort == NULL) {
// 如果类型有比较函数,则根据排序类型选择默认的排序算法
if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare) {
switch (which) {
default:
case NPY_QUICKSORT:
sort = npy_quicksort;
break;
case NPY_HEAPSORT:
sort = npy_heapsort;
break;
case NPY_STABLESORT:
sort = npy_timsort;
break;
}
}
else {
// 类型没有比较函数则设置类型错误异常
PyErr_SetString(PyExc_TypeError,
"type does not have compare function");
return -1;
}
}
// 调用新的排序函数进行排序
return _new_sortlike(op, axis, sort, NULL, NULL, 0);
}
/*
* 使第k个数组元素为正数,展平并排序
*/
static PyArrayObject *
partition_prep_kth_array(PyArrayObject * ktharray,
PyArrayObject * op,
int axis)
{
const npy_intp * shape = PyArray_SHAPE(op);
PyArrayObject * kthrvl;
npy_intp * kth;
npy_intp nkth, i;
// 如果ktharray是布尔类型则发出警告,并返回NULL
if (PyArray_ISBOOL(ktharray)) {
/* 2021-09-29, NumPy 1.22 */
if (DEPRECATE(
"Passing booleans as partition index is deprecated"
" (warning added in NumPy 1.22)") < 0) {
return NULL;
}
}
// 如果ktharray不是整数类型则设置类型错误并返回NULL
else if (!PyArray_ISINTEGER(ktharray)) {
PyErr_Format(PyExc_TypeError, "Partition index must be integer");
return NULL;
}
// 如果ktharray的维度大于1则设置值错误并返回NULL
if (PyArray_NDIM(ktharray) > 1) {
PyErr_Format(PyExc_ValueError, "kth array must have dimension <= 1");
return NULL;
}
// 将ktharray转换为整型数组
kthrvl = (PyArrayObject *)PyArray_Cast(ktharray, NPY_INTP);
if (kthrvl == NULL)
return NULL;
// 获取kth数组的数据指针和大小
kth = PyArray_DATA(kthrvl);
nkth = PyArray_SIZE(kthrvl);
for (i = 0; i < nkth; i++) {
if (kth[i] < 0) {
kth[i] += shape[axis];
}
if (PyArray_SIZE(op) != 0 &&
(kth[i] < 0 || kth[i] >= shape[axis])) {
PyErr_Format(PyExc_ValueError, "kth(=%zd) out of bounds (%zd)",
kth[i], shape[axis]);
Py_XDECREF(kthrvl);
return NULL;
}
}
/*
* 对 kthrvl 数组进行排序,以确保分区不会相互重叠
*/
if (PyArray_SIZE(kthrvl) > 1) {
PyArray_Sort(kthrvl, -1, NPY_QUICKSORT);
}
return kthrvl;
/*NUMPY_API
* Partition an array in-place
*/
NPY_NO_EXPORT int
PyArray_Partition(PyArrayObject *op, PyArrayObject * ktharray, int axis,
NPY_SELECTKIND which)
{
PyArrayObject *kthrvl; // 声明 kthrvl 变量,用于存储处理后的 ktharray 数组对象
PyArray_PartitionFunc *part; // 声明 part 变量,用于存储分区函数指针
PyArray_SortFunc *sort; // 声明 sort 变量,用于存储排序函数指针
int n = PyArray_NDIM(op); // 获取数组 op 的维度数,存入 n 中
int ret; // 声明 ret 变量,用于存储函数返回值
if (check_and_adjust_axis(&axis, n) < 0) { // 检查并调整轴的值,确保其有效性
return -1; // 如果检查失败,返回错误码
}
if (PyArray_FailUnlessWriteable(op, "partition array") < 0) { // 检查数组 op 是否可写
return -1; // 如果不可写,返回错误码
}
if (which < 0 || which >= NPY_NSELECTS) { // 检查 which 参数是否有效
PyErr_SetString(PyExc_ValueError, "not a valid partition kind"); // 设置错误信息
return -1; // 返回错误码
}
part = get_partition_func(PyArray_TYPE(op), which); // 获取分区函数,并存入 part 中
if (part == NULL) { // 如果获取失败
/* Use sorting, slower but equivalent */ // 使用排序代替分区,虽然更慢但功能相同
if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare) { // 检查类型是否具有比较函数
sort = npy_quicksort; // 设置快速排序函数指针
}
else {
PyErr_SetString(PyExc_TypeError,
"type does not have compare function"); // 设置类型不具有比较函数的错误信息
return -1; // 返回错误码
}
}
/* Process ktharray even if using sorting to do bounds checking */
kthrvl = partition_prep_kth_array(ktharray, op, axis); // 准备 ktharray 数组的处理,进行边界检查
if (kthrvl == NULL) { // 如果处理失败
return -1; // 返回错误码
}
ret = _new_sortlike(op, axis, sort, part,
PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl)); // 执行排序或分区操作
Py_DECREF(kthrvl); // 释放 kthrvl 对象的引用计数
return ret; // 返回操作结果
}
/*NUMPY_API
* ArgSort an array
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgSort(PyArrayObject *op, int axis, NPY_SORTKIND which)
{
PyArrayObject *op2; // 声明 op2 变量,用于存储处理后的数组对象
PyArray_ArgSortFunc *argsort = NULL; // 声明 argsort 变量,用于存储排序函数指针
PyObject *ret; // 声明 ret 变量,用于存储函数返回值
argsort = PyDataType_GetArrFuncs(PyArray_DESCR(op))->argsort[which]; // 获取指定排序类型的函数指针
if (argsort == NULL) { // 如果未找到相应的排序函数
if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare) { // 检查类型是否具有比较函数
switch (which) { // 根据排序类型选择排序函数
default:
case NPY_QUICKSORT:
argsort = npy_aquicksort; // 快速排序函数
break;
case NPY_HEAPSORT:
argsort = npy_aheapsort; // 堆排序函数
break;
case NPY_STABLESORT:
argsort = npy_atimsort; // 稳定排序函数
break;
}
}
else {
PyErr_SetString(PyExc_TypeError,
"type does not have compare function"); // 设置类型不具有比较函数的错误信息
return NULL; // 返回空指针表示错误
}
}
op2 = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0); // 检查并返回处理后的数组对象 op2
if (op2 == NULL) { // 如果检查失败
return NULL; // 返回空指针表示错误
}
ret = _new_argsortlike(op2, axis, argsort, NULL, NULL, 0); // 执行类似于排序的操作
Py_DECREF(op2); // 释放 op2 对象的引用计数
return ret; // 返回操作结果
}
/*NUMPY_API
* ArgPartition an array
*/
NPY_NO_EXPORT PyObject *
PyArray_ArgPartition(PyArrayObject *op, PyArrayObject *ktharray, int axis,
NPY_SELECTKIND which)
{
PyArrayObject *op2, *kthrvl; // 声明 op2 和 kthrvl 变量,用于存储处理后的数组对象
PyArray_ArgPartitionFunc *argpart; // 声明 argpart 变量,用于存储分区函数指针
PyArray_ArgSortFunc *argsort; // 声明 argsort 变量,用于存储排序函数指针
PyObject *ret; // 声明 ret 变量,用于存储函数返回值
/*
* As a C-exported function, enum NPY_SELECTKIND loses its enum property
* Check the values to make sure they are in range
*/
// 作为 C 导出函数,枚举 NPY_SELECTKIND 失去其枚举属性
// 检查值以确保其在有效范围内
if ((int)which < 0 || (int)which >= NPY_NSELECTS) {
PyErr_SetString(PyExc_ValueError,
"not a valid partition kind");
return NULL;
}
argpart = get_argpartition_func(PyArray_TYPE(op), which);
if (argpart == NULL) {
/* 如果没有找到相应的分区函数,则使用排序(更慢但等效的方法) */
if (PyDataType_GetArrFuncs(PyArray_DESCR(op))->compare) {
argsort = npy_aquicksort;
}
else {
PyErr_SetString(PyExc_TypeError,
"type does not have compare function");
return NULL;
}
}
op2 = (PyArrayObject *)PyArray_CheckAxis(op, &axis, 0);
if (op2 == NULL) {
return NULL;
}
kthrvl = partition_prep_kth_array(ktharray, op2, axis);
if (kthrvl == NULL) {
Py_DECREF(op2);
return NULL;
}
ret = _new_argsortlike(op2, axis, argsort, argpart,
PyArray_DATA(kthrvl), PyArray_SIZE(kthrvl));
Py_DECREF(kthrvl);
Py_DECREF(op2);
return ret;
/*NUMPY_API
*LexSort an array providing indices that will sort a collection of arrays
*lexicographically. The first key is sorted on first, followed by the second key
*-- requires that arg"merge"sort is available for each sort_key
*
*Returns an index array that shows the indexes for the lexicographic sort along
*the given axis.
*/
NPY_NO_EXPORT PyObject *
PyArray_LexSort(PyObject *sort_keys, int axis)
{
PyArrayObject **mps; // 指向排序键的数组对象的指针数组
PyArrayIterObject **its; // 指向排序键的迭代器对象的指针数组
PyArrayObject *ret = NULL; // 返回的排序后的索引数组对象
PyArrayIterObject *rit = NULL; // 返回的排序后的索引数组的迭代器对象
npy_intp n, N, size, i, j; // 整数变量声明
npy_intp astride, rstride, *iptr; // 整数变量声明
int nd; // 数组的维数
int needcopy = 0; // 是否需要复制的标志位
int elsize; // 元素大小
int maxelsize; // 最大元素大小
int object = 0; // 是否包含对象数组的标志位
PyArray_ArgSortFunc *argsort; // 排序函数指针
NPY_BEGIN_THREADS_DEF; // 多线程宏定义的开始
if (!PySequence_Check(sort_keys) // 检查排序键是否是序列对象
|| ((n = PySequence_Size(sort_keys)) <= 0)) { // 获取排序键的长度并检查是否大于0
PyErr_SetString(PyExc_TypeError,
"need sequence of keys with len > 0 in lexsort"); // 设置错误信息并返回空指针
return NULL;
}
mps = (PyArrayObject **) PyArray_malloc(n * sizeof(PyArrayObject *)); // 分配排序键数组对象的指针数组内存
if (mps == NULL) { // 内存分配失败处理
return PyErr_NoMemory(); // 返回内存错误信息
}
its = (PyArrayIterObject **) PyArray_malloc(n * sizeof(PyArrayIterObject *)); // 分配排序键迭代器对象的指针数组内存
if (its == NULL) { // 内存分配失败处理
PyArray_free(mps); // 释放之前分配的mps内存
return PyErr_NoMemory(); // 返回内存错误信息
}
for (i = 0; i < n; i++) { // 遍历排序键数目
mps[i] = NULL; // 初始化每个排序键数组对象指针为空
its[i] = NULL; // 初始化每个排序键迭代器对象指针为空
}
for (i = 0; i < n; i++) { // 再次遍历排序键数目
PyObject *obj; // Python对象指针声明
obj = PySequence_GetItem(sort_keys, i); // 获取排序键序列中第i个对象
if (obj == NULL) { // 获取失败处理
goto fail; // 跳转到错误处理标签
}
mps[i] = (PyArrayObject *)PyArray_FROM_O(obj); // 从Python对象创建排序键的数组对象
Py_DECREF(obj); // 减少对象的引用计数
if (mps[i] == NULL) { // 创建失败处理
goto fail; // 跳转到错误处理标签
}
if (i > 0) { // 对于非第一个排序键
if ((PyArray_NDIM(mps[i]) != PyArray_NDIM(mps[0])) // 检查维度是否相同
|| (!PyArray_CompareLists(PyArray_DIMS(mps[i]), // 检查维度列表是否相同
PyArray_DIMS(mps[0]),
PyArray_NDIM(mps[0])))) {
PyErr_SetString(PyExc_ValueError,
"all keys need to be the same shape"); // 设置错误信息并返回空指针
goto fail; // 跳转到错误处理标签
}
}
if (!PyDataType_GetArrFuncs(PyArray_DESCR(mps[i]))->argsort[NPY_STABLESORT] // 检查是否支持稳定排序
&& !PyDataType_GetArrFuncs(PyArray_DESCR(mps[i]))->compare) { // 检查是否有比较函数
PyErr_Format(PyExc_TypeError,
"item %zd type does not have compare function", i); // 设置错误信息并返回空指针
goto fail; // 跳转到错误处理标签
}
if (!object // 如果不是对象数组
&& PyDataType_FLAGCHK(PyArray_DESCR(mps[i]), NPY_NEEDS_PYAPI)) { // 检查是否需要Python API
object = 1; // 设置对象标志位
}
}
/* Now we can check the axis */
nd = PyArray_NDIM(mps[0]); // 获取第一个排序键的维数
/*
* Special case letting axis={-1,0} slip through for scalars,
* for backwards compatibility reasons.
*/
if (nd == 0 && (axis == 0 || axis == -1)) {
/* TODO: can we deprecate this? */
}
else if (check_and_adjust_axis(&axis, nd) < 0) { // 检查并调整轴的有效性
goto fail; // 跳转到错误处理标签
}
if ((nd == 0) || (PyArray_SIZE(mps[0]) <= 1)) {
/* empty/single element case */
// 如果输入数组的维度为0或者第一个数组的元素数量小于等于1,则处理空数组或单元素情况
// 创建一个新的数组对象,用于返回结果
ret = (PyArrayObject *)PyArray_NewFromDescr(
&PyArray_Type, PyArray_DescrFromType(NPY_INTP),
PyArray_NDIM(mps[0]), PyArray_DIMS(mps[0]), NULL, NULL,
0, NULL);
// 检查数组对象是否创建成功
if (ret == NULL) {
goto fail;
}
// 如果第一个数组的元素数量大于0,将第一个元素设为0
if (PyArray_SIZE(mps[0]) > 0) {
*((npy_intp *)(PyArray_DATA(ret))) = 0;
}
// 跳转到完成处理的标签位置
goto finish;
}
// 为每个输入数组创建迭代器对象
for (i = 0; i < n; i++) {
its[i] = (PyArrayIterObject *)PyArray_IterAllButAxis(
(PyObject *)mps[i], &axis);
// 检查迭代器对象是否创建成功
if (its[i] == NULL) {
goto fail;
}
}
/* Now do the sorting */
// 创建一个新的整数数组对象用于排序结果
ret = (PyArrayObject *)PyArray_NewFromDescr(
&PyArray_Type, PyArray_DescrFromType(NPY_INTP),
PyArray_NDIM(mps[0]), PyArray_DIMS(mps[0]), NULL, NULL,
0, NULL);
if (ret == NULL) {
goto fail;
}
// 创建一个迭代器对象用于排序结果
rit = (PyArrayIterObject *)
PyArray_IterAllButAxis((PyObject *)ret, &axis);
if (rit == NULL) {
goto fail;
}
// 如果不是对象数组,则开始线程处理
if (!object) {
NPY_BEGIN_THREADS;
}
// 初始化变量
size = rit->size;
N = PyArray_DIMS(mps[0])[axis];
rstride = PyArray_STRIDE(ret, axis);
maxelsize = PyArray_ITEMSIZE(mps[0]);
needcopy = (rstride != sizeof(npy_intp));
// 检查是否需要复制数据
for (j = 0; j < n; j++) {
needcopy = needcopy
|| PyArray_ISBYTESWAPPED(mps[j])
|| !(PyArray_FLAGS(mps[j]) & NPY_ARRAY_ALIGNED)
|| (PyArray_STRIDES(mps[j])[axis] != (npy_intp)PyArray_ITEMSIZE(mps[j]));
// 更新最大元素大小
if (PyArray_ITEMSIZE(mps[j]) > maxelsize) {
maxelsize = PyArray_ITEMSIZE(mps[j]);
}
}
if (needcopy) {
// 如果需要进行复制操作,则进入此条件分支
char *valbuffer, *indbuffer;
int *swaps;
// 确保 N 大于 0,这是由 indbuffer 保证的前提条件
assert(N > 0);
// 计算要分配的 valbuffer 的大小,并确保至少为 1,避免空的分配
npy_intp valbufsize = N * maxelsize;
if (NPY_UNLIKELY(valbufsize) == 0) {
valbufsize = 1;
}
// 分配 valbuffer 内存空间
valbuffer = PyDataMem_NEW(valbufsize);
if (valbuffer == NULL) {
// 分配失败时跳转到 fail 标签处处理
goto fail;
}
// 分配 indbuffer 内存空间
indbuffer = PyDataMem_NEW(N * sizeof(npy_intp));
if (indbuffer == NULL) {
// 分配失败时释放之前分配的 valbuffer,并跳转到 fail 标签处处理
PyDataMem_FREE(valbuffer);
goto fail;
}
// 分配 swaps 数组的内存空间
swaps = malloc(NPY_LIKELY(n > 0) ? n * sizeof(int) : 1);
if (swaps == NULL) {
// 分配失败时释放之前分配的 valbuffer 和 indbuffer,并跳转到 fail 标签处处理
PyDataMem_FREE(valbuffer);
PyDataMem_FREE(indbuffer);
goto fail;
}
// 对 swaps 数组进行初始化,判断是否需要字节交换
for (j = 0; j < n; j++) {
swaps[j] = PyArray_ISBYTESWAPPED(mps[j]);
}
// 处理每一个元素
while (size--) {
iptr = (npy_intp *)indbuffer;
// 初始化 indbuffer 数组
for (i = 0; i < N; i++) {
*iptr++ = i;
}
// 对每一个数组进行排序操作
for (j = 0; j < n; j++) {
int rcode;
elsize = PyArray_ITEMSIZE(mps[j]);
astride = PyArray_STRIDES(mps[j])[axis];
// 获取排序函数,如果未找到则使用默认的排序函数 npy_atimsort
argsort = PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->argsort[NPY_STABLESORT];
if(argsort == NULL) {
argsort = npy_atimsort;
}
// 复制数据到 valbuffer,并进行可能的字节交换
_unaligned_strided_byte_copy(valbuffer, (npy_intp) elsize,
its[j]->dataptr, astride, N, elsize);
if (swaps[j]) {
_strided_byte_swap(valbuffer, (npy_intp) elsize, N, elsize);
}
// 调用排序函数进行排序
rcode = argsort(valbuffer, (npy_intp *)indbuffer, N, mps[j]);
if (rcode < 0 || (PyDataType_REFCHK(PyArray_DESCR(mps[j]))
&& PyErr_Occurred())) {
// 排序失败时释放所有内存,并跳转到 fail 标签处处理
PyDataMem_FREE(valbuffer);
PyDataMem_FREE(indbuffer);
free(swaps);
goto fail;
}
// 移动到下一个数组元素
PyArray_ITER_NEXT(its[j]);
}
// 将排序后的索引数据复制到结果数组中
_unaligned_strided_byte_copy(rit->dataptr, rstride, indbuffer,
sizeof(npy_intp), N, sizeof(npy_intp));
// 移动到结果数组的下一个位置
PyArray_ITER_NEXT(rit);
}
// 完成所有操作后释放内存
PyDataMem_FREE(valbuffer);
PyDataMem_FREE(indbuffer);
free(swaps);
}
else {
while (size--) {
// 获取当前迭代器的数据指针,并将其转换为整数指针
iptr = (npy_intp *)rit->dataptr;
// 对当前迭代器的数据指针进行赋值操作,从0到N-1
for (i = 0; i < N; i++) {
*iptr++ = i;
}
// 遍历mps数组,对每个元素进行排序操作
for (j = 0; j < n; j++) {
int rcode;
// 获取排序函数,如果为NULL,则使用默认排序函数npy_atimsort
argsort = PyDataType_GetArrFuncs(PyArray_DESCR(mps[j]))->argsort[NPY_STABLESORT];
if(argsort == NULL) {
argsort = npy_atimsort;
}
// 调用排序函数进行排序
rcode = argsort(its[j]->dataptr, (npy_intp *)rit->dataptr, N, mps[j]);
// 检查排序操作是否成功,如果失败则跳转到fail标签处理
if (rcode < 0 || (PyDataType_REFCHK(PyArray_DESCR(mps[j]))
&& PyErr_Occurred())) {
goto fail;
}
// 移动到下一个迭代器
PyArray_ITER_NEXT(its[j]);
}
// 移动到下一个迭代器
PyArray_ITER_NEXT(rit);
}
}
// 如果object为假值,结束多线程状态
if (!object) {
NPY_END_THREADS;
}
finish:
// 释放mps和its数组的每个元素的引用计数
for (i = 0; i < n; i++) {
Py_XDECREF(mps[i]);
Py_XDECREF(its[i]);
}
// 释放rit迭代器的引用计数
Py_XDECREF(rit);
// 释放mps和its数组的内存
PyArray_free(mps);
PyArray_free(its);
// 返回ret对象
return (PyObject *)ret;
fail:
// 失败处理:结束多线程状态
NPY_END_THREADS;
// 如果没有设置错误状态,则设置内存分配失败的错误状态
if (!PyErr_Occurred()) {
/* Out of memory during sorting or buffer creation */
PyErr_NoMemory();
}
// 释放rit迭代器的引用计数
Py_XDECREF(rit);
// 释放ret对象的引用计数
Py_XDECREF(ret);
// 释放mps和its数组的每个元素的引用计数
for (i = 0; i < n; i++) {
Py_XDECREF(mps[i]);
Py_XDECREF(its[i]);
}
// 释放mps和its数组的内存
PyArray_free(mps);
PyArray_free(its);
// 返回空值
return NULL;
/*NUMPY_API
*
* Search the sorted array op1 for the location of the items in op2. The
* result is an array of indexes, one for each element in op2, such that if
* the item were to be inserted in op1 just before that index the array
* would still be in sorted order.
*
* Parameters
* ----------
* op1 : PyArrayObject *
* Array to be searched, must be 1-D.
* op2 : PyObject *
* Array of items whose insertion indexes in op1 are wanted
* side : {NPY_SEARCHLEFT, NPY_SEARCHRIGHT}
* If NPY_SEARCHLEFT, return first valid insertion indexes
* If NPY_SEARCHRIGHT, return last valid insertion indexes
* perm : PyObject *
* Permutation array that sorts op1 (optional)
*
* Returns
* -------
* ret : PyObject *
* New reference to npy_intp array containing indexes where items in op2
* could be validly inserted into op1. NULL on error.
*
* Notes
* -----
* Binary search is used to find the indexes.
*/
NPY_NO_EXPORT PyObject *
PyArray_SearchSorted(PyArrayObject *op1, PyObject *op2,
NPY_SEARCHSIDE side, PyObject *perm)
{
PyArrayObject *ap1 = NULL;
PyArrayObject *ap2 = NULL;
PyArrayObject *ap3 = NULL;
PyArrayObject *sorter = NULL;
PyArrayObject *ret = NULL;
PyArray_Descr *dtype;
int ap1_flags = NPY_ARRAY_NOTSWAPPED | NPY_ARRAY_ALIGNED;
PyArray_BinSearchFunc *binsearch = NULL;
PyArray_ArgBinSearchFunc *argbinsearch = NULL;
NPY_BEGIN_THREADS_DEF;
// 寻找与 op2 的共同类型
dtype = PyArray_DescrFromObject((PyObject *)op2, PyArray_DESCR(op1));
if (dtype == NULL) {
return NULL;
}
/* refs to dtype we own = 1 */
// 查找二分搜索函数
if (perm) {
argbinsearch = get_argbinsearch_func(dtype, side);
}
else {
binsearch = get_binsearch_func(dtype, side);
}
if (binsearch == NULL && argbinsearch == NULL) {
PyErr_SetString(PyExc_TypeError, "compare not supported for type");
/* refs to dtype we own = 1 */
Py_DECREF(dtype);
/* refs to dtype we own = 0 */
return NULL;
}
// 需要将 ap2 转换为连续数组并且是正确的类型
/* refs to dtype we own = 1 */
Py_INCREF(dtype);
/* refs to dtype we own = 2 */
ap2 = (PyArrayObject *)PyArray_CheckFromAny(op2, dtype,
0, 0,
NPY_ARRAY_CARRAY_RO | NPY_ARRAY_NOTSWAPPED,
NULL);
/* refs to dtype we own = 1, array creation steals one even on failure */
if (ap2 == NULL) {
Py_DECREF(dtype);
/* refs to dtype we own = 0 */
return NULL;
}
/*
* 如果要查找的元素 (ap2) 大于待查找的数组 (op1),我们将待查找数组复制到一个连续的数组以提高缓存利用率。
*/
if (PyArray_SIZE(ap2) > PyArray_SIZE(op1)) {
ap1_flags |= NPY_ARRAY_CARRAY_RO;
}
ap1 = (PyArrayObject *)PyArray_CheckFromAny((PyObject *)op1, dtype,
1, 1, ap1_flags, NULL);
/* 检查 op1 是否可以转换为指定的 NumPy 数组对象,要求一维数组,数据类型由 dtype 指定 */
/* 当 ap1 为 NULL 时,跳转到错误处理部分 */
if (ap1 == NULL) {
goto fail;
}
if (perm) {
/* 将 perm 转换为一维的、对齐的、未交换字节序的 NumPy 数组对象 */
ap3 = (PyArrayObject *)PyArray_CheckFromAny(perm, NULL,
1, 1,
NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED,
NULL);
/* 当 ap3 为 NULL 时,设置类型错误异常,并跳转到错误处理部分 */
if (ap3 == NULL) {
PyErr_SetString(PyExc_TypeError,
"could not parse sorter argument");
goto fail;
}
/* 当 ap3 不是整数类型的数组时,设置类型错误异常,并跳转到错误处理部分 */
if (!PyArray_ISINTEGER(ap3)) {
PyErr_SetString(PyExc_TypeError,
"sorter must only contain integers");
goto fail;
}
/* 将 ap3 转换为已知的整数类型数组 */
sorter = (PyArrayObject *)PyArray_FromArray(ap3,
PyArray_DescrFromType(NPY_INTP),
NPY_ARRAY_ALIGNED | NPY_ARRAY_NOTSWAPPED);
/* 当 sorter 为 NULL 时,设置数值错误异常,并跳转到错误处理部分 */
if (sorter == NULL) {
PyErr_SetString(PyExc_ValueError,
"could not parse sorter argument");
goto fail;
}
/* 检查 sorter 的大小是否与 ap1 的大小相等,不相等则设置数值错误异常,并跳转到错误处理部分 */
if (PyArray_SIZE(sorter) != PyArray_SIZE(ap1)) {
PyErr_SetString(PyExc_ValueError,
"sorter.size must equal a.size");
goto fail;
}
}
/* 创建一个整数类型的连续数组 ret,用于存储返回的索引 */
ret = (PyArrayObject *)PyArray_NewFromDescr(
&PyArray_Type, PyArray_DescrFromType(NPY_INTP),
PyArray_NDIM(ap2), PyArray_DIMS(ap2), NULL, NULL,
0, (PyObject *)ap2);
/* 当 ret 为 NULL 时,跳转到错误处理部分 */
if (ret == NULL) {
goto fail;
}
if (ap3 == NULL) {
/* 执行常规的二分查找 */
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
binsearch((const char *)PyArray_DATA(ap1),
(const char *)PyArray_DATA(ap2),
(char *)PyArray_DATA(ret),
PyArray_SIZE(ap1), PyArray_SIZE(ap2),
PyArray_STRIDES(ap1)[0], PyArray_ITEMSIZE(ap2),
NPY_SIZEOF_INTP, ap2);
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
}
else {
/* 使用排序数组进行二分查找 */
// 定义错误变量
int error = 0;
// 开始线程安全操作,根据 ap2 的描述符
NPY_BEGIN_THREADS_DESCR(PyArray_DESCR(ap2));
// 调用二分查找函数 argbinsearch,处理 ap1、ap2、sorter 和 ret 的数据
error = argbinsearch((const char *)PyArray_DATA(ap1),
(const char *)PyArray_DATA(ap2),
(const char *)PyArray_DATA(sorter),
(char *)PyArray_DATA(ret),
PyArray_SIZE(ap1), PyArray_SIZE(ap2),
PyArray_STRIDES(ap1)[0],
PyArray_ITEMSIZE(ap2),
PyArray_STRIDES(sorter)[0], NPY_SIZEOF_INTP, ap2);
// 结束线程安全操作,根据 ap2 的描述符
NPY_END_THREADS_DESCR(PyArray_DESCR(ap2));
// 如果二分查找出错,设置异常并跳转到 fail 标签
if (error < 0) {
PyErr_SetString(PyExc_ValueError,
"Sorter index out of range.");
goto fail;
}
// 释放对象引用:ap3 和 sorter
Py_DECREF(ap3);
Py_DECREF(sorter);
}
// 释放对象引用:ap1 和 ap2
Py_DECREF(ap1);
Py_DECREF(ap2);
// 返回 ret 对象的 PyObject 指针
return (PyObject *)ret;
fail:
// 在发生失败时释放对象引用:ap1、ap2、ap3、sorter 和 ret
Py_XDECREF(ap1);
Py_XDECREF(ap2);
Py_XDECREF(ap3);
Py_XDECREF(sorter);
Py_XDECREF(ret);
// 返回 NULL 指针表示函数执行失败
return NULL;
/*NUMPY_API
* Diagonal
*
* In NumPy versions prior to 1.7, this function always returned a copy of
* the diagonal array. In 1.7, the code has been updated to compute a view
* onto 'self', but it still copies this array before returning, as well as
* setting the internal WARN_ON_WRITE flag. In a future version, it will
* simply return a view onto self.
*/
NPY_NO_EXPORT PyObject *
PyArray_Diagonal(PyArrayObject *self, int offset, int axis1, int axis2)
{
int i, idim, ndim = PyArray_NDIM(self);
npy_intp *strides;
npy_intp stride1, stride2, offset_stride;
npy_intp *shape, dim1, dim2;
char *data;
npy_intp diag_size;
PyArray_Descr *dtype;
PyObject *ret;
npy_intp ret_shape[NPY_MAXDIMS], ret_strides[NPY_MAXDIMS];
if (ndim < 2) {
PyErr_SetString(PyExc_ValueError,
"diag requires an array of at least two dimensions");
return NULL;
}
/* Handle negative axes with standard Python indexing rules */
if (check_and_adjust_axis_msg(&axis1, ndim, npy_interned_str.axis1) < 0) {
return NULL;
}
if (check_and_adjust_axis_msg(&axis2, ndim, npy_interned_str.axis2) < 0) {
return NULL;
}
if (axis1 == axis2) {
PyErr_SetString(PyExc_ValueError,
"axis1 and axis2 cannot be the same");
return NULL;
}
/* Get the shape and strides of the two axes */
shape = PyArray_SHAPE(self);
dim1 = shape[axis1];
dim2 = shape[axis2];
strides = PyArray_STRIDES(self);
stride1 = strides[axis1];
stride2 = strides[axis2];
/* Compute the data pointers and diag_size for the view */
data = PyArray_DATA(self);
if (offset >= 0) {
offset_stride = stride2;
dim2 -= offset;
}
else {
offset = -offset;
offset_stride = stride1;
dim1 -= offset;
}
diag_size = dim2 < dim1 ? dim2 : dim1;
if (diag_size < 0) {
diag_size = 0;
}
else {
data += offset * offset_stride;
}
/* Build the new shape and strides for the main data */
i = 0;
for (idim = 0; idim < ndim; ++idim) {
if (idim != axis1 && idim != axis2) {
ret_shape[i] = shape[idim];
ret_strides[i] = strides[idim];
++i;
}
}
ret_shape[ndim-2] = diag_size;
ret_strides[ndim-2] = stride1 + stride2;
/* Create the diagonal view */
dtype = PyArray_DTYPE(self);
Py_INCREF(dtype);
ret = PyArray_NewFromDescrAndBase(
Py_TYPE(self), dtype,
ndim-1, ret_shape, ret_strides, data,
PyArray_FLAGS(self), (PyObject *)self, (PyObject *)self);
if (ret == NULL) {
return NULL;
}
/*
* For numpy 1.9 the diagonal view is not writeable.
* This line needs to be removed in 1.10.
*/
PyArray_CLEARFLAGS((PyArrayObject *)ret, NPY_ARRAY_WRITEABLE);
return ret;
}
/* 压缩数组元素,将满足条件的元素压缩成一个数组,并返回结果 */
PyArray_Compress(PyArrayObject *self, PyObject *condition, int axis,
PyArrayObject *out)
{
PyArrayObject *cond; // 条件数组对象
PyObject *res, *ret; // 结果对象和返回对象
if (PyArray_Check(condition)) {
cond = (PyArrayObject *)condition; // 如果条件是数组,则直接使用
Py_INCREF(cond); // 增加条件数组的引用计数
}
else {
// 如果条件不是数组,则创建一个布尔类型的数组
PyArray_Descr *dtype = PyArray_DescrFromType(NPY_BOOL);
if (dtype == NULL) {
return NULL; // 如果创建描述符失败则返回空
}
cond = (PyArrayObject *)PyArray_FromAny(condition, dtype,
0, 0, 0, NULL); // 将条件转换为布尔类型数组
if (cond == NULL) {
return NULL; // 如果转换失败则返回空
}
}
if (PyArray_NDIM(cond) != 1) {
Py_DECREF(cond); // 如果条件数组维度不是1,则释放条件数组
PyErr_SetString(PyExc_ValueError,
"condition must be a 1-d array"); // 抛出值错误异常
return NULL; // 返回空
}
res = PyArray_Nonzero(cond); // 找出条件数组中非零元素的索引
Py_DECREF(cond); // 释放条件数组
if (res == NULL) {
return res; // 如果结果为空,则直接返回空
}
ret = PyArray_TakeFrom(self, PyTuple_GET_ITEM(res, 0), axis,
out, NPY_RAISE); // 从数组中按索引取出元素形成新数组
Py_DECREF(res); // 释放结果对象
return ret; // 返回结果对象
}
/*
* 计算 48 字节块中非零字节的数量
* w 必须按 8 字节对齐
*
* 即使它使用 64 位类型,它比 32 位平台上的逐字节求和更快
* 但是在这些平台上,使用 32 位类型版本将使其更快
*/
static inline npy_intp
count_nonzero_bytes_384(const npy_uint64 * w)
{
const npy_uint64 w1 = w[0];
const npy_uint64 w2 = w[1];
const npy_uint64 w3 = w[2];
const npy_uint64 w4 = w[3];
const npy_uint64 w5 = w[4];
const npy_uint64 w6 = w[5];
npy_intp r;
/*
* 最后部分的横向加法和 popcount,前三个二分可以跳过,因为我们正在处理字节。
* 乘法等同于 (x + (x>>8) + (x>>16) + (x>>24)) & 0xFF
* 无符号类型的乘法溢出在定义上是良好的。
* w1 + w2 确保不会溢出,因为数据只有 0 和 1。
*/
r = ((w1 + w2 + w3 + w4 + w5 + w6) * 0x0101010101010101ULL) >> 56ULL;
/*
* 字节不全为 0 或 1,则逐个求和。
* 只有在视图或外部缓冲区中做了奇怪的操作时才会发生。
* 在乐观计算之后执行此操作允许节省寄存器并实现更好的流水线处理。
*/
if (NPY_UNLIKELY(
((w1 | w2 | w3 | w4 | w5 | w6) & 0xFEFEFEFEFEFEFEFEULL) != 0)) {
/* 重新加载指针以避免与 gcc 的不必要的堆栈溢出 */
const char * c = (const char *)w;
npy_uintp i, count = 0;
for (i = 0; i < 48; i++) {
count += (c[i] != 0); // 统计非零字节的数量
}
return count; // 返回统计结果
}
return r; // 返回快速计算的结果
}
/* 计算 `*d` 和 `end` 之间的零字节数量,更新 `*d` 指向下一个要计算的位置 */
NPY_FINLINE NPY_GCC_OPT_3 npyv_u8
count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end, npy_uint8 max_count)
{
const npyv_u8 vone = npyv_setall_u8(1); // 创建所有元素为 1 的向量
const npyv_u8 vzero = npyv_zero_u8(); // 创建所有元素为 0 的向量
npy_intp lane_max = 0; // 最大车道数
npyv_u8 vsum8 = npyv_zero_u8();
while (*d < end && lane_max <= max_count - 1) {
npyv_u8 vt = npyv_cvt_u8_b8(npyv_cmpeq_u8(npyv_load_u8(*d), vzero));
vt = npyv_and_u8(vt, vone);
vsum8 = npyv_add_u8(vsum8, vt);
*d += npyv_nlanes_u8;
lane_max += 1;
}
return vsum8;
/*
* Counts the number of non-zero values in a raw array of unsigned 16-bit integers.
* Depending on SIMD availability, it uses vectorized operations for efficient counting.
*/
static inline NPY_GCC_OPT_3 npy_intp
count_nonzero_u16(const char *data, npy_intp bstride, npy_uintp len)
{
npy_intp count = 0;
// Check if SIMD (Single Instruction, Multiple Data) optimization is available
// If bstride is 1, perform SIMD operations for optimal counting
if (bstride == 1) {
npy_uintp len_m = len & -npyv_nlanes_u8;
npy_uintp zcount = 0;
// Process the data in chunks based on SIMD vectorization
for (const char *end = data + len_m; data < end;) {
// Count zero bytes using SIMD for 16-bit integers
npyv_u16x2 vsum16 = count_zero_bytes_u16((const npy_uint8**)&data, (const npy_uint8*)end, NPY_MAX_UINT16);
// Expand the 16-bit sums to 32-bit integers
npyv_u32x2 sum_32_0 = npyv_expand_u32_u16(vsum16.val[0]);
npyv_u32x2 sum_32_1 = npyv_expand_u32_u16(vsum16.val[1]);
// Sum the 32-bit values to get the count of non-zero elements
zcount += npyv_sum_u32(npyv_add_u32(
npyv_add_u32(sum_32_0.val[0], sum_32_0.val[1]),
npyv_add_u32(sum_32_1.val[0], sum_32_1.val[1])
));
}
// Adjust the remaining length after SIMD processing
len -= len_m;
// Calculate the total count of non-zero elements
count = len_m - zcount;
} else {
// If bstride is not 1 and SIMD is available, but the stride is not optimal,
// fall back to non-SIMD approach if alignment conditions are not met
if (!NPY_ALIGNMENT_REQUIRED || npy_is_aligned(data, sizeof(npy_uint64))) {
// Define a step size to process data in chunks of 6 * sizeof(npy_uint64)
int step = 6 * sizeof(npy_uint64);
// Calculate the remaining bytes to process after the aligned chunks
int left_bytes = len % step;
// Process the aligned chunks using a specialized function
for (const char *end = data + len; data < end - left_bytes; data += step) {
count += count_nonzero_bytes_384((const npy_uint64 *)data);
}
// Process the remaining bytes
len = left_bytes;
}
}
// If SIMD is not available, or bstride != 1 and alignment is required, use a fallback
// If SIMD is not available, or if bstride != 1 and alignment is required,
// fall back to a sequential non-SIMD approach
// Count non-zero elements sequentially for the remaining data
for (; len > 0; --len, data += bstride) {
count += (*data != 0);
}
// Return the total count of non-zero elements in the array
return count;
}
if (bstride == sizeof(npy_uint16)) {
npy_uintp zcount = 0, len_m = len & -npyv_nlanes_u16;
const npyv_u16 vone = npyv_setall_u16(1);
const npyv_u16 vzero = npyv_zero_u16();
for (npy_uintp lenx = len_m; lenx > 0;) {
npyv_u16 vsum16 = npyv_zero_u16();
npy_uintp max16 = PyArray_MIN(lenx, NPY_MAX_UINT16*npyv_nlanes_u16);
for (const char *end = data + max16*bstride; data < end; data += NPY_SIMD_WIDTH) {
npyv_u16 mask = npyv_cvt_u16_b16(npyv_cmpeq_u16(npyv_load_u16((npy_uint16*)data), vzero));
mask = npyv_and_u16(mask, vone);
vsum16 = npyv_add_u16(vsum16, mask);
}
lenx -= max16;
zcount += npyv_sumup_u16(vsum16);
}
len -= len_m;
count = len_m - zcount;
}
/*
* 在条件编译指令结束后,开始定义函数count_nonzero_u32,用于计算非零元素个数
*/
static inline NPY_GCC_OPT_3 npy_intp
count_nonzero_u32(const char *data, npy_intp bstride, npy_uintp len)
{
npy_intp count = 0;
// 如果步长等于4字节(即sizeof(npy_uint32)),则启用SIMD优化
if (bstride == sizeof(npy_uint32)) {
// 计算最大迭代次数
const npy_uintp max_iter = NPY_MAX_UINT32 * npyv_nlanes_u32;
// 计算实际处理的数据长度
const npy_uintp len_m = (len > max_iter ? max_iter : len) & -npyv_nlanes_u32;
// 创建SIMD向量,所有元素初始化为1和0
const npyv_u32 vone = npyv_setall_u32(1);
const npyv_u32 vzero = npyv_zero_u32();
npyv_u32 vsum32 = npyv_zero_u32();
// 使用SIMD进行循环,逐步处理数据
for (const char *end = data + len_m * bstride; data < end; data += NPY_SIMD_WIDTH) {
// 加载数据并进行比较,生成掩码
npyv_u32 mask = npyv_cvt_u32_b32(npyv_cmpeq_u32(npyv_load_u32((npy_uint32*)data), vzero));
// 掩码与全1向量进行与运算
mask = npyv_and_u32(mask, vone);
// 向量加法,累加掩码结果
vsum32 = npyv_add_u32(vsum32, mask);
}
// 对奇偶向量进行处理,计算总的非零元素个数
const npyv_u32 maskevn = npyv_reinterpret_u32_u64(npyv_setall_u64(0xffffffffULL));
npyv_u64 odd = npyv_shri_u64(npyv_reinterpret_u64_u32(vsum32), 32);
npyv_u64 even = npyv_reinterpret_u64_u32(npyv_and_u32(vsum32, maskevn));
count = len_m - npyv_sum_u64(npyv_add_u64(odd, even));
// 更新剩余长度
len -= len_m;
}
// 普通循环,处理剩余数据
for (; len > 0; --len, data += bstride) {
count += (*(npy_uint32*)data != 0);
}
// 返回非零元素个数
return count;
}
/*
* 在条件编译指令结束后,开始定义函数count_nonzero_u64,用于计算非零元素个数
*/
static inline NPY_GCC_OPT_3 npy_intp
count_nonzero_u64(const char *data, npy_intp bstride, npy_uintp len)
{
npy_intp count = 0;
// 如果步长等于8字节(即sizeof(npy_uint64)),则启用SIMD优化
if (bstride == sizeof(npy_uint64)) {
// 计算实际处理的数据长度
const npy_uintp len_m = len & -npyv_nlanes_u64;
// 创建SIMD向量,所有元素初始化为1和0
const npyv_u64 vone = npyv_setall_u64(1);
const npyv_u64 vzero = npyv_zero_u64();
npyv_u64 vsum64 = npyv_zero_u64();
// 使用SIMD进行循环,逐步处理数据
for (const char *end = data + len_m * bstride; data < end; data += NPY_SIMD_WIDTH) {
// 加载数据并进行比较,生成掩码
npyv_u64 mask = npyv_cvt_u64_b64(npyv_cmpeq_u64(npyv_load_u64((npy_uint64*)data), vzero));
// 掩码与全1向量进行与运算
mask = npyv_and_u64(mask, vone);
// 向量加法,累加掩码结果
vsum64 = npyv_add_u64(vsum64, mask);
}
// 计算总的非零元素个数
count = len_m - npyv_sum_u64(vsum64);
// 更新剩余长度
len -= len_m;
}
// 普通循环,处理剩余数据
for (; len > 0; --len, data += bstride) {
count += (*(npy_uint64*)data != 0);
}
// 返回非零元素个数
return count;
}
/*
* 在函数定义之前添加注释,描述该函数的功能和返回值含义
*/
static NPY_GCC_OPT_3 npy_intp
count_nonzero_int(int ndim, char *data, const npy_intp *ashape, const npy_intp *astrides, int elsize)
{
assert(elsize <= 8);
int idim;
npy_intp shape[NPY_MAXDIMS], strides[NPY_MAXDIMS];
npy_intp coord[NPY_MAXDIMS];
// 使用原始迭代处理,无堆内存分配
if (PyArray_PrepareOneRawArrayIter(
ndim, ashape,
data, astrides,
&ndim, shape,
&data, strides) < 0) {
return -1;
}
// 处理长度为零的数组情况,如果数组第一个维度的长度为零,则直接返回计数为零
if (shape[0] == 0) {
return 0;
}
// 开始多线程操作的宏定义,根据条件决定是否开启多线程
NPY_BEGIN_THREADS_DEF;
NPY_BEGIN_THREADS_THRESHOLDED(shape[0]);
// 定义宏 NONZERO_CASE,根据元素大小不同进行不同的非零计数操作
case LEN: \
// 使用原始迭代器开始迭代,遍历数组中的元素,计算非零元素的个数
NPY_RAW_ITER_START(idim, ndim, coord, shape) { \
count += count_nonzero_
} NPY_RAW_ITER_ONE_NEXT(idim, ndim, coord, shape, data, strides); \
// 每个元素大小情况下的操作结束
break
// 初始化计数器 count
npy_intp count = 0;
// 根据元素大小 elsize 的不同,选择不同的 NONZERO_CASE 宏处理
switch(elsize) {
NONZERO_CASE(1, u8);
NONZERO_CASE(2, u16);
NONZERO_CASE(4, u32);
NONZERO_CASE(8, u64);
}
// 取消 NONZERO_CASE 宏的定义
// 结束多线程操作
NPY_END_THREADS;
// 返回计数值
return count;
/*
* Counts the number of True values in a raw boolean array. This
* is a low-overhead function which does no heap allocations.
*
* Returns -1 on error.
*/
NPY_NO_EXPORT NPY_GCC_OPT_3 npy_intp
count_boolean_trues(int ndim, char *data, npy_intp const *ashape, npy_intp const *astrides)
{
// 使用 count_nonzero_int 函数计算布尔数组中的 True 值数量
return count_nonzero_int(ndim, data, ashape, astrides, 1);
}
/*NUMPY_API
* Counts the number of non-zero elements in the array.
*
* Returns -1 on error.
*/
NPY_NO_EXPORT npy_intp
PyArray_CountNonzero(PyArrayObject *self)
{
PyArray_NonzeroFunc *nonzero;
char *data;
npy_intp stride, count;
npy_intp nonzero_count = 0;
int needs_api = 0;
PyArray_Descr *dtype;
// 获取数组的数据类型描述符
dtype = PyArray_DESCR(self);
/* Special low-overhead version specific to the boolean/int types */
// 如果数组对齐且是布尔或整数类型,则调用特定的低开销版本
if (PyArray_ISALIGNED(self) && (
PyDataType_ISBOOL(dtype) || PyDataType_ISINTEGER(dtype))) {
return count_nonzero_int(
PyArray_NDIM(self), PyArray_BYTES(self), PyArray_DIMS(self),
PyArray_STRIDES(self), dtype->elsize
);
}
// 获取非零元素计数的函数指针
nonzero = PyDataType_GetArrFuncs(PyArray_DESCR(self))->nonzero;
/* If it's a trivial one-dimensional loop, don't use an iterator */
// 如果是简单的一维循环,则不使用迭代器
if (PyArray_TRIVIALLY_ITERABLE(self)) {
// 检查是否需要 Python API 支持
needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI);
// 准备简单迭代器
PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride);
// 根据需要使用多线程处理
if (needs_api) {
while (count--) {
// 调用非零元素判定函数,统计非零元素数量
if (nonzero(data, self)) {
++nonzero_count;
}
// 检查是否有 Python 异常发生
if (PyErr_Occurred()) {
return -1;
}
// 移动数据指针到下一个元素
data += stride;
}
} else {
// 多线程处理非零元素判定
NPY_BEGIN_THREADS_THRESHOLDED(count);
while (count--) {
if (nonzero(data, self)) {
++nonzero_count;
}
data += stride;
}
NPY_END_THREADS;
}
// 返回统计到的非零元素数量
return nonzero_count;
}
/*
* If the array has size zero, return zero (the iterator rejects
* size zero arrays)
*/
// 如果数组大小为零,则直接返回零
if (PyArray_SIZE(self) == 0) {
return 0;
}
/*
* Otherwise create and use an iterator to count the nonzeros.
*/
// 创建迭代器来统计非零元素数量
iter = NpyIter_New(self, NPY_ITER_READONLY |
NPY_ITER_EXTERNAL_LOOP |
NPY_ITER_REFS_OK,
NPY_KEEPORDER, NPY_NO_CASTING,
NULL);
if (iter == NULL) {
return -1;
}
// 检查是否需要 Python API 支持
needs_api = NpyIter_IterationNeedsAPI(iter);
// 获取迭代器的下一步函数指针
iternext = NpyIter_GetIterNext(iter, NULL);
if (iternext == NULL) {
NpyIter_Deallocate(iter);
return -1;
}
// 开始多线程处理迭代器
NPY_BEGIN_THREADS_NDITER(iter);
dataptr = NpyIter_GetDataPtrArray(iter);
strideptr = NpyIter_GetInnerStrideArray(iter);
innersizeptr = NpyIter_GetInnerLoopSizePtr(iter);
/* 遍历所有元素以计算非零元素数量 */
do {
data = *dataptr;
stride = *strideptr;
count = *innersizeptr;
while (count--) {
if (nonzero(data, self)) {
++nonzero_count;
}
if (needs_api && PyErr_Occurred()) {
nonzero_count = -1;
goto finish;
}
data += stride;
}
} while(iternext(iter));
finish:
NPY_END_THREADS; // 结束线程
NpyIter_Deallocate(iter); // 释放迭代器资源
return nonzero_count; // 返回非零元素的数量
}
/*NUMPY_API
* Nonzero
*
* TODO: In NumPy 2.0, should make the iteration order a parameter.
*/
NPY_NO_EXPORT PyObject *
PyArray_Nonzero(PyArrayObject *self)
{
int i, ndim = PyArray_NDIM(self); // 获取数组的维度数
// 检查数组是否为零维,如果是,设置错误信息并返回 NULL
if (ndim == 0) {
char const* msg;
if (PyArray_ISBOOL(self)) {
msg =
"Calling nonzero on 0d arrays is not allowed. "
"Use np.atleast_1d(scalar).nonzero() instead. "
"If the context of this error is of the form "
"`arr[nonzero(cond)]`, just use `arr[cond]`.";
} else {
msg =
"Calling nonzero on 0d arrays is not allowed. "
"Use np.atleast_1d(scalar).nonzero() instead.";
}
PyErr_SetString(PyExc_ValueError, msg); // 设置错误类型和消息
return NULL; // 返回 NULL 表示错误
}
PyArrayObject *ret = NULL; // 结果数组对象指针
PyObject *ret_tuple; // 结果元组对象
npy_intp ret_dims[2]; // 结果数组的维度
PyArray_NonzeroFunc *nonzero; // 非零元素查找函数指针
PyArray_Descr *dtype; // 数组的数据类型描述符
npy_intp nonzero_count; // 非零元素的数量
npy_intp added_count = 0; // 已添加的计数
int needs_api; // 是否需要 Python API
int is_bool; // 数组是否为布尔类型
NpyIter *iter; // 迭代器对象指针
NpyIter_IterNextFunc *iternext; // 迭代器下一步函数指针
NpyIter_GetMultiIndexFunc *get_multi_index; // 获取多重索引函数指针
char **dataptr; // 数据指针数组
dtype = PyArray_DESCR(self); // 获取数组的数据类型描述符
nonzero = PyDataType_GetArrFuncs(dtype)->nonzero; // 获取非零元素查找函数
needs_api = PyDataType_FLAGCHK(dtype, NPY_NEEDS_PYAPI); // 检查是否需要 Python API
/*
* First count the number of non-zeros in 'self'.
*/
nonzero_count = PyArray_CountNonzero(self); // 计算数组中非零元素的数量
if (nonzero_count < 0) {
return NULL; // 如果计算出错,返回 NULL
}
is_bool = PyArray_ISBOOL(self); // 检查数组是否为布尔类型
/* Allocate the result as a 2D array */
ret_dims[0] = nonzero_count; // 第一维度是非零元素的数量
ret_dims[1] = ndim; // 第二维度是数组的维度数
ret = (PyArrayObject *)PyArray_NewFromDescr(
&PyArray_Type, PyArray_DescrFromType(NPY_INTP),
2, ret_dims, NULL, NULL,
0, NULL); // 从描述符创建一个新的数组对象作为结果
if (ret == NULL) {
return NULL; // 如果创建失败,返回 NULL
}
/* If it's a one-dimensional result, don't use an iterator */
if (ndim == 1) {
npy_intp * multi_index = (npy_intp *)PyArray_DATA(ret);
char * data = PyArray_BYTES(self);
npy_intp stride = PyArray_STRIDE(self, 0);
npy_intp count = PyArray_DIM(self, 0);
NPY_BEGIN_THREADS_DEF;
/* 无需处理 */
if (nonzero_count == 0) {
goto finish;
}
if (!needs_api) {
NPY_BEGIN_THREADS_THRESHOLDED(count);
}
/* 针对布尔类型避免函数调用 */
if (is_bool) {
/*
* 对于稀疏数据,使用快速的 memchr 变体,参见 gh-4370
* 在这种稀疏路径后的快速布尔计数比结合两个循环更快,即使对于更大的数组也是如此
*/
if (((double)nonzero_count / count) <= 0.1) {
npy_intp subsize;
npy_intp j = 0;
while (1) {
npy_memchr(data + j * stride, 0, stride, count - j,
&subsize, 1);
j += subsize;
if (j >= count) {
break;
}
*multi_index++ = j++;
}
}
/*
* 为了避免分支预测错误导致的性能下降,在这里使用无分支策略
*/
else {
npy_intp *multi_index_end = multi_index + nonzero_count;
npy_intp j = 0;
/* 手动展开循环以便于 GCC 和可能的其他编译器 */
while (multi_index + 4 < multi_index_end) {
*multi_index = j;
multi_index += data[0] != 0;
*multi_index = j + 1;
multi_index += data[stride] != 0;
*multi_index = j + 2;
multi_index += data[stride * 2] != 0;
*multi_index = j + 3;
multi_index += data[stride * 3] != 0;
data += stride * 4;
j += 4;
}
while (multi_index < multi_index_end) {
*multi_index = j;
multi_index += *data != 0;
data += stride;
++j;
}
}
}
else {
npy_intp j;
for (j = 0; j < count; ++j) {
if (nonzero(data, self)) {
if (++added_count > nonzero_count) {
break;
}
*multi_index++ = j;
}
if (needs_api && PyErr_Occurred()) {
break;
}
data += stride;
}
}
NPY_END_THREADS;
goto finish;
}
/*
* 以 C 顺序构建一个迭代器来跟踪多维索引。
*/
iter = NpyIter_New(self, NPY_ITER_READONLY |
NPY_ITER_MULTI_INDEX |
NPY_ITER_ZEROSIZE_OK |
NPY_ITER_REFS_OK,
NPY_CORDER, NPY_NO_CASTING,
NULL);
创建一个 Numpy 迭代器对象 `iter`,用于遍历数组元素。
if (iter == NULL) {
Py_DECREF(ret);
return NULL;
}
检查迭代器是否成功创建,如果创建失败,则释放之前分配的资源并返回空指针。
if (NpyIter_GetIterSize(iter) != 0) {
检查迭代器中的元素数量是否不为零,即数组不为空。
npy_intp * multi_index;
NPY_BEGIN_THREADS_DEF;
声明多索引数组 `multi_index` 和 Numpy 线程宏 `NPY_BEGIN_THREADS_DEF`。
/* Get the pointers for inner loop iteration */
iternext = NpyIter_GetIterNext(iter, NULL);
获取迭代器的下一个迭代函数 `iternext`,用于迭代内部循环。
if (iternext == NULL) {
NpyIter_Deallocate(iter);
Py_DECREF(ret);
return NULL;
}
如果获取迭代函数失败,则释放迭代器资源、释放之前分配的 Python 对象并返回空指针。
get_multi_index = NpyIter_GetGetMultiIndex(iter, NULL);
获取获取多索引函数 `get_multi_index`,用于获取多维数组中元素的索引。
if (get_multi_index == NULL) {
NpyIter_Deallocate(iter);
Py_DECREF(ret);
return NULL;
}
如果获取获取多索引函数失败,则释放迭代器资源、释放之前分配的 Python 对象并返回空指针。
needs_api = NpyIter_IterationNeedsAPI(iter);
检查迭代器是否需要 Python API 支持,并将结果存储在 `needs_api` 中。
NPY_BEGIN_THREADS_NDITER(iter);
开始 Numpy 线程化迭代器循环。
dataptr = NpyIter_GetDataPtrArray(iter);
获取数据指针数组 `dataptr`,用于访问数组元素数据。
multi_index = (npy_intp *)PyArray_DATA(ret);
将返回对象 `ret` 中的数据指针类型转换为 `npy_intp*` 类型,并赋值给 `multi_index`。
/* Get the multi-index for each non-zero element */
if (is_bool) {
/* avoid function call for bool */
do {
if (**dataptr != 0) {
get_multi_index(iter, multi_index);
multi_index += ndim;
}
} while(iternext(iter));
}
else {
do {
if (nonzero(*dataptr, self)) {
if (++added_count > nonzero_count) {
break;
}
get_multi_index(iter, multi_index);
multi_index += ndim;
}
if (needs_api && PyErr_Occurred()) {
break;
}
} while(iternext(iter));
}
根据数组元素的情况,获取每个非零元素的多索引:
- 如果数组是布尔类型,则直接检查元素值是否不为零,然后获取其多索引。
- 否则,调用 `nonzero` 函数检查元素是否非零,并根据需要获取多索引。同时检查是否需要 Python API 支持和是否发生了异常。
NPY_END_THREADS;
结束 Numpy 线程化迭代器循环。
}
NpyIter_Deallocate(iter);
循环结束后,释放 Numpy 迭代器资源。
finish:
// 检查是否有 Python 异常发生,如果有则清理返回空并释放之前创建的对象
if (PyErr_Occurred()) {
Py_DECREF(ret);
return NULL;
}
/* if executed `nonzero()` check for miscount due to side-effect */
// 如果执行了 `nonzero()` 函数检查由于副作用导致的计数错误
if (!is_bool && added_count != nonzero_count) {
PyErr_SetString(PyExc_RuntimeError,
"number of non-zero array elements "
"changed during function execution.");
Py_DECREF(ret);
return NULL;
}
// 创建一个包含 ndim 个元素的元组对象 ret_tuple
ret_tuple = PyTuple_New(ndim);
// 如果创建元组对象失败则清理返回空并释放之前创建的对象
if (ret_tuple == NULL) {
Py_DECREF(ret);
return NULL;
}
/* Create views into ret, one for each dimension */
// 为 ret 中的每个维度创建视图
for (i = 0; i < ndim; ++i) {
npy_intp stride = ndim * NPY_SIZEOF_INTP;
/* the result is an empty array, the view must point to valid memory */
// 如果结果是一个空数组,视图必须指向有效的内存
npy_intp data_offset = nonzero_count == 0 ? 0 : i * NPY_SIZEOF_INTP;
// 创建一个 PyArrayObject 类型的视图对象 view
PyArrayObject *view = (PyArrayObject *)PyArray_NewFromDescrAndBase(
Py_TYPE(ret), PyArray_DescrFromType(NPY_INTP),
1, &nonzero_count, &stride, PyArray_BYTES(ret) + data_offset,
PyArray_FLAGS(ret), (PyObject *)ret, (PyObject *)ret);
// 如果创建视图对象失败则清理返回空并释放之前创建的对象
if (view == NULL) {
Py_DECREF(ret);
Py_DECREF(ret_tuple);
return NULL;
}
// 将视图对象 view 添加到元组 ret_tuple 的第 i 个位置
PyTuple_SET_ITEM(ret_tuple, i, (PyObject *)view);
}
// 清理返回对象 ret,因为其引用已经被传递给视图对象
Py_DECREF(ret);
// 返回包含视图对象的元组 ret_tuple
return ret_tuple;
}
/*
* Gets a single item from the array, based on a single multi-index
* array of values, which must be of length PyArray_NDIM(self).
*/
NPY_NO_EXPORT PyObject *
PyArray_MultiIndexGetItem(PyArrayObject *self, const npy_intp *multi_index)
{
int idim, ndim = PyArray_NDIM(self);
char *data = PyArray_DATA(self);
npy_intp *shape = PyArray_SHAPE(self);
npy_intp *strides = PyArray_STRIDES(self);
/* Get the data pointer */
// 获取数据指针,根据多重索引 multi_index 访问数组元素
for (idim = 0; idim < ndim; ++idim) {
npy_intp shapevalue = shape[idim];
npy_intp ind = multi_index[idim];
// 检查并调整索引 ind,确保在有效范围内
if (check_and_adjust_index(&ind, shapevalue, idim, NULL) < 0) {
return NULL;
}
// 根据索引计算数据指针的偏移量
data += ind * strides[idim];
}
// 返回数组中指定位置的元素对象
return PyArray_GETITEM(self, data);
}
/*
* Sets a single item in the array, based on a single multi-index
* array of values, which must be of length PyArray_NDIM(self).
*
* Returns 0 on success, -1 on failure.
*/
NPY_NO_EXPORT int
PyArray_MultiIndexSetItem(PyArrayObject *self, const npy_intp *multi_index,
PyObject *obj)
{
int idim, ndim = PyArray_NDIM(self);
char *data = PyArray_DATA(self);
npy_intp *shape = PyArray_SHAPE(self);
npy_intp *strides = PyArray_STRIDES(self);
/* Get the data pointer */
// 获取数据指针,根据多重索引 multi_index 访问数组元素
for (idim = 0; idim < ndim; ++idim) {
npy_intp shapevalue = shape[idim];
npy_intp ind = multi_index[idim];
// 检查并调整索引 ind,确保在有效范围内
if (check_and_adjust_index(&ind, shapevalue, idim, NULL) < 0) {
return -1;
}
// 根据索引计算数据指针的偏移量
data += ind * strides[idim];
}
// 将对象 obj 设置到数组中指定位置,并返回操作结果
return PyArray_Pack(PyArray_DESCR(self), data, obj);
}