NumPy-源码解析-十六-

76 阅读37分钟

NumPy 源码解析(十六)

.\numpy\numpy\lib\tests\test_regression.py

import os  # 导入操作系统模块

import numpy as np  # 导入NumPy库
from numpy.testing import (  # 导入NumPy测试模块中的函数和类
    assert_, assert_equal, assert_array_equal, assert_array_almost_equal,
    assert_raises, _assert_valid_refcount,
    )
import pytest  # 导入pytest测试框架


class TestRegression:
    def test_poly1d(self):
        # Ticket #28
        # 测试 np.poly1d 函数的行为,验证多项式减法
        assert_equal(np.poly1d([1]) - np.poly1d([1, 0]),
                     np.poly1d([-1, 1]))

    def test_cov_parameters(self):
        # Ticket #91
        # 创建随机矩阵 x,并复制到 y
        x = np.random.random((3, 3))
        y = x.copy()
        # 分别计算 x 和 y 的协方差矩阵,验证结果一致性
        np.cov(x, rowvar=True)
        np.cov(y, rowvar=False)
        assert_array_equal(x, y)

    def test_mem_digitize(self):
        # Ticket #95
        # 循环进行数字化处理,验证 np.digitize 函数的内存使用情况
        for i in range(100):
            np.digitize([1, 2, 3, 4], [1, 3])
            np.digitize([0, 1, 2, 3, 4], [1, 3])

    def test_unique_zero_sized(self):
        # Ticket #205
        # 测试空数组的唯一值,验证 np.unique 函数的行为
        assert_array_equal([], np.unique(np.array([])))

    def test_mem_vectorise(self):
        # Ticket #325
        # 使用 np.vectorize 函数创建向量化函数,并验证其内存使用情况
        vt = np.vectorize(lambda *args: args)
        vt(np.zeros((1, 2, 1)), np.zeros((2, 1, 1)), np.zeros((1, 1, 2)))
        vt(np.zeros((1, 2, 1)), np.zeros((2, 1, 1)), np.zeros((1,
           1, 2)), np.zeros((2, 2)))

    def test_mgrid_single_element(self):
        # Ticket #339
        # 验证 np.mgrid 在只有一个元素时的行为
        assert_array_equal(np.mgrid[0:0:1j], [0])
        assert_array_equal(np.mgrid[0:0], [])

    def test_refcount_vectorize(self):
        # Ticket #378
        # 定义函数 p,并使用 np.vectorize 进行向量化处理,验证其引用计数
        def p(x, y):
            return 123
        v = np.vectorize(p)
        _assert_valid_refcount(v)

    def test_poly1d_nan_roots(self):
        # Ticket #396
        # 创建具有 NaN 根的多项式,验证 np.poly1d 函数的异常处理
        p = np.poly1d([np.nan, np.nan, 1], r=False)
        assert_raises(np.linalg.LinAlgError, getattr, p, "r")

    def test_mem_polymul(self):
        # Ticket #448
        # 验证空列表输入时 np.polymul 的内存使用情况
        np.polymul([], [1.])

    def test_mem_string_concat(self):
        # Ticket #469
        # 创建空数组 x,并向其附加字符串,验证 np.append 函数的行为
        x = np.array([])
        np.append(x, 'asdasd\tasdasd')

    def test_poly_div(self):
        # Ticket #553
        # 创建两个多项式 u 和 v,并验证 np.polydiv 函数的行为
        u = np.poly1d([1, 2, 3])
        v = np.poly1d([1, 2, 3, 4, 5])
        q, r = np.polydiv(u, v)
        assert_equal(q*v + r, u)

    def test_poly_eq(self):
        # Ticket #554
        # 创建两个多项式 x 和 y,并验证其相等性
        x = np.poly1d([1, 2, 3])
        y = np.poly1d([3, 4])
        assert_(x != y)
        assert_(x == x)
    def test_polyfit_build(self):
        # Ticket #628
        # 参考值,多项式拟合的期望系数数组
        ref = [-1.06123820e-06, 5.70886914e-04, -1.13822012e-01,
               9.95368241e+00, -3.14526520e+02]
        # x 数据点
        x = [90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103,
             104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115,
             116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 129,
             130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
             146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157,
             158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
             170, 171, 172, 173, 174, 175, 176]
        # y 数据点
        y = [9.0, 3.0, 7.0, 4.0, 4.0, 8.0, 6.0, 11.0, 9.0, 8.0, 11.0, 5.0,
             6.0, 5.0, 9.0, 8.0, 6.0, 10.0, 6.0, 10.0, 7.0, 6.0, 6.0, 6.0,
             13.0, 4.0, 9.0, 11.0, 4.0, 5.0, 8.0, 5.0, 7.0, 7.0, 6.0, 12.0,
             7.0, 7.0, 9.0, 4.0, 12.0, 6.0, 6.0, 4.0, 3.0, 9.0, 8.0, 8.0,
             6.0, 7.0, 9.0, 10.0, 6.0, 8.0, 4.0, 7.0, 7.0, 10.0, 8.0, 8.0,
             6.0, 3.0, 8.0, 4.0, 5.0, 7.0, 8.0, 6.0, 6.0, 4.0, 12.0, 9.0,
             8.0, 8.0, 8.0, 6.0, 7.0, 4.0, 4.0, 5.0, 7.0]
        # 测试多项式拟合
        tested = np.polyfit(x, y, 4)
        # 断言拟合结果与参考值接近
        assert_array_almost_equal(ref, tested)

    def test_polydiv_type(self):
        # 使 polydiv 支持复数类型
        msg = "Wrong type, should be complex"
        x = np.ones(3, dtype=complex)
        # 对复数类型进行多项式除法
        q, r = np.polydiv(x, x)
        # 断言结果的数据类型是复数
        assert_(q.dtype == complex, msg)
        msg = "Wrong type, should be float"
        x = np.ones(3, dtype=int)
        # 对整数类型进行多项式除法
        q, r = np.polydiv(x, x)
        # 断言结果的数据类型是浮点数
        assert_(q.dtype == float, msg)

    def test_histogramdd_too_many_bins(self):
        # Ticket 928.
        # 检查 np.histogramdd 处理过多的 bins 时是否引发 ValueError
        assert_raises(ValueError, np.histogramdd, np.ones((1, 10)), bins=2**10)

    def test_polyint_type(self):
        # Ticket #944
        msg = "Wrong type, should be complex"
        x = np.ones(3, dtype=complex)
        # 对复数类型进行积分操作
        assert_(np.polyint(x).dtype == complex, msg)
        msg = "Wrong type, should be float"
        x = np.ones(3, dtype=int)
        # 对整数类型进行积分操作
        assert_(np.polyint(x).dtype == float, msg)

    def test_ndenumerate_crash(self):
        # Ticket 1140
        # 不应该导致崩溃的测试:对空数组使用 np.ndenumerate
        list(np.ndenumerate(np.array([[]])))

    def test_large_fancy_indexing(self):
        # 大规模的 fancy indexing,在 64 位系统上可能会失败
        nbits = np.dtype(np.intp).itemsize * 8
        thesize = int((2**nbits)**(1.0/5.0)+1)

        def dp():
            n = 3
            a = np.ones((n,)*5)
            i = np.random.randint(0, n, size=thesize)
            # 在数组 a 上进行大规模 fancy indexing
            a[np.ix_(i, i, i, i, i)] = 0

        def dp2():
            n = 3
            a = np.ones((n,)*5)
            i = np.random.randint(0, n, size=thesize)
            # 尝试进行大规模 fancy indexing
            a[np.ix_(i, i, i, i, i)]

        # 断言大规模 fancy indexing 会引发 ValueError
        assert_raises(ValueError, dp)
        assert_raises(ValueError, dp2)

    def test_void_coercion(self):
        dt = np.dtype([('a', 'f4'), ('b', 'i4')])
        x = np.zeros((1,), dt)
        # 测试结构化数组的拼接
        assert_(np.r_[x, x].dtype == dt)
    def test_include_dirs(self):
        # 检查函数 get_include 是否包含合理的内容,作为健全性检查
        # 相关于 ticket #1405 的部分
        include_dirs = [np.get_include()]
        # 遍历 include_dirs 中的每个路径
        for path in include_dirs:
            # 断言路径是字符串类型
            assert_(isinstance(path, str))
            # 断言路径不为空字符串
            assert_(path != '')

    def test_polyder_return_type(self):
        # Ticket #1249 的测试
        # 断言 np.polyder 函数返回的对象类型为 np.poly1d
        assert_(isinstance(np.polyder(np.poly1d([1]), 0), np.poly1d))
        # 断言 np.polyder 函数对于列表输入返回的对象类型为 np.ndarray
        assert_(isinstance(np.polyder([1], 0), np.ndarray))
        # 断言 np.polyder 函数进行一阶导数操作后返回的对象类型为 np.poly1d
        assert_(isinstance(np.polyder(np.poly1d([1]), 1), np.poly1d))
        # 断言 np.polyder 函数对于列表输入进行一阶导数操作返回的对象类型为 np.ndarray
        assert_(isinstance(np.polyder([1], 1), np.ndarray))

    def test_append_fields_dtype_list(self):
        # Ticket #1676 的测试
        from numpy.lib.recfunctions import append_fields

        # 创建一个基础数组
        base = np.array([1, 2, 3], dtype=np.int32)
        # 字段名称列表
        names = ['a', 'b', 'c']
        # 数据是单位矩阵的整数形式
        data = np.eye(3).astype(np.int32)
        # 数据类型列表
        dlist = [np.float64, np.int32, np.int32]
        try:
            # 尝试使用 append_fields 函数
            append_fields(base, names, data, dlist)
        except Exception:
            # 如果出现异常则抛出断言错误
            raise AssertionError()

    def test_loadtxt_fields_subarrays(self):
        # 对 ticket #1936 的测试
        from io import StringIO

        # 定义结构化数据类型
        dt = [("a", 'u1', 2), ("b", 'u1', 2)]
        # 使用 loadtxt 从字符串流读取数据到结构化数组
        x = np.loadtxt(StringIO("0 1 2 3"), dtype=dt)
        # 断言 x 的内容与给定的数组数据相等
        assert_equal(x, np.array([((0, 1), (2, 3))], dtype=dt))

        # 更复杂的结构化数据类型
        dt = [("a", [("a", 'u1', (1, 3)), ("b", 'u1')])]
        x = np.loadtxt(StringIO("0 1 2 3"), dtype=dt)
        assert_equal(x, np.array([(((0, 1, 2), 3),)], dtype=dt))

        # 具有多维形状的结构化数据类型
        dt = [("a", 'u1', (2, 2))]
        x = np.loadtxt(StringIO("0 1 2 3"), dtype=dt)
        assert_equal(x, np.array([(((0, 1), (2, 3)),)], dtype=dt))

        # 更复杂的多维结构化数据类型
        dt = [("a", 'u1', (2, 3, 2))]
        x = np.loadtxt(StringIO("0 1 2 3 4 5 6 7 8 9 10 11"), dtype=dt)
        data = [((((0, 1), (2, 3), (4, 5)), ((6, 7), (8, 9), (10, 11))),)]
        assert_equal(x, np.array(data, dtype=dt))

    def test_nansum_with_boolean(self):
        # 对 gh-2978 的测试
        # 创建一个布尔类型的零数组
        a = np.zeros(2, dtype=bool)
        try:
            # 尝试使用 np.nansum 函数
            np.nansum(a)
        except Exception:
            # 如果出现异常则抛出断言错误
            raise AssertionError()

    def test_py3_compat(self):
        # 对 gh-2561 的测试
        # 测试在 Python 3 中是否绕过了旧式类测试
        class C():
            """Python 2 中的旧式类,在 Python 3 中是普通类"""
            pass

        # 打开空设备文件,用于输出
        out = open(os.devnull, 'w')
        try:
            # 尝试使用 np.info 函数
            np.info(C(), output=out)
        except AttributeError:
            # 如果出现属性错误则抛出断言错误
            raise AssertionError()
        finally:
            # 关闭输出文件
            out.close()

.\numpy\numpy\lib\tests\test_shape_base.py

# 导入所需的库和模块
import numpy as np  # 导入 NumPy 库,用于数值计算
import functools  # 导入 functools 模块,用于函数操作
import sys  # 导入 sys 模块,用于系统相关操作
import pytest  # 导入 pytest 模块,用于编写和运行测试

from numpy import (  # 导入 NumPy 中的多个函数和类
    apply_along_axis, apply_over_axes, array_split, split, hsplit, dsplit,
    vsplit, dstack, column_stack, kron, tile, expand_dims, take_along_axis,
    put_along_axis
    )
from numpy.exceptions import AxisError  # 导入 AxisError 异常类,用于处理轴异常
from numpy.testing import (  # 导入 NumPy 测试相关的函数和类
    assert_, assert_equal, assert_array_equal, assert_raises, assert_warns
    )


IS_64BIT = sys.maxsize > 2**32  # 判断系统位数是否为 64 位


def _add_keepdims(func):
    """ hack in keepdims behavior into a function taking an axis """
    @functools.wraps(func)
    def wrapped(a, axis, **kwargs):
        # 调用原始函数并添加 keepdims 行为
        res = func(a, axis=axis, **kwargs)
        if axis is None:
            axis = 0  # 如果结果是标量,可以在任意轴插入
        return np.expand_dims(res, axis=axis)  # 返回带有 keepdims 的结果数组
    return wrapped


class TestTakeAlongAxis:
    def test_argequivalent(self):
        """ Test it translates from arg<func> to <func> """
        from numpy.random import rand  # 从 NumPy 随机模块导入 rand 函数
        a = rand(3, 4, 5)  # 创建一个形状为 (3, 4, 5) 的随机数组 a

        funcs = [
            (np.sort, np.argsort, dict()),  # 元组列表,包含排序和排序索引函数
            (_add_keepdims(np.min), _add_keepdims(np.argmin), dict()),  # 使用带有 keepdims 的最小和最小索引函数
            (_add_keepdims(np.max), _add_keepdims(np.argmax), dict()),  # 使用带有 keepdims 的最大和最大索引函数
            #(np.partition, np.argpartition, dict(kth=2)),  # 部分排序和部分排序索引函数
        ]

        for func, argfunc, kwargs in funcs:  # 遍历函数列表
            for axis in list(range(a.ndim)) + [None]:  # 遍历数组的所有轴和 None
                a_func = func(a, axis=axis, **kwargs)  # 应用函数 func 到数组 a
                ai_func = argfunc(a, axis=axis, **kwargs)  # 应用函数 argfunc 到数组 a
                assert_equal(a_func, take_along_axis(a, ai_func, axis=axis))  # 断言函数的输出与 take_along_axis 的输出相等

    def test_invalid(self):
        """ Test it errors when indices has too few dimensions """
        a = np.ones((10, 10))  # 创建一个全为 1 的 10x10 数组 a
        ai = np.ones((10, 2), dtype=np.intp)  # 创建一个全为 1 的形状为 (10, 2) 的整数数组 ai

        # 确保正常工作
        take_along_axis(a, ai, axis=1)

        # 索引维度不足
        assert_raises(ValueError, take_along_axis, a, np.array(1), axis=1)
        # 不允许布尔数组
        assert_raises(IndexError, take_along_axis, a, ai.astype(bool), axis=1)
        # 不允许浮点数数组
        assert_raises(IndexError, take_along_axis, a, ai.astype(float), axis=1)
        # 无效的轴
        assert_raises(AxisError, take_along_axis, a, ai, axis=10)
        # 无效的索引
        assert_raises(ValueError, take_along_axis, a, ai, axis=None)

    def test_empty(self):
        """ Test everything is ok with empty results, even with inserted dims """
        a  = np.ones((3, 4, 5))  # 创建一个全为 1 的形状为 (3, 4, 5) 的数组 a
        ai = np.ones((3, 0, 5), dtype=np.intp)  # 创建一个形状为 (3, 0, 5) 的整数数组 ai

        actual = take_along_axis(a, ai, axis=1)  # 在轴 1 上执行 take_along_axis 操作
        assert_equal(actual.shape, ai.shape)  # 断言实际输出的形状与 ai 的形状相等

    def test_broadcast(self):
        """ Test that non-indexing dimensions are broadcast in both directions """
        a  = np.ones((3, 4, 1))  # 创建一个形状为 (3, 4, 1) 的全为 1 的数组 a
        ai = np.ones((1, 2, 5), dtype=np.intp)  # 创建一个形状为 (1, 2, 5) 的整数数组 ai
        actual = take_along_axis(a, ai, axis=1)  # 在轴 1 上执行 take_along_axis 操作
        assert_equal(actual.shape, (3, 2, 5))  # 断言实际输出的形状为 (3, 2, 5)

class TestPutAlongAxis:
    def test_replace_max(self):
        # 创建一个基础的二维 NumPy 数组
        a_base = np.array([[10, 30, 20], [60, 40, 50]])

        # 对数组的每个维度和整体进行迭代测试
        for axis in list(range(a_base.ndim)) + [None]:
            # 在循环中对数组进行深拷贝,以便进行修改
            a = a_base.copy()

            # 找到最大值的索引,并保持维度信息
            i_max = _add_keepdims(np.argmax)(a, axis=axis)
            # 使用指定的小值替换最大值
            put_along_axis(a, i_max, -99, axis=axis)

            # 寻找新的最小值索引,它应当等于之前找到的最大值索引
            i_min = _add_keepdims(np.argmin)(a, axis=axis)

            # 断言新的最小值索引等于最大值索引
            assert_equal(i_min, i_max)

    def test_broadcast(self):
        """ Test that non-indexing dimensions are broadcast in both directions """
        # 创建一个三维全为1的 NumPy 数组
        a = np.ones((3, 4, 1))
        # 创建一个用于索引的二维整数数组,确保在范围内
        ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4
        # 使用 put_along_axis 函数在指定轴向上设置值
        put_along_axis(a, ai, 20, axis=1)
        # 断言使用 take_along_axis 函数获取的值等于设定的值
        assert_equal(take_along_axis(a, ai, axis=1), 20)

    def test_invalid(self):
        """ Test invalid inputs """
        # 创建一个基础的二维 NumPy 数组
        a_base = np.array([[10, 30, 20], [60, 40, 50]])
        # 创建用于索引和值的数组
        indices = np.array([[0], [1]])
        values = np.array([[2], [1]])

        # 对数组进行深拷贝作为基础
        a = a_base.copy()
        # 使用 put_along_axis 函数在指定轴向上设置值
        put_along_axis(a, indices, values, axis=0)
        # 断言数组的所有元素等于预期的值
        assert np.all(a == [[2, 2, 2], [1, 1, 1]])

        # 测试无效的索引输入
        a = a_base.copy()
        # 使用 assert_raises 检测是否抛出 ValueError 异常
        with assert_raises(ValueError) as exc:
            put_along_axis(a, indices, values, axis=None)
        # 断言异常信息包含特定文本
        assert "single dimension" in str(exc.exception)
class TestApplyAlongAxis:
    # 测试类,用于测试 apply_along_axis 函数的不同用例

    def test_simple(self):
        # 简单测试用例:创建一个 20x10 的双精度浮点数数组 a,全为 1
        a = np.ones((20, 10), 'd')
        # 断言应用 len 函数在轴 0 上的结果与期望的数组相等
        assert_array_equal(
            apply_along_axis(len, 0, a), len(a)*np.ones(a.shape[1]))

    def test_simple101(self):
        # 另一个简单测试用例:创建一个 10x101 的双精度浮点数数组 a,全为 1
        a = np.ones((10, 101), 'd')
        # 断言应用 len 函数在轴 0 上的结果与期望的数组相等
        assert_array_equal(
            apply_along_axis(len, 0, a), len(a)*np.ones(a.shape[1]))

    def test_3d(self):
        # 3D 数组测试用例:创建一个形状为 (3, 3, 3) 的数组 a,其中元素为 0 到 26
        a = np.arange(27).reshape((3, 3, 3))
        # 断言应用 np.sum 函数在轴 0 上的结果与期望的数组相等
        assert_array_equal(apply_along_axis(np.sum, 0, a),
                           [[27, 30, 33], [36, 39, 42], [45, 48, 51]])

    def test_preserve_subclass(self):
        # 保留子类测试用例:定义一个函数 double,返回数组行的两倍
        def double(row):
            return row * 2
        
        # 定义一个继承自 np.ndarray 的子类 MyNDArray
        class MyNDArray(np.ndarray):
            pass
        
        # 创建一个 MyNDArray 类型的数组 m,形状为 (2, 2),元素为 0 到 3
        m = np.array([[0, 1], [2, 3]]).view(MyNDArray)
        # 期望的数组 expected,元素为 [[0, 2], [4, 6]]
        expected = np.array([[0, 2], [4, 6]]).view(MyNDArray)

        # 对数组 m 应用 double 函数在轴 0 上,断言结果是 MyNDArray 类型,并且与期望数组相等
        result = apply_along_axis(double, 0, m)
        assert_(isinstance(result, MyNDArray))
        assert_array_equal(result, expected)

        # 对数组 m 应用 double 函数在轴 1 上,断言结果是 MyNDArray 类型,并且与期望数组相等
        result = apply_along_axis(double, 1, m)
        assert_(isinstance(result, MyNDArray))
        assert_array_equal(result, expected)

    def test_subclass(self):
        # 子类测试用例:定义一个最小的子类 MinimalSubclass,继承自 np.ndarray
        class MinimalSubclass(np.ndarray):
            data = 1
        
        # 定义一个函数 minimal_function,返回数组的 data 属性
        def minimal_function(array):
            return array.data
        
        # 创建一个 MinimalSubclass 类型的数组 a,形状为 (6, 3),元素全为 0
        a = np.zeros((6, 3)).view(MinimalSubclass)

        # 断言应用 minimal_function 函数在轴 0 上的结果与期望的数组相等,期望数组为 [1, 1, 1]
        assert_array_equal(
            apply_along_axis(minimal_function, 0, a), np.array([1, 1, 1])
        )

    def test_scalar_array(self, cls=np.ndarray):
        # 标量数组测试用例:创建一个 6x3 的数组 a,元素全为 1,视图类型为 cls
        a = np.ones((6, 3)).view(cls)
        # 对数组 a 应用 np.sum 函数在轴 0 上,保存结果到 res
        res = apply_along_axis(np.sum, 0, a)
        # 断言 res 的类型是 cls,并且与期望的数组相等,期望数组为 [6, 6, 6]
        assert_(isinstance(res, cls))
        assert_array_equal(res, np.array([6, 6, 6]).view(cls))

    def test_0d_array(self, cls=np.ndarray):
        # 零维数组测试用例:定义一个函数 sum_to_0d,对一维数组求和,并返回相同类的零维数组
        def sum_to_0d(x):
            """ Sum x, returning a 0d array of the same class """
            assert_equal(x.ndim, 1)
            return np.squeeze(np.sum(x, keepdims=True))
        
        # 创建一个 6x3 的数组 a,元素全为 1,视图类型为 cls
        a = np.ones((6, 3)).view(cls)
        
        # 对数组 a 应用 sum_to_0d 函数在轴 0 上,保存结果到 res
        res = apply_along_axis(sum_to_0d, 0, a)
        # 断言 res 的类型是 cls,并且与期望的数组相等,期望数组为 [6, 6, 6]
        assert_(isinstance(res, cls))
        assert_array_equal(res, np.array([6, 6, 6]).view(cls))

        # 对数组 a 应用 sum_to_0d 函数在轴 1 上,保存结果到 res
        res = apply_along_axis(sum_to_0d, 1, a)
        # 断言 res 的类型是 cls,并且与期望的数组相等,期望数组为 [3, 3, 3, 3, 3, 3]
        assert_(isinstance(res, cls))
        assert_array_equal(res, np.array([3, 3, 3, 3, 3, 3]).view(cls))
    def test_axis_insertion(self, cls=np.ndarray):
        # 定义一个函数 f1to2,从输入向量 x 生成一个非对称非方阵
        def f1to2(x):
            """produces an asymmetric non-square matrix from x"""
            # 断言 x 的维度为 1
            assert_equal(x.ndim, 1)
            # 返回 x 的逆序乘以 x 的转置,并将结果视图化为指定类别 cls
            return (x[::-1] * x[1:,None]).view(cls)

        # 创建一个 6x3 的二维数组 a2d
        a2d = np.arange(6*3).reshape((6, 3))

        # 对第一个轴进行二维插入
        actual = apply_along_axis(f1to2, 0, a2d)
        # 期望结果是对每列应用 f1to2 函数并堆叠,最终视图化为指定类别 cls
        expected = np.stack([
            f1to2(a2d[:,i]) for i in range(a2d.shape[1])
        ], axis=-1).view(cls)
        # 断言实际结果和期望结果的类型相同
        assert_equal(type(actual), type(expected))
        # 断言实际结果和期望结果相等
        assert_equal(actual, expected)

        # 对最后一个轴进行二维插入
        actual = apply_along_axis(f1to2, 1, a2d)
        # 期望结果是对每行应用 f1to2 函数并堆叠,最终视图化为指定类别 cls
        expected = np.stack([
            f1to2(a2d[i,:]) for i in range(a2d.shape[0])
        ], axis=0).view(cls)
        # 断言实际结果和期望结果的类型相同
        assert_equal(type(actual), type(expected))
        # 断言实际结果和期望结果相等
        assert_equal(actual, expected)

        # 对中间轴进行三维插入
        a3d = np.arange(6*5*3).reshape((6, 5, 3))

        actual = apply_along_axis(f1to2, 1, a3d)
        # 期望结果是对每个深度切片应用 f1to2 函数并堆叠,最终视图化为指定类别 cls
        expected = np.stack([
            np.stack([
                f1to2(a3d[i,:,j]) for i in range(a3d.shape[0])
            ], axis=0)
            for j in range(a3d.shape[2])
        ], axis=-1).view(cls)
        # 断言实际结果和期望结果的类型相同
        assert_equal(type(actual), type(expected))
        # 断言实际结果和期望结果相等
        assert_equal(actual, expected)

    def test_subclass_preservation(self):
        # 定义一个简单的子类 MinimalSubclass 继承自 np.ndarray
        class MinimalSubclass(np.ndarray):
            pass
        # 分别测试标量数组、0维数组和轴插入函数在 MinimalSubclass 下的行为
        self.test_scalar_array(MinimalSubclass)
        self.test_0d_array(MinimalSubclass)
        self.test_axis_insertion(MinimalSubclass)

    def test_axis_insertion_ma(self):
        # 定义一个函数 f1to2,从输入向量 x 生成一个非对称非方阵的掩码数组
        def f1to2(x):
            """produces an asymmetric non-square matrix from x"""
            # 断言 x 的维度为 1
            assert_equal(x.ndim, 1)
            # 计算 x 的逆序乘以 x 的转置,将结果进行掩码处理(余数为 0 的位置被掩盖)
            res = x[::-1] * x[1:,None]
            return np.ma.masked_where(res%5==0, res)
        # 创建一个 6x3 的二维数组 a
        a = np.arange(6*3).reshape((6, 3))
        # 应用 f1to2 函数到 a 的每列,期望结果是一个掩码数组
        res = apply_along_axis(f1to2, 0, a)
        # 断言 res 的类型为 np.ma.masked_array
        assert_(isinstance(res, np.ma.masked_array))
        # 断言 res 的维度为 3
        assert_equal(res.ndim, 3)
        # 检查每个深度切片的掩码与对应列的 f1to2 结果的掩码相匹配
        assert_array_equal(res[:,:,0].mask, f1to2(a[:,0]).mask)
        assert_array_equal(res[:,:,1].mask, f1to2(a[:,1]).mask)
        assert_array_equal(res[:,:,2].mask, f1to2(a[:,2]).mask)

    def test_tuple_func1d(self):
        # 定义一个简单的函数 sample_1d,交换输入向量 x 的第一个和第二个元素
        def sample_1d(x):
            return x[1], x[0]
        # 对输入数组 [[1, 2], [3, 4]] 应用 sample_1d 函数,期望结果是 [[2, 1], [4, 3]]
        res = np.apply_along_axis(sample_1d, 1, np.array([[1, 2], [3, 4]]))
        # 断言 res 等于期望的结果数组
        assert_array_equal(res, np.array([[2, 1], [4, 3]]))
    def test_empty(self):
        # 定义一个永远不会被调用的函数,用于测试 apply_along_axis
        def never_call(x):
            assert_(False) # 应该永远不会执行到这里

        # 创建一个空的 numpy 数组
        a = np.empty((0, 0))
        # 测试在维度 0 上调用 apply_along_axis 是否会引发 ValueError 异常
        assert_raises(ValueError, np.apply_along_axis, never_call, 0, a)
        # 测试在维度 1 上调用 apply_along_axis 是否会引发 ValueError 异常
        assert_raises(ValueError, np.apply_along_axis, never_call, 1, a)

        # 但是在某些非零维度情况下,也可以正常工作
        # 定义一个将空数组转换为 1 的函数
        def empty_to_1(x):
            assert_(len(x) == 0)
            return 1

        # 创建一个具有 10 行和 0 列的空 numpy 数组
        a = np.empty((10, 0))
        # 在每行上应用 empty_to_1 函数,预期结果应为全为 1 的数组
        actual = np.apply_along_axis(empty_to_1, 1, a)
        assert_equal(actual, np.ones(10))
        # 测试在维度 0 上调用 apply_along_axis 是否会引发 ValueError 异常
        assert_raises(ValueError, np.apply_along_axis, empty_to_1, 0, a)

    def test_with_iterable_object(self):
        # 来自问题 5248
        # 创建一个包含集合的多维 numpy 数组
        d = np.array([
            [{1, 11}, {2, 22}, {3, 33}],
            [{4, 44}, {5, 55}, {6, 66}]
        ])
        # 在维度 0 上应用 lambda 函数,该函数执行集合的并集操作
        actual = np.apply_along_axis(lambda a: set.union(*a), 0, d)
        # 预期的结果数组,每个元素为合并对应位置集合的结果
        expected = np.array([{1, 11, 4, 44}, {2, 22, 5, 55}, {3, 33, 6, 66}])

        # 断言实际输出与预期输出相等
        assert_equal(actual, expected)

        # 问题 8642 - assert_equal 无法检测到此问题!
        # 遍历实际输出的每个元素,检查其类型是否与预期输出的类型相同
        for i in np.ndindex(actual.shape):
            assert_equal(type(actual[i]), type(expected[i]))
class TestApplyOverAxes:
    # 测试 apply_over_axes 函数的功能
    def test_simple(self):
        # 创建一个 2x3x4 的数组
        a = np.arange(24).reshape(2, 3, 4)
        # 对数组 a 沿指定轴向应用 np.sum 函数
        aoa_a = apply_over_axes(np.sum, a, [0, 2])
        # 断言 aoa_a 的结果与给定的数组相等
        assert_array_equal(aoa_a, np.array([[[60], [92], [124]]]))


class TestExpandDims:
    # 测试 expand_dims 函数的功能
    def test_functionality(self):
        s = (2, 3, 4, 5)
        a = np.empty(s)
        # 遍历可能的轴向范围
        for axis in range(-5, 4):
            # 在数组 a 上扩展维度
            b = expand_dims(a, axis)
            # 断言扩展后的维度在指定轴上为 1
            assert_(b.shape[axis] == 1)
            # 断言压缩后的数组形状与原始形状相同
            assert_(np.squeeze(b).shape == s)

    # 测试在元组形式下的多轴扩展
    def test_axis_tuple(self):
        a = np.empty((3, 3, 3))
        # 断言使用元组指定轴向进行扩展后的形状
        assert np.expand_dims(a, axis=(0, 1, 2)).shape == (1, 1, 1, 3, 3, 3)
        assert np.expand_dims(a, axis=(0, -1, -2)).shape == (1, 3, 3, 3, 1, 1)
        assert np.expand_dims(a, axis=(0, 3, 5)).shape == (1, 3, 3, 1, 3, 1)
        assert np.expand_dims(a, axis=(0, -3, -5)).shape == (1, 1, 3, 1, 3, 3)

    # 测试超出范围的轴向
    def test_axis_out_of_range(self):
        s = (2, 3, 4, 5)
        a = np.empty(s)
        # 断言当轴向超出数组维度范围时引发 AxisError
        assert_raises(AxisError, expand_dims, a, -6)
        assert_raises(AxisError, expand_dims, a, 5)

        a = np.empty((3, 3, 3))
        # 断言当轴向以元组形式且包含超出范围的索引时引发 AxisError
        assert_raises(AxisError, expand_dims, a, (0, -6))
        assert_raises(AxisError, expand_dims, a, (0, 5))

    # 测试重复指定的轴向
    def test_repeated_axis(self):
        a = np.empty((3, 3, 3))
        # 断言当重复指定相同轴向时引发 ValueError
        assert_raises(ValueError, expand_dims, a, axis=(1, 1))

    # 测试子类化的情况
    def test_subclasses(self):
        a = np.arange(10).reshape((2, 5))
        a = np.ma.array(a, mask=a % 3 == 0)

        # 对子类化的数组进行维度扩展
        expanded = np.expand_dims(a, axis=1)
        # 断言扩展后的对象仍然是 MaskedArray 类的实例
        assert_(isinstance(expanded, np.ma.MaskedArray))
        # 断言扩展后的数组形状
        assert_equal(expanded.shape, (2, 1, 5))
        # 断言扩展后的 mask 属性形状与数组形状相同
        assert_equal(expanded.mask.shape, (2, 1, 5))


class TestArraySplit:
    # 测试在分割数组时指定零分割点的情况
    def test_integer_0_split(self):
        a = np.arange(10)
        # 断言当指定分割点为零时,引发 ValueError
        assert_raises(ValueError, array_split, a, 0)
    # 定义一个测试方法,用于测试数组分割函数的不同情况
    def test_integer_split(self):
        # 创建一个包含 0 到 9 的数组
        a = np.arange(10)
        # 调用数组分割函数,将数组 a 分割成1个子数组
        res = array_split(a, 1)
        # 期望的结果是包含数组 a 整体的列表
        desired = [np.arange(10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成2个子数组
        res = array_split(a, 2)
        # 期望的结果是分别包含数组 a 的前5个元素和后5个元素的列表
        desired = [np.arange(5), np.arange(5, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成3个子数组
        res = array_split(a, 3)
        # 期望的结果是分别包含数组 a 的前4个元素、中间3个元素和最后3个元素的列表
        desired = [np.arange(4), np.arange(4, 7), np.arange(7, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成4个子数组
        res = array_split(a, 4)
        # 期望的结果是分别包含数组 a 的前3个元素、中间3个元素、接着的2个元素和最后2个元素的列表
        desired = [np.arange(3), np.arange(3, 6), np.arange(6, 8),
                   np.arange(8, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成5个子数组
        res = array_split(a, 5)
        # 期望的结果是分别包含数组 a 的前2个元素、接下来的2个元素、接着的2个元素、再接着的2个元素和最后2个元素的列表
        desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
                   np.arange(6, 8), np.arange(8, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成6个子数组
        res = array_split(a, 6)
        # 期望的结果是分别包含数组 a 的前2个元素、接下来的2个元素、接着的2个元素、再接着的2个元素、再接着的1个元素和最后1个元素的列表
        desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
                   np.arange(6, 8), np.arange(8, 9), np.arange(9, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成7个子数组
        res = array_split(a, 7)
        # 期望的结果是分别包含数组 a 的前2个元素、接下来的2个元素、接着的2个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素和最后1个元素的列表
        desired = [np.arange(2), np.arange(2, 4), np.arange(4, 6),
                   np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
                   np.arange(9, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成8个子数组
        res = array_split(a, 8)
        # 期望的结果是分别包含数组 a 的前2个元素、接下来的2个元素、接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、最后1个元素和最后1个元素的列表
        desired = [np.arange(2), np.arange(2, 4), np.arange(4, 5),
                   np.arange(5, 6), np.arange(6, 7), np.arange(7, 8),
                   np.arange(8, 9), np.arange(9, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成9个子数组
        res = array_split(a, 9)
        # 期望的结果是分别包含数组 a 的前2个元素、接下来的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、最后1个元素和最后1个元素的列表
        desired = [np.arange(2), np.arange(2, 3), np.arange(3, 4),
                   np.arange(4, 5), np.arange(5, 6), np.arange(6, 7),
                   np.arange(7, 8), np.arange(8, 9), np.arange(9, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成10个子数组
        res = array_split(a, 10)
        # 期望的结果是分别包含数组 a 的前1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、最后1个元素和最后1个元素的列表
        desired = [np.arange(1), np.arange(1, 2), np.arange(2, 3),
                   np.arange(3, 4), np.arange(4, 5), np.arange(5, 6),
                   np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
                   np.arange(9, 10)]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)

        # 将数组 a 分割成11个子数组
        res = array_split(a, 11)
        # 期望的结果是分别包含数组 a 的前1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、再接着的1个元素、最后1个元素和空数组的列表
        desired = [np.arange(1), np.arange(1, 2), np.arange(2, 3),
                   np.arange(3, 4), np.arange(4, 5), np.arange(5, 6),
                   np.arange(6, 7), np.arange(7, 8), np.arange(8, 9),
                   np.arange(9, 10), np.array([])]
        # 比较函数输出和期望结果是否一致
        compare_results(res, desired)
    # 定义一个测试方法,用于测试在二维数组中按行进行整数切分
    def test_integer_split_2D_rows(self):
        # 创建一个包含两行每行有 0 到 9 的二维 NumPy 数组
        a = np.array([np.arange(10), np.arange(10)])
        # 在第一个轴上将数组 `a` 分成 3 部分
        res = array_split(a, 3, axis=0)
        # 目标结果是包含三个元素的列表,分别是包含 0 到 9 的数组、另一个包含 0 到 9 的数组、以及一个 0 行 10 列的零数组
        tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]),
                   np.zeros((0, 10))]
        # 比较函数,用于比较结果 `res` 和目标 `tgt`
        compare_results(res, tgt)
        # 断言 `a` 数组的数据类型与 `res` 列表中最后一个元素的数据类型相同
        assert_(a.dtype.type is res[-1].dtype.type)

        # 对手动切分也做同样的操作:
        # 在第一个轴上根据索引 [0, 1] 对数组 `a` 进行切分
        res = array_split(a, [0, 1], axis=0)
        # 目标结果是包含四个元素的列表,分别是一个空的 0 行 10 列的数组、一个包含 0 到 9 的数组、另一个包含 0 到 9 的数组
        tgt = [np.zeros((0, 10)), np.array([np.arange(10)]),
               np.array([np.arange(10)])]
        # 再次使用比较函数,比较结果 `res` 和目标 `tgt`
        compare_results(res, tgt)
        # 断言 `a` 数组的数据类型与 `res` 列表中最后一个元素的数据类型相同
        assert_(a.dtype.type is res[-1].dtype.type)

    # 定义一个测试方法,用于测试在二维数组中按列进行整数切分
    def test_integer_split_2D_cols(self):
        # 创建一个包含两行每行有 0 到 9 的二维 NumPy 数组
        a = np.array([np.arange(10), np.arange(10)])
        # 在最后一个轴(列)上将数组 `a` 分成 3 部分
        res = array_split(a, 3, axis=-1)
        # 目标结果是包含三个元素的列表,分别是包含 0 到 3 的两行两列数组、包含 4 到 6 的两行两列数组、包含 7 到 9 的两行两列数组
        desired = [np.array([np.arange(4), np.arange(4)]),
                   np.array([np.arange(4, 7), np.arange(4, 7)]),
                   np.array([np.arange(7, 10), np.arange(7, 10)])]
        # 使用比较函数,比较结果 `res` 和目标 `desired`
        compare_results(res, desired)

    # 定义一个测试方法,用于测试在二维数组中按默认轴进行整数切分
    def test_integer_split_2D_default(self):
        """ This will fail if we change default axis
        """
        # 创建一个包含两行每行有 0 到 9 的二维 NumPy 数组
        a = np.array([np.arange(10), np.arange(10)])
        # 在默认的第一个轴(行)上将数组 `a` 分成 3 部分
        res = array_split(a, 3)
        # 目标结果是包含三个元素的列表,分别是包含 0 到 9 的数组、另一个包含 0 到 9 的数组、以及一个 0 行 10 列的零数组
        tgt = [np.array([np.arange(10)]), np.array([np.arange(10)]),
                   np.zeros((0, 10))]
        # 使用比较函数,比较结果 `res` 和目标 `tgt`
        compare_results(res, tgt)
        # 断言 `a` 数组的数据类型与 `res` 列表中最后一个元素的数据类型相同
        assert_(a.dtype.type is res[-1].dtype.type)
        # 或许应该检查更高的维度

    # 根据条件跳过测试,如果不在 64 位平台上则跳过
    @pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform")
    def test_integer_split_2D_rows_greater_max_int32(self):
        # 创建一个形状为 (2^32, 2) 的广播数组,元素值均为 0
        a = np.broadcast_to([0], (1 << 32, 2))
        # 在第一个轴上将数组 `a` 分成 4 部分
        res = array_split(a, 4)
        # 创建一个形状为 (2^30, 2) 的广播数组,元素值均为 0,作为目标结果的一个分块
        chunk = np.broadcast_to([0], (1 << 30, 2))
        # 目标结果是包含四个元素的列表,每个元素都是形状为 (2^30, 2) 的广播数组
        tgt = [chunk] * 4
        # 遍历目标结果,断言结果 `res` 中每个分块的形状与目标结果中对应分块的形状相同
        for i in range(len(tgt)):
            assert_equal(res[i].shape, tgt[i].shape)

    # 定义一个测试方法,用于测试在一维数组中按索引进行切分
    def test_index_split_simple(self):
        # 创建一个包含 0 到 9 的一维 NumPy 数组
        a = np.arange(10)
        # 根据给定的索引 [1, 5, 7] 在最后一个轴上将数组 `a` 分成多个子数组
        res = array_split(a, [1, 5, 7], axis=-1)
        # 目标结果是包含四个元素的列表,分别是从 0 到 1、从 1 到 5、从 5 到 7、从 7 到 10 的子数组
        desired = [np.arange(0, 1), np.arange(1, 5), np.arange(5, 7),
                   np.arange(7, 10)]
        # 使用比较函数,比较结果 `res` 和目标 `desired`
        compare_results(res, desired)

    # 定义一个测试方法,用于测试在一维数组中按索引进行切分,其中低边界为 0
    def test_index_split_low_bound(self):
        # 创建一个包含 0 到 9 的一维 NumPy 数组
        a = np.arange(10)
        # 根据给定的索引 [0, 5, 7] 在最后一个轴上将数组 `a` 分成多个子数组
        res = array_split(a, [0, 5, 7], axis=-1)
        # 目标结果是包含四个元素的列表,分别是空数组、从 0 到 5、从 5 到 7、从 7 到 10 的子数组
        desired = [np.array([]), np.arange(0, 5), np.arange(5, 7),
                   np.arange(7, 10)]
        # 使用比较函数,比较结果 `res` 和目标 `desired`
        compare_results(res, desired)

    # 定义一个测试方法,用于测试在一维数组中按索引进行切分,其中高边界为数组长度
    def test_index_split_high_bound(self):
        # 创建一个包含 0 到 9 的一维 NumPy 数组
        a = np.arange(10)
        # 根据给定的索引 [0, 5, 7, 10, 12] 在最后一个轴上将数组 `a` 分成多个子数组
        res = array_split(a, [0, 5, 7, 10, 12], axis=-1)
        # 目标结果
class TestSplit:
    # The split function is essentially the same as array_split,
    # except that it test if splitting will result in an
    # equal split.  Only test for this case.

    def test_equal_split(self):
        # Create a NumPy array with values from 0 to 9
        a = np.arange(10)
        # Call the split function with array 'a' and split count 2
        res = split(a, 2)
        # Define the desired result as two arrays: [0, 1, 2, 3, 4] and [5, 6, 7, 8, 9]
        desired = [np.arange(5), np.arange(5, 10)]
        # Compare the result of split operation with the desired result
        compare_results(res, desired)

    def test_unequal_split(self):
        # Create a NumPy array with values from 0 to 9
        a = np.arange(10)
        # Assert that splitting array 'a' into 3 parts raises a ValueError
        assert_raises(ValueError, split, a, 3)


class TestColumnStack:
    def test_non_iterable(self):
        # Assert that passing a non-iterable argument (like '1') to column_stack raises a TypeError
        assert_raises(TypeError, column_stack, 1)

    def test_1D_arrays(self):
        # example from docstring
        # Create two 1-dimensional NumPy arrays 'a' and 'b'
        a = np.array((1, 2, 3))
        b = np.array((2, 3, 4))
        # Define the expected result of column_stack operation
        expected = np.array([[1, 2],
                             [2, 3],
                             [3, 4]])
        # Perform column_stack operation on arrays 'a' and 'b'
        actual = np.column_stack((a, b))
        # Assert that the actual result matches the expected result
        assert_equal(actual, expected)

    def test_2D_arrays(self):
        # same as hstack 2D docstring example
        # Create two 2-dimensional NumPy arrays 'a' and 'b'
        a = np.array([[1], [2], [3]])
        b = np.array([[2], [3], [4]])
        # Define the expected result of column_stack operation
        expected = np.array([[1, 2],
                             [2, 3],
                             [3, 4]])
        # Perform column_stack operation on arrays 'a' and 'b'
        actual = np.column_stack((a, b))
        # Assert that the actual result matches the expected result
        assert_equal(actual, expected)

    def test_generator(self):
        # Assert that passing a generator of arrays to column_stack raises a TypeError with a specific error message
        with pytest.raises(TypeError, match="arrays to stack must be"):
            column_stack((np.arange(3) for _ in range(2)))


class TestDstack:
    def test_non_iterable(self):
        # Assert that passing a non-iterable argument (like '1') to dstack raises a TypeError
        assert_raises(TypeError, dstack, 1)

    def test_0D_array(self):
        # Create a 0-dimensional NumPy array 'a'
        a = np.array(1)
        # Create another 0-dimensional NumPy array 'b'
        b = np.array(2)
        # Perform dstack operation on arrays 'a' and 'b'
        res = dstack([a, b])
        # Define the desired result of dstack operation
        desired = np.array([[[1, 2]]])
        # Assert that the result of dstack operation matches the desired result
        assert_array_equal(res, desired)

    def test_1D_array(self):
        # Create two 1-dimensional NumPy arrays 'a' and 'b'
        a = np.array([1])
        b = np.array([2])
        # Perform dstack operation on arrays 'a' and 'b'
        res = dstack([a, b])
        # Define the desired result of dstack operation
        desired = np.array([[[1, 2]]])
        # Assert that the result of dstack operation matches the desired result
        assert_array_equal(res, desired)

    def test_2D_array(self):
        # Create two 2-dimensional NumPy arrays 'a' and 'b'
        a = np.array([[1], [2]])
        b = np.array([[1], [2]])
        # Perform dstack operation on arrays 'a' and 'b'
        res = dstack([a, b])
        # Define the desired result of dstack operation
        desired = np.array([[[1, 1]], [[2, 2, ]]])
        # Assert that the result of dstack operation matches the desired result
        assert_array_equal(res, desired)

    def test_2D_array2(self):
        # Create two 1-dimensional NumPy arrays 'a' and 'b'
        a = np.array([1, 2])
        b = np.array([1, 2])
        # Perform dstack operation on arrays 'a' and 'b'
        res = dstack([a, b])
        # Define the desired result of dstack operation
        desired = np.array([[[1, 1], [2, 2]]])
        # Assert that the result of dstack operation matches the desired result
        assert_array_equal(res, desired)

    def test_generator(self):
        # Assert that passing a generator of arrays to dstack raises a TypeError with a specific error message
        with pytest.raises(TypeError, match="arrays to stack must be"):
            dstack((np.arange(3) for _ in range(2)))


# array_split has more comprehensive test of splitting.
# only do simple test on hsplit, vsplit, and dsplit
class TestHsplit:
    """Only testing for integer splits.

    """
    def test_non_iterable(self):
        # Assert that passing a non-iterable argument (like '1') to hsplit raises a ValueError
        assert_raises(ValueError, hsplit, 1, 1)

    def test_0D_array(self):
        # Create a 0-dimensional NumPy array 'a'
        a = np.array(1)
        try:
            # Attempt to perform hsplit operation on array 'a' with split count 2
            hsplit(a, 2)
            # If successful, assert an unexpected condition to fail the test
            assert_(0)
        except ValueError:
            # Catch the expected ValueError if hsplit operation fails
            pass
    # 定义一个测试函数,用于测试对一维数组进行水平分割的功能
    def test_1D_array(self):
        # 创建一个包含整数1到4的一维 NumPy 数组
        a = np.array([1, 2, 3, 4])
        # 调用 hsplit 函数对数组进行水平分割,分割成两部分
        res = hsplit(a, 2)
        # 预期的分割结果,分别是包含[1, 2]和[3, 4]的两个数组
        desired = [np.array([1, 2]), np.array([3, 4])]
        # 调用 compare_results 函数比较实际结果和预期结果
        compare_results(res, desired)
    
    # 定义一个测试函数,用于测试对二维数组进行水平分割的功能
    def test_2D_array(self):
        # 创建一个包含两行四列的二维 NumPy 数组
        a = np.array([[1, 2, 3, 4],
                      [1, 2, 3, 4]])
        # 调用 hsplit 函数对数组进行水平分割,分割成两部分
        res = hsplit(a, 2)
        # 预期的分割结果,分别是包含[[1, 2], [1, 2]]和[[3, 4], [3, 4]]的两个数组
        desired = [np.array([[1, 2], [1, 2]]), np.array([[3, 4], [3, 4]])]
        # 调用 compare_results 函数比较实际结果和预期结果
        compare_results(res, desired)
class TestVsplit:
    """Only testing for integer splits.
    
    Test class for verifying vsplit function behavior.
    """
    
    def test_non_iterable(self):
        # Assert that vsplit raises a ValueError when provided with non-iterable inputs
        assert_raises(ValueError, vsplit, 1, 1)
        
    def test_0D_array(self):
        # Create a 0-dimensional NumPy array
        a = np.array(1)
        # Assert that vsplit raises a ValueError when applied to a 0-dimensional array
        assert_raises(ValueError, vsplit, a, 2)
        
    def test_1D_array(self):
        # Create a 1-dimensional NumPy array
        a = np.array([1, 2, 3, 4])
        try:
            # Attempt to split the array into 2 parts using vsplit
            vsplit(a, 2)
            # Assert a failure because vsplit should raise a ValueError
            assert_(0)
        except ValueError:
            pass
    
    def test_2D_array(self):
        # Create a 2-dimensional NumPy array
        a = np.array([[1, 2, 3, 4],
                      [1, 2, 3, 4]])
        # Split the array into 2 vertical parts using vsplit
        res = vsplit(a, 2)
        # Define the expected result as a list of 2-dimensional arrays
        desired = [np.array([[1, 2, 3, 4]]), np.array([[1, 2, 3, 4]])]
        # Compare the actual result with the expected result
        compare_results(res, desired)


class TestDsplit:
    # Only testing for integer splits.
    """Test class for verifying dsplit function behavior."""

    def test_non_iterable(self):
        # Assert that dsplit raises a ValueError when provided with non-iterable inputs
        assert_raises(ValueError, dsplit, 1, 1)
        
    def test_0D_array(self):
        # Create a 0-dimensional NumPy array
        a = np.array(1)
        # Assert that dsplit raises a ValueError when applied to a 0-dimensional array
        assert_raises(ValueError, dsplit, a, 2)
        
    def test_1D_array(self):
        # Create a 1-dimensional NumPy array
        a = np.array([1, 2, 3, 4])
        # Assert that dsplit raises a ValueError when applied to a 1-dimensional array
        assert_raises(ValueError, dsplit, a, 2)
        
    def test_2D_array(self):
        # Create a 2-dimensional NumPy array
        a = np.array([[1, 2, 3, 4],
                      [1, 2, 3, 4]])
        try:
            # Attempt to split the array into 2 parts using dsplit
            dsplit(a, 2)
            # Assert a failure because dsplit should raise a ValueError
            assert_(0)
        except ValueError:
            pass
    
    def test_3D_array(self):
        # Create a 3-dimensional NumPy array
        a = np.array([[[1, 2, 3, 4],
                       [1, 2, 3, 4]],
                      [[1, 2, 3, 4],
                       [1, 2, 3, 4]]])
        # Split the array into 2 parts along the third dimension using dsplit
        res = dsplit(a, 2)
        # Define the expected result as a list of 3-dimensional arrays
        desired = [np.array([[[1, 2], [1, 2]], [[1, 2], [1, 2]]]),
                   np.array([[[3, 4], [3, 4]], [[3, 4], [3, 4]]])]
        # Compare the actual result with the expected result
        compare_results(res, desired)


class TestSqueeze:
    """Test class for verifying squeeze function behavior."""
    
    def test_basic(self):
        # Importing necessary function from numpy.random
        from numpy.random import rand
        
        # Generate random arrays with varied dimensions
        a = rand(20, 10, 10, 1, 1)
        b = rand(20, 1, 10, 1, 20)
        c = rand(1, 1, 20, 10)
        
        # Assert that squeezing these arrays gives expected reshaped results
        assert_array_equal(np.squeeze(a), np.reshape(a, (20, 10, 10)))
        assert_array_equal(np.squeeze(b), np.reshape(b, (20, 10, 20)))
        assert_array_equal(np.squeeze(c), np.reshape(c, (20, 10)))
        
        # Test squeezing a 0-dimensional array to ensure it remains an ndarray
        a = [[[1.5]]]
        res = np.squeeze(a)
        assert_equal(res, 1.5)
        assert_equal(res.ndim, 0)
        assert_equal(type(res), np.ndarray)
    def test_basic(self):
        # 使用0维的ndarray
        a = np.array(1)  # 创建一个0维的ndarray,值为1
        b = np.array([[1, 2], [3, 4]])  # 创建一个2x2的二维ndarray
        k = np.array([[1, 2], [3, 4]])  # k被赋值为和b相同的二维ndarray
        assert_array_equal(np.kron(a, b), k)  # 断言np.kron(a, b)和k相等
        a = np.array([[1, 2], [3, 4]])  # 更新a为一个2x2的二维ndarray
        b = np.array(1)  # 更新b为一个0维的ndarray,值为1
        assert_array_equal(np.kron(a, b), k)  # 再次断言np.kron(a, b)和k相等

        # 使用1维的ndarray
        a = np.array([3])  # 创建一个包含单个元素3的1维ndarray
        b = np.array([[1, 2], [3, 4]])  # 创建一个2x2的二维ndarray
        k = np.array([[3, 6], [9, 12]])  # k被赋值为相应的2x2的二维ndarray
        assert_array_equal(np.kron(a, b), k)  # 断言np.kron(a, b)和k相等
        a = np.array([[1, 2], [3, 4]])  # 更新a为一个2x2的二维ndarray
        b = np.array([3])  # 更新b为一个包含单个元素3的1维ndarray
        assert_array_equal(np.kron(a, b), k)  # 再次断言np.kron(a, b)和k相等

        # 使用3维的ndarray
        a = np.array([[[1]], [[2]]])  # 创建一个包含两个元素的3维ndarray
        b = np.array([[1, 2], [3, 4]])  # 创建一个2x2的二维ndarray
        k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])  # k被赋值为相应的3维ndarray
        assert_array_equal(np.kron(a, b), k)  # 断言np.kron(a, b)和k相等
        a = np.array([[1, 2], [3, 4]])  # 更新a为一个2x2的二维ndarray
        b = np.array([[[1]], [[2]]])  # 更新b为一个包含两个元素的3维ndarray
        k = np.array([[[1, 2], [3, 4]], [[2, 4], [6, 8]]])  # k被赋值为相应的3维ndarray
        assert_array_equal(np.kron(a, b), k)  # 再次断言np.kron(a, b)和k相等

    def test_return_type(self):
        class myarray(np.ndarray):
            __array_priority__ = 1.0

        a = np.ones([2, 2])  # 创建一个2x2的全1数组
        ma = myarray(a.shape, a.dtype, a.data)  # 使用自定义类创建一个数组
        assert_equal(type(kron(a, a)), np.ndarray)  # 断言kron(a, a)的返回类型是np.ndarray
        assert_equal(type(kron(ma, ma)), myarray)  # 断言kron(ma, ma)的返回类型是myarray
        assert_equal(type(kron(a, ma)), myarray)  # 断言kron(a, ma)的返回类型是myarray
        assert_equal(type(kron(ma, a)), myarray)  # 断言kron(ma, a)的返回类型是myarray

    @pytest.mark.parametrize(
        "array_class", [np.asarray, np.asmatrix]
    )
    def test_kron_smoke(self, array_class):
        a = array_class(np.ones([3, 3]))  # 根据array_class创建一个3x3的全1数组
        b = array_class(np.ones([3, 3]))  # 根据array_class创建另一个3x3的全1数组
        k = array_class(np.ones([9, 9]))  # 根据array_class创建一个9x9的全1数组

        assert_array_equal(np.kron(a, b), k)  # 断言np.kron(a, b)和k相等

    def test_kron_ma(self):
        x = np.ma.array([[1, 2], [3, 4]], mask=[[0, 1], [1, 0]])  # 创建一个带有掩码的掩盖数组
        k = np.ma.array(np.diag([1, 4, 4, 16]), mask=~np.array(np.identity(4), dtype=bool))  # 创建一个带有掩码的掩盖数组

        assert_array_equal(k, np.kron(x, x))  # 断言np.kron(x, x)和k相等

    @pytest.mark.parametrize(
        "shape_a,shape_b", [
            ((1, 1), (1, 1)),
            ((1, 2, 3), (4, 5, 6)),
            ((2, 2), (2, 2, 2)),
            ((1, 0), (1, 1)),
            ((2, 0, 2), (2, 2)),
            ((2, 0, 0, 2), (2, 0, 2)),
        ])
    def test_kron_shape(self, shape_a, shape_b):
        a = np.ones(shape_a)  # 创建一个形状为shape_a的全1数组
        b = np.ones(shape_b)  # 创建一个形状为shape_b的全1数组
        normalised_shape_a = (1,) * max(0, len(shape_b)-len(shape_a)) + shape_a  # 规范化shape_a的形状
        normalised_shape_b = (1,) * max(0, len(shape_a)-len(shape_b)) + shape_b  # 规范化shape_b的形状
        expected_shape = np.multiply(normalised_shape_a, normalised_shape_b)  # 计算期望的形状

        k = np.kron(a, b)  # 计算np.kron(a, b)
        assert np.array_equal(
                k.shape, expected_shape), "Unexpected shape from kron"  # 断言k的形状与期望的形状相等
class TestTile:
    def test_basic(self):
        # 创建一个包含元素 [0, 1, 2] 的 NumPy 数组
        a = np.array([0, 1, 2])
        # 创建一个包含列表 [[1, 2], [3, 4]] 的 Python 列表
        b = [[1, 2], [3, 4]]
        # 测试 tile 函数对数组 a 的重复操作是否正确
        assert_equal(tile(a, 2), [0, 1, 2, 0, 1, 2])
        # 测试 tile 函数对数组 a 的 (2, 2) 形状的重复操作是否正确
        assert_equal(tile(a, (2, 2)), [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]])
        # 测试 tile 函数对数组 a 的 (1, 2) 形状的重复操作是否正确
        assert_equal(tile(a, (1, 2)), [[0, 1, 2, 0, 1, 2]])
        # 测试 tile 函数对列表 b 的 2 次重复操作是否正确
        assert_equal(tile(b, 2), [[1, 2, 1, 2], [3, 4, 3, 4]])
        # 测试 tile 函数对列表 b 的 (2, 1) 形状的重复操作是否正确
        assert_equal(tile(b, (2, 1)), [[1, 2], [3, 4], [1, 2], [3, 4]])
        # 测试 tile 函数对列表 b 的 (2, 2) 形状的重复操作是否正确
        assert_equal(tile(b, (2, 2)), [[1, 2, 1, 2], [3, 4, 3, 4],
                                       [1, 2, 1, 2], [3, 4, 3, 4]])

    def test_tile_one_repetition_on_array_gh4679(self):
        # 创建一个从 0 到 4 的 NumPy 数组
        a = np.arange(5)
        # 对数组 a 进行 1 次重复操作,并将结果保存到数组 b
        b = tile(a, 1)
        # 将数组 b 中的所有元素加上 2
        b += 2
        # 断言数组 a 未被修改
        assert_equal(a, np.arange(5))

    def test_empty(self):
        # 创建一个包含空列表的 NumPy 数组
        a = np.array([[[]]])
        # 创建一个包含两个空列表的 NumPy 数组
        b = np.array([[], []])
        # 对数组 b 进行 2 次重复操作,并获取其形状
        c = tile(b, 2).shape
        # 对数组 a 进行 (3, 2, 5) 形状的重复操作,并获取其形状
        d = tile(a, (3, 2, 5)).shape
        # 断言 c 的形状为 (2, 0)
        assert_equal(c, (2, 0))
        # 断言 d 的形状为 (3, 2, 0)
        assert_equal(d, (3, 2, 0))

    def test_kroncompare(self):
        # 从 numpy.random 模块导入 randint 函数
        from numpy.random import randint

        # 定义重复操作的形状和数组的形状
        reps = [(2,), (1, 2), (2, 1), (2, 2), (2, 3, 2), (3, 2)]
        shape = [(3,), (2, 3), (3, 4, 3), (3, 2, 3), (4, 3, 2, 4), (2, 2)]
        # 遍历数组的形状
        for s in shape:
            # 生成指定形状的随机数组 b
            b = randint(0, 10, size=s)
            # 遍历重复操作的形状
            for r in reps:
                # 创建与数组 b 类型相同的全为 1 的数组 a
                a = np.ones(r, b.dtype)
                # 使用 tile 函数对数组 b 进行 r 形状的重复操作
                large = tile(b, r)
                # 使用 kron 函数对数组 a 和 b 进行 Kroncker 乘积
                klarge = kron(a, b)
                # 断言两者相等
                assert_equal(large, klarge)


class TestMayShareMemory:
    def test_basic(self):
        # 创建一个形状为 (50, 60) 的全为 1 的 NumPy 数组 d
        d = np.ones((50, 60))
        # 创建一个形状为 (30, 60, 6) 的全为 1 的 NumPy 数组 d2
        d2 = np.ones((30, 60, 6))
        # 断言数组 d 与自身共享内存
        assert_(np.may_share_memory(d, d))
        # 断言数组 d 与 d 的倒序视图不共享内存
        assert_(np.may_share_memory(d, d[::-1]))
        # 断言数组 d 与 d 的每隔一个元素视图不共享内存
        assert_(np.may_share_memory(d, d[::2]))
        # 断言数组 d 与 d 的行倒序视图不共享内存
        assert_(np.may_share_memory(d, d[1:, ::-1]))

        # 断言数组 d 的倒序视图与数组 d2 不共享内存
        assert_(not np.may_share_memory(d[::-1], d2))
        # 断言数组 d 的每隔一个元素视图与数组 d2 不共享内存
        assert_(not np.may_share_memory(d[::2], d2))
        # 断言数组 d 的行倒序视图与数组 d2 不共享内存
        assert_(not np.may_share_memory(d[1:, ::-1], d2))
        # 断言数组 d2 的行倒序视图与自身共享内存
        assert_(np.may_share_memory(d2[1:, ::-1], d2))


# Utility
def compare_results(res, desired):
    """Compare lists of arrays."""
    # 如果 res 和 desired 的长度不一致,则抛出 ValueError 异常
    if len(res) != len(desired):
        raise ValueError("Iterables have different lengths")
    # 遍历 res 和 desired 中的每对数组,并使用 assert_array_equal 断言它们相等
    for x, y in zip(res, desired):
        assert_array_equal(x, y)

.\numpy\numpy\lib\tests\test_stride_tricks.py

# 导入必要的库
import numpy as np
# 从 numpy._core._rational_tests 模块中导入 rational 函数
from numpy._core._rational_tests import rational
# 从 numpy.testing 模块中导入多个断言函数
from numpy.testing import (
    assert_equal, assert_array_equal, assert_raises, assert_,
    assert_raises_regex, assert_warns,
    )
# 从 numpy.lib._stride_tricks_impl 模块中导入多个函数
from numpy.lib._stride_tricks_impl import (
    as_strided, broadcast_arrays, _broadcast_shape, broadcast_to,
    broadcast_shapes, sliding_window_view,
    )
# 导入 pytest 模块
import pytest


def assert_shapes_correct(input_shapes, expected_shape):
    # 对给定的输入形状列表进行广播,检查广播后的输出形状是否与期望形状相同。

    # 创建一个由零数组组成的列表,每个数组使用给定的形状
    inarrays = [np.zeros(s) for s in input_shapes]
    # 对列表中的数组进行广播,得到广播后的数组列表
    outarrays = broadcast_arrays(*inarrays)
    # 获取广播后每个数组的形状
    outshapes = [a.shape for a in outarrays]
    # 创建一个期望的形状列表,长度与输入形状列表相同
    expected = [expected_shape] * len(inarrays)
    # 断言广播后的形状与期望形状列表相同
    assert_equal(outshapes, expected)


def assert_incompatible_shapes_raise(input_shapes):
    # 对给定的(不兼容的)输入形状列表进行广播,检查是否引发 ValueError 异常。

    # 创建一个由零数组组成的列表,每个数组使用给定的形状
    inarrays = [np.zeros(s) for s in input_shapes]
    # 断言对列表中的数组进行广播时会引发 ValueError 异常
    assert_raises(ValueError, broadcast_arrays, *inarrays)


def assert_same_as_ufunc(shape0, shape1, transposed=False, flipped=False):
    # 对两个形状进行广播,检查数据布局是否与 ufunc 执行广播时的相同。

    # 创建一个形状为 shape0 的零数组
    x0 = np.zeros(shape0, dtype=int)
    # 根据 shape1 创建一个一维数组,并将其重塑为 shape1 形状的二维数组
    n = int(np.multiply.reduce(shape1))
    x1 = np.arange(n).reshape(shape1)
    # 如果 transposed 为 True,则对 x0 和 x1 进行转置
    if transposed:
        x0 = x0.T
        x1 = x1.T
    # 如果 flipped 为 True,则对 x0 和 x1 进行反转
    if flipped:
        x0 = x0[::-1]
        x1 = x1[::-1]
    # 使用 add ufunc 进行广播操作。由于我们在 x1 上加的是零数组 x0,因此结果应该与 x1 的广播视图完全相同。
    y = x0 + x1
    # 对 x0 和 x1 进行广播操作,并断言广播后的结果与 y 相同
    b0, b1 = broadcast_arrays(x0, x1)
    assert_array_equal(y, b1)


def test_same():
    # 测试相同输入的广播

    # 创建一个长度为 10 的一维数组 x 和 y
    x = np.arange(10)
    y = np.arange(10)
    # 对 x 和 y 进行广播操作
    bx, by = broadcast_arrays(x, y)
    # 断言广播后的 x 和 y 与原始 x 和 y 相同
    assert_array_equal(x, bx)
    assert_array_equal(y, by)


def test_broadcast_kwargs():
    # 测试使用非 'subok' 关键字参数调用 np.broadcast_arrays() 是否引发 TypeError 异常

    # 创建一个长度为 10 的一维数组 x 和 y
    x = np.arange(10)
    y = np.arange(10)

    # 使用 assert_raises_regex 上下文管理器断言调用 broadcast_arrays() 时使用非 'subok' 关键字参数会引发 TypeError 异常
    with assert_raises_regex(TypeError, 'got an unexpected keyword'):
        broadcast_arrays(x, y, dtype='float64')


def test_one_off():
    # 测试特殊的广播情况

    # 创建一个形状为 (1, 3) 的二维数组 x 和一个形状为 (3, 1) 的二维数组 y
    x = np.array([[1, 2, 3]])
    y = np.array([[1], [2], [3]])
    # 对 x 和 y 进行广播操作
    bx, by = broadcast_arrays(x, y)
    # 创建一个期望的二维数组 bx0 和 by0
    bx0 = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])
    by0 = bx0.T
    # 断言广播后的 bx 和 by 与 bx0 和 by0 相同
    assert_array_equal(bx0, bx)
    assert_array_equal(by0, by)


def test_same_input_shapes():
    # 检查最终形状是否与输入形状相同

    # 定义一个包含各种输入形状的列表
    data = [
        (),
        (1,),
        (3,),
        (0, 1),
        (0, 3),
        (1, 0),
        (3, 0),
        (1, 3),
        (3, 1),
        (3, 3),
    ]
    # 遍历数据列表中的每个形状对象
    for shape in data:
        # 构建只包含当前形状对象的列表作为输入
        input_shapes = [shape]
        # 调用函数以确保输入的形状与给定的形状对象相匹配
        assert_shapes_correct(input_shapes, shape)
        # 构建包含两个当前形状对象的列表作为输入
        input_shapes2 = [shape, shape]
        # 调用函数以确保输入的形状与给定的形状对象相匹配
        assert_shapes_correct(input_shapes2, shape)
        # 构建包含三个当前形状对象的列表作为输入
        input_shapes3 = [shape, shape, shape]
        # 调用函数以确保输入的形状与给定的形状对象相匹配
        assert_shapes_correct(input_shapes3, shape)
def test_two_compatible_by_ones_input_shapes():
    # 检查两个不同输入形状,但长度相同且某些部分为1的情况,是否广播到正确的形状。

    data = [
        [[(1,), (3,)], (3,)],
        [[(1, 3), (3, 3)], (3, 3)],
        [[(3, 1), (3, 3)], (3, 3)],
        [[(1, 3), (3, 1)], (3, 3)],
        [[(1, 1), (3, 3)], (3, 3)],
        [[(1, 1), (1, 3)], (1, 3)],
        [[(1, 1), (3, 1)], (3, 1)],
        [[(1, 0), (0, 0)], (0, 0)],
        [[(0, 1), (0, 0)], (0, 0)],
        [[(1, 0), (0, 1)], (0, 0)],
        [[(1, 1), (0, 0)], (0, 0)],
        [[(1, 1), (1, 0)], (1, 0)],
        [[(1, 1), (0, 1)], (0, 1)],
    ]
    for input_shapes, expected_shape in data:
        assert_shapes_correct(input_shapes, expected_shape)
        # 反转输入形状,因为广播应该是对称的。
        assert_shapes_correct(input_shapes[::-1], expected_shape)


def test_two_compatible_by_prepending_ones_input_shapes():
    # 检查两个不同长度的输入形状,但以1开头的情况是否广播到正确的形状。

    data = [
        [[(), (3,)], (3,)],
        [[(3,), (3, 3)], (3, 3)],
        [[(3,), (3, 1)], (3, 3)],
        [[(1,), (3, 3)], (3, 3)],
        [[(), (3, 3)], (3, 3)],
        [[(1, 1), (3,)], (1, 3)],
        [[(1,), (3, 1)], (3, 1)],
        [[(1,), (1, 3)], (1, 3)],
        [[(), (1, 3)], (1, 3)],
        [[(), (3, 1)], (3, 1)],
        [[(), (0,)], (0,)],
        [[(0,), (0, 0)], (0, 0)],
        [[(0,), (0, 1)], (0, 0)],
        [[(1,), (0, 0)], (0, 0)],
        [[(), (0, 0)], (0, 0)],
        [[(1, 1), (0,)], (1, 0)],
        [[(1,), (0, 1)], (0, 1)],
        [[(1,), (1, 0)], (1, 0)],
        [[(), (1, 0)], (1, 0)],
        [[(), (0, 1)], (0, 1)],
    ]
    for input_shapes, expected_shape in data:
        assert_shapes_correct(input_shapes, expected_shape)
        # 反转输入形状,因为广播应该是对称的。
        assert_shapes_correct(input_shapes[::-1], expected_shape)


def test_incompatible_shapes_raise_valueerror():
    # 检查对于不兼容的形状是否引发 ValueError 异常。

    data = [
        [(3,), (4,)],
        [(2, 3), (2,)],
        [(3,), (3,), (4,)],
        [(1, 3, 4), (2, 3, 3)],
    ]
    for input_shapes in data:
        assert_incompatible_shapes_raise(input_shapes)
        # 反转输入形状,因为广播应该是对称的。
        assert_incompatible_shapes_raise(input_shapes[::-1])


def test_same_as_ufunc():
    # 检查数据布局是否与 ufunc 执行的操作相同。
    
    data = [
        [[(1,), (3,)], (3,)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (3,)
        [[(1, 3), (3, 3)], (3, 3)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (3, 3)
        [[(3, 1), (3, 3)], (3, 3)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (3, 3)
        [[(1, 3), (3, 1)], (3, 3)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (3, 3)
        [[(1, 1), (3, 3)], (3, 3)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (3, 3)
        [[(1, 1), (1, 3)], (1, 3)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (1, 3)
        [[(1, 1), (3, 1)], (3, 1)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (3, 1)
        [[(1, 0), (0, 0)], (0, 0)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (0, 0)
        [[(0, 1), (0, 0)], (0, 0)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (0, 0)
        [[(1, 0), (0, 1)], (0, 0)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (0, 0)
        [[(1, 1), (0, 0)], (0, 0)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (0, 0)
        [[(1, 1), (1, 0)], (1, 0)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (1, 0)
        [[(1, 1), (0, 1)], (0, 1)],  # 示例数据:包含两个元组作为输入形状,期望输出形状为 (0, 1)
        [[(), (3,)], (3,)],  # 示例数据:包含一个元组和一个标量作为输入形状,期望输出形状为 (3,)
        [[(3,), (3, 3)], (3, 3)],  # 示例数据:包含一个元组和一个向量作为输入形状,期望输出形状为 (3, 3)
        [[(3,), (3, 1)], (3, 3)],  # 示例数据:包含一个元组和一个向量作为输入形状,期望输出形状为 (3, 3)
        [[(1,), (3, 3)], (3, 3)],  # 示例数据:包含一个标量和一个向量作为输入形状,期望输出形状为 (3, 3)
        [[(), (3, 3)], (3, 3)],  # 示例数据:包含一个标量和一个矩阵作为输入形状,期望输出形状为 (3, 3)
        [[(1, 1), (3,)], (1, 3)],  # 示例数据:包含一个元组和一个标量作为输入形状,期望输出形状为 (1, 3)
        [[(1,), (3, 1)], (3, 1)],  # 示例数据:包含一个标量和一个向量作为输入形状,期望输出形状为 (3, 1)
        [[(1,), (1, 3)], (1, 3)],  # 示例数据:包含一个标量和一个矩阵作为输入形状,期望输出形状为 (1, 3)
        [[(), (1, 3)], (1, 3)],  # 示例数据:包含一个标量和一个矩阵作为输入形状,期望输出形状为 (1, 3)
        [[(), (3, 1)], (3, 1)],  # 示例数据:包含一个标量和一个向量作为输入形状,期望输出形状为 (3, 1)
        [[(), (0,)], (0,)],  # 示例数据:包含一个标量和一个向量作为输入形状,期望输出形状为 (0,)
        [[(0,), (0, 0)], (0, 0)],  # 示例数据:包含一个向量和一个标量作为输入形状,期望输出形状为 (0, 0)
        [[(0,), (0, 1)], (0, 0)],  # 示例数据:包含一个向量和一个向量作为输入形状,期望输出形状为 (0, 0)
        [[(1,), (0, 0)], (0, 0)],  # 示例数据:包含一个标量和一个标量作为输入形状,期望输出形状为 (0, 0)
        [[(), (0, 0)], (0, 0)],  # 示例数据:包含一个标量和一个标量作为输入形状,期望输出形状为 (0, 0)
        [[(1, 1), (0,)], (1, 0)],  # 示例数据:包含一个元组和一个标量作为输入形状,期望输出形状为 (1, 0)
        [[(1,), (0, 1)], (0, 1)],  # 示例数据:包含一个标量和一个向量作为输入形状,期望输出形状为 (0, 1)
        [[(1,), (1, 0)], (1, 0)],  # 示例数据:包含一个标量和一个标量作为输入形状,期望输出形状为 (1, 0)
        [[(), (1, 0)], (1, 0)],  # 示例数据:包含一个标量和一个标量作为输入形状,期望输出形状为 (1, 0)
        [[(), (0, 1)], (0, 1)],  # 示例数据:包含一个标量和一个向量作为输入形状,期望输出形状为 (0, 1)
    ]
    for input_shapes, expected_shape in data:
        assert_same_as_ufunc(input_shapes[0], input_shapes[1],
                             "Shapes: %s %s" % (input_shapes[0], input_shapes[1]))
        # 反转输入形状,因为广播应该是对称的。
        assert_same_as_ufunc(input_shapes[1], input_shapes[0])
        # 也尝试转置它们。
        assert_same_as_ufunc(input_shapes[0], input_shapes[1], True)
        # ... 并且对于非秩为0的输入,也进行翻转以测试负步长。
        if () not in input_shapes:
            assert_same_as_ufunc(input_shapes[0], input_shapes[1], False, True)
            assert_same_as_ufunc(input_shapes[0], input_shapes[1], True, True)
# 定义函数用于测试 broadcast_to 函数的成功情况
def test_broadcast_to_succeeds():
    # 定义测试数据,每个元素包含输入数组、目标形状和期望的输出数组
    data = [
        [np.array(0), (0,), np.array(0)],  # 输入数组为0维,目标形状为(0,),期望输出为0
        [np.array(0), (1,), np.zeros(1)],  # 输入数组为0维,目标形状为(1,),期望输出为[0.]
        [np.array(0), (3,), np.zeros(3)],  # 输入数组为0维,目标形状为(3,),期望输出为[0. 0. 0.]
        [np.ones(1), (1,), np.ones(1)],    # 输入数组为1维[1.],目标形状为(1,),期望输出为[1.]
        [np.ones(1), (2,), np.ones(2)],    # 输入数组为1维[1.],目标形状为(2,),期望输出为[1. 1.]
        [np.ones(1), (1, 2, 3), np.ones((1, 2, 3))],  # 输入数组为1维[1.],目标形状为(1, 2, 3),期望输出为[[[1. 1. 1.] [1. 1. 1.]]]
        [np.arange(3), (3,), np.arange(3)],  # 输入数组为1维[0 1 2],目标形状为(3,),期望输出为[0 1 2]
        [np.arange(3), (1, 3), np.arange(3).reshape(1, -1)],  # 输入数组为1维[0 1 2],目标形状为(1, 3),期望输出为[[0 1 2]]
        [np.arange(3), (2, 3), np.array([[0, 1, 2], [0, 1, 2]])],  # 输入数组为1维[0 1 2],目标形状为(2, 3),期望输出为[[0 1 2] [0 1 2]]
        # 测试目标形状为非元组的情况
        [np.ones(0), 0, np.ones(0)],  # 输入数组为0维,目标形状为0,期望输出为空数组
        [np.ones(1), 1, np.ones(1)],  # 输入数组为1维[1.],目标形状为1,期望输出为[1.]
        [np.ones(1), 2, np.ones(2)],  # 输入数组为1维[1.],目标形状为2,期望输出为[1. 1.]
        # 下面这些大小为0的情况看似奇怪,但是复现了与 ufuncs 广播的行为(参见上面的 test_same_as_ufunc)
        [np.ones(1), (0,), np.ones(0)],  # 输入数组为1维[1.],目标形状为(0,),期望输出为空数组
        [np.ones((1, 2)), (0, 2), np.ones((0, 2))],  # 输入数组为2维[[1. 1.]],目标形状为(0, 2),期望输出为二维空数组
        [np.ones((2, 1)), (2, 0), np.ones((2, 0))],  # 输入数组为2维[[1.] [1.]],目标形状为(2, 0),期望输出为二维空数组
    ]
    # 遍历测试数据
    for input_array, shape, expected in data:
        # 调用 broadcast_to 函数得到实际输出
        actual = broadcast_to(input_array, shape)
        # 断言实际输出与期望输出相等
        assert_array_equal(expected, actual)


# 定义函数用于测试 broadcast_to 函数的异常情况
def test_broadcast_to_raises():
    # 定义测试数据,每个元素包含原始形状和目标形状
    data = [
        [(0,), ()],        # 原始形状为(0,),目标形状为空元组,期望抛出 ValueError 异常
        [(1,), ()],        # 原始形状为(1,),目标形状为空元组,期望抛出 ValueError 异常
        [(3,), ()],        # 原始形状为(3,),目标形状为空元组,期望抛出 ValueError 异常
        [(3,), (1,)],      # 原始形状为(3,),目标形状为(1,),期望抛出 ValueError 异常
        [(3,), (2,)],      # 原始形状为(3,),目标形状为(2,),期望抛出 ValueError 异常
        [(3,), (4,)],      # 原始形状为(3,),目标形状为(4,),期望抛出 ValueError 异常
        [(1, 2), (2, 1)],  # 原始形状为(1, 2),目标形状为(2, 1),期望抛出 ValueError 异常
        [(1, 1), (1,)],    # 原始形状为(1, 1),目标形状为(1,),期望抛出 ValueError 异常
        [(1,), -1],        # 原始形状为(1,),目标形状为-1,期望抛出 ValueError 异常
        [(1,), (-1,)],     # 原始形状为(1,),目标形状为(-1,),期望抛出 ValueError 异常
        [(1, 2), (-1, 2)],  # 原始形状为(1, 2),目标形状为(-1, 2),期望抛出 ValueError 异常
    ]
    # 遍历测试数据
    for orig_shape, target_shape in data:
        # 创建具有原始形状的零数组
        arr = np.zeros(orig_shape)
        # 使用 lambda 表达式和 assert_raises 断言抛出 ValueError 异常
        assert_raises(ValueError, lambda: broadcast_to(arr, target_shape))


# 定义函数用于测试 broadcast_shapes 函数
def test_broadcast_shape():
    # 测试 _broadcast_shape 内部函数
    # _broadcast_shape 已经通过 broadcast_arrays 间接测试
    # _broadcast_shape 也通过公共 broadcast_shapes 函数测试
    assert_equal(_broadcast_shape(), ())  # 确保没有参数时返回空元组
    assert_equal(_broadcast_shape([1, 2]), (2,))  # 测试一维数组[1, 2],期望返回元组(2,)
    assert_equal(_broadcast_shape(np.ones((1, 1))), (1, 1))  # 测试形状为(1, 1)的全一数组,期望返回元组(1, 1)
    assert_equal(_broadcast_shape(np.ones((1, 1)), np.ones((3, 4))), (3, 4))  # 测试两个不同形状的全一数组,期望返回元组(3, 4)
    assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 32)), (1, 2))  # 测试32个形状为(1, 2)的全一数组,期望返回元组(1, 2)
    assert_equal(_broadcast_shape(*([np.ones((1, 2))] * 100)), (1, 2))  # 测试100个形状为(1, 2)的全一数组,期望返回元组(1, 2)

    # gh-5862 的回归测试
    assert_equal(_broadcast_shape(*([np.ones(2)] * 32 + [1])), (2,))  # 测试32个形状为(2,)的全一数组和一个形状为(1,)的数组,期
    data = [
        [[], ()],                     # 空列表作为输入,预期输出是空元组
        [[()], ()],                   # 包含一个空元组作为输入,预期输出是空元组
        [[(7,)], (7,)],               # 包含一个包含一个元素的元组作为输入,预期输出是包含一个元素的元组
        [[(1, 2), (2,)], (1, 2)],     # 包含两个元组作为输入,预期输出是第一个元组
        [[(1, 1)], (1, 1)],           # 包含一个包含相同元素的元组作为输入,预期输出是包含相同元素的元组
        [[(1, 1), (3, 4)], (3, 4)],   # 包含两个不同元素的元组作为输入,预期输出是第二个元组
        [[(6, 7), (5, 6, 1), (7,), (5, 1, 7)], (5, 6, 7)],  # 多个不同长度元组的输入,预期输出是长度最长的元组
        [[(5, 6, 1)], (5, 6, 1)],     # 包含一个包含三个元素的元组作为输入,预期输出是包含三个元素的元组
        [[(1, 3), (3, 1)], (3, 3)],   # 包含两个不同顺序的元组作为输入,预期输出是第二个元组
        [[(1, 0), (0, 0)], (0, 0)],   # 包含两个不同元素的元组作为输入,预期输出是包含两个零的元组
        [[(0, 1), (0, 0)], (0, 0)],   # 包含两个不同元素的元组作为输入,预期输出是包含两个零的元组
        [[(1, 0), (0, 1)], (0, 0)],   # 包含两个不同元素的元组作为输入,预期输出是包含两个零的元组
        [[(1, 1), (0, 0)], (0, 0)],   # 包含两个不同元素的元组作为输入,预期输出是包含两个零的元组
        [[(1, 1), (1, 0)], (1, 0)],   # 包含两个不同元素的元组作为输入,预期输出是包含一个和零的元组
        [[(1, 1), (0, 1)], (0, 1)],   # 包含两个不同元素的元组作为输入,预期输出是包含一个和一的元组
        [[(), (0,)], (0,)],           # 包含一个空元组和一个非空元组作为输入,预期输出是非空元组
        [[(0,), (0, 0)], (0, 0)],     # 包含一个单元素元组和一个包含两个元素的元组作为输入,预期输出是包含两个零的元组
        [[(0,), (0, 1)], (0, 0)],     # 包含一个单元素元组和一个包含两个元素的元组作为输入,预期输出是包含两个零的元组
        [[(1,), (0, 0)], (0, 0)],     # 包含一个单元素元组和一个包含两个元素的元组作为输入,预期输出是包含两个零的元组
        [[(), (0, 0)], (0, 0)],       # 包含一个空元组和一个包含两个元素的元组作为输入,预期输出是包含两个零的元组
        [[(1, 1), (0,)], (1, 0)],     # 包含一个包含两个元素的元组和一个单元素元组作为输入,预期输出是包含一个和零的元组
        [[(1,), (0, 1)], (0, 1)],     # 包含一个单元素元组和一个包含两个元素的元组作为输入,预期输出是包含一个和一的元组
        [[(1,), (1, 0)], (1, 0)],     # 包含一个单元素元组和一个包含两个元素的元组作为输入,预期输出是包含一个和一的元组
        [[(), (1, 0)], (1, 0)],       # 包含一个空元组和一个包含两个元素的元组作为输入,预期输出是包含一个和一的元组
        [[(), (0, 1)], (0, 1)],       # 包含一个空元组和一个包含两个元素的元组作为输入,预期输出是包含一个和一的元组
        [[(1,), (3,)], (3,)],         # 包含一个单元素元组和一个包含一个元素的元组作为输入,预期输出是包含一个元素的元组
        [[2, (3, 2)], (3, 2)],        # 包含一个整数和一个包含两个元素的元组作为输入,预期输出是包含两个元素的元组
    ]

    for input_shapes, target_shape in data:
        assert_equal(broadcast_shapes(*input_shapes), target_shape)

    assert_equal(broadcast_shapes(*([(1, 2)] * 32)), (1, 2))      # 对包含32个相同元组的列表进行广播,预期输出是这个元组
    assert_equal(broadcast_shapes(*([(1, 2)] * 100)), (1, 2))    # 对包含100个相同元组的列表进行广播,预期输出是这个元组

    # 用于测试 gh-5862 的回归测试
    assert_equal(broadcast_shapes(*([(2,)] * 32)), (2,))         # 对包含32个包含一个元素的元组的列表进行广播,预期输出是包含一个元素的元组
def test_broadcast_shapes_raises():
    # 测试公共函数 broadcast_shapes
    # 定义测试数据,包含多组不同形状的输入
    data = [
        [(3,), (4,)],           # 输入形状为 (3,) 和 (4,)
        [(2, 3), (2,)],         # 输入形状为 (2, 3) 和 (2,)
        [(3,), (3,), (4,)],     # 输入形状为 (3,)、(3,) 和 (4,)
        [(1, 3, 4), (2, 3, 3)],  # 输入形状为 (1, 3, 4) 和 (2, 3, 3)
        [(1, 2), (3, 1), (3, 2), (10, 5)],  # 多组输入形状
        [2, (2, 3)],            # 输入形状为 2 和 (2, 3)
    ]
    
    # 遍历测试数据,对每组数据调用 broadcast_shapes 应当抛出 ValueError 异常
    for input_shapes in data:
        assert_raises(ValueError, lambda: broadcast_shapes(*input_shapes))
    
    # 构造一个包含多个相同形状的参数的列表,长度为 64
    bad_args = [(2,)] * 32 + [(3,)] * 32
    
    # 调用 broadcast_shapes 应当抛出 ValueError 异常
    assert_raises(ValueError, lambda: broadcast_shapes(*bad_args))


def test_as_strided():
    # 创建一个包含单个元素的 numpy 数组 a
    a = np.array([None])
    # 调用 as_strided 函数获取数组 a 的视图
    a_view = as_strided(a)
    # 预期的结果是包含单个 None 的 numpy 数组
    expected = np.array([None])
    # 断言 a_view 等于预期结果 expected
    assert_array_equal(a_view, expected)
    
    # 创建一个包含元素 [1, 2, 3, 4] 的 numpy 数组 a
    a = np.array([1, 2, 3, 4])
    # 使用 as_strided 函数创建形状为 (2,)、步长为 (2 * a.itemsize,) 的数组视图 a_view
    a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
    # 预期的结果是包含元素 [1, 3] 的 numpy 数组
    expected = np.array([1, 3])
    # 断言 a_view 等于预期结果 expected
    assert_array_equal(a_view, expected)
    
    # 创建一个包含元素 [1, 2, 3, 4] 的 numpy 数组 a
    a = np.array([1, 2, 3, 4])
    # 使用 as_strided 函数创建形状为 (3, 4)、步长为 (0, 1 * a.itemsize) 的数组视图 a_view
    a_view = as_strided(a, shape=(3, 4), strides=(0, 1 * a.itemsize))
    # 预期的结果是包含元素 [[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]] 的 numpy 数组
    expected = np.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]])
    # 断言 a_view 等于预期结果 expected
    assert_array_equal(a_view, expected)
    
    # 回归测试 gh-5081
    # 创建一个自定义结构的 dtype
    dt = np.dtype([('num', 'i4'), ('obj', 'O')])
    # 创建一个 dtype 为 dt、形状为 (4,) 的空数组 a
    a = np.empty((4,), dtype=dt)
    # 将 a['num'] 的值设为 [1, 2, 3, 4]
    a['num'] = np.arange(1, 5)
    # 使用 as_strided 函数创建形状为 (3, 4)、步长为 (0, a.itemsize) 的数组视图 a_view
    a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
    # 预期的 num 数组是 [[1, 2, 3, 4]] 的重复三次,obj 数组是包含四个 None 的列表,组成的列表
    expected_num = [[1, 2, 3, 4]] * 3
    expected_obj = [[None] * 4] * 3
    # 断言 a_view 的 dtype 等于预期的 dt
    assert_equal(a_view.dtype, dt)
    # 断言 a_view['num'] 等于预期结果 expected_num
    assert_array_equal(expected_num, a_view['num'])
    # 断言 a_view['obj'] 等于预期结果 expected_obj
    assert_array_equal(expected_obj, a_view['obj'])
    
    # 确保没有字段的空类型保持不变
    # 创建一个 dtype 为 'V4'、形状为 (4,) 的空数组 a
    a = np.empty((4,), dtype='V4')
    # 使用 as_strided 函数创建形状为 (3, 4)、步长为 (0, a.itemsize) 的数组视图 a_view
    a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
    # 断言 a 的 dtype 等于 a_view 的 dtype
    assert_equal(a.dtype, a_view.dtype)
    
    # 确保唯一可能失败的类型被正确处理
    # 创建一个自定义结构的 dtype,只包含一个 'V4' 类型的字段
    dt = np.dtype({'names': [''], 'formats': ['V4']})
    # 创建一个 dtype 为 dt、形状为 (4,) 的空数组 a
    a = np.empty((4,), dtype=dt)
    # 使用 as_strided 函数创建形状为 (3, 4)、步长为 (0, a.itemsize) 的数组视图 a_view
    a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
    # 断言 a 的 dtype 等于 a_view 的 dtype
    assert_equal(a.dtype, a_view.dtype)
    
    # 自定义 dtype 不应该丢失 (gh-9161)
    # 创建一个有理数类的实例列表 r,包含 [rational(0), rational(1), rational(2), rational(3)]
    r = [rational(i) for i in range(4)]
    # 创建一个 dtype 为 rational 的数组 a,包含 r 中的有理数实例
    a = np.array(r, dtype=rational)
    # 使用 as_strided 函数创建形状为 (3, 4)、步长为 (0, a.itemsize) 的数组视图 a_view
    a_view = as_strided(a, shape=(3, 4), strides=(0, a.itemsize))
    # 断言 a 的 dtype 等于 a_view 的 dtype
    assert_equal(a.dtype, a_view.dtype)
    # 断言 a_view 中包含 r 的重复三次的数组
    assert_array_equal([r] * 3, a_view)


class TestSlidingWindowView:
    def test_1d(self):
        # 创建一个包含元素 [0, 1, 2, 3, 4] 的 numpy 数组 arr
        arr = np.arange(5)
        # 使用 sliding_window_view 函数创建 arr 的滑动窗口视图 arr_view,窗口大小为 2
        arr_view = sliding_window_view(arr, 2)
        # 预期的结果是包含元素 [[0, 1], [1, 2], [2, 3], [3, 4]] 的 numpy 数组
        expected = np.array([[0, 1], [1, 2], [2, 3], [3, 4]])
        # 断言 arr_view 等于预期结果 expected
        assert_array_equal(arr_view, expected)
    # 定义一个测试方法,测试二维情况下的滑动窗口视图
    def test_2d(self):
        # 创建两个二维的坐标网格,i 为行,j 为列,形状为 (3, 4)
        i, j = np.ogrid[:3, :4]
        # 根据公式 arr = 10*i + j 创建一个二维数组 arr
        arr = 10*i + j
        # 定义窗口的形状为 (2, 2)
        shape = (2, 2)
        # 使用 sliding_window_view 函数创建 arr 的滑动窗口视图
        arr_view = sliding_window_view(arr, shape)
        # 期望的结果数组,包含所有可能的滑动窗口视图
        expected = np.array([[[[0, 1], [10, 11]],
                              [[1, 2], [11, 12]],
                              [[2, 3], [12, 13]]],
                             [[[10, 11], [20, 21]],
                              [[11, 12], [21, 22]],
                              [[12, 13], [22, 23]]]])
        # 断言 arr_view 与期望结果 expected 相等
        assert_array_equal(arr_view, expected)

    # 定义一个测试方法,测试带有指定轴的二维滑动窗口视图
    def test_2d_with_axis(self):
        # 创建两个二维的坐标网格,i 为行,j 为列,形状为 (3, 4)
        i, j = np.ogrid[:3, :4]
        # 根据公式 arr = 10*i + j 创建一个二维数组 arr
        arr = 10*i + j
        # 使用 sliding_window_view 函数创建 arr 的滑动窗口视图,指定窗口形状为 3,轴为 0
        arr_view = sliding_window_view(arr, 3, 0)
        # 期望的结果数组,包含按照指定轴生成的滑动窗口视图
        expected = np.array([[[0, 10, 20],
                              [1, 11, 21],
                              [2, 12, 22],
                              [3, 13, 23]]])
        # 断言 arr_view 与期望结果 expected 相等
        assert_array_equal(arr_view, expected)

    # 定义一个测试方法,测试带有重复轴的二维滑动窗口视图
    def test_2d_repeated_axis(self):
        # 创建两个二维的坐标网格,i 为行,j 为列,形状为 (3, 4)
        i, j = np.ogrid[:3, :4]
        # 根据公式 arr = 10*i + j 创建一个二维数组 arr
        arr = 10*i + j
        # 使用 sliding_window_view 函数创建 arr 的滑动窗口视图,指定窗口形状为 (2, 3),轴为 (1, 1)
        arr_view = sliding_window_view(arr, (2, 3), (1, 1))
        # 期望的结果数组,包含按照指定轴生成的滑动窗口视图
        expected = np.array([[[[0, 1, 2],
                               [1, 2, 3]]],
                             [[[10, 11, 12],
                               [11, 12, 13]]],
                             [[[20, 21, 22],
                               [21, 22, 23]]]])
        # 断言 arr_view 与期望结果 expected 相等
        assert_array_equal(arr_view, expected)

    # 定义一个测试方法,测试在不指定轴的情况下的二维滑动窗口视图
    def test_2d_without_axis(self):
        # 创建两个二维的坐标网格,i 为行,j 为列,形状为 (4, 4)
        i, j = np.ogrid[:4, :4]
        # 根据公式 arr = 10*i + j 创建一个二维数组 arr
        arr = 10*i + j
        # 定义窗口的形状为 (2, 3)
        shape = (2, 3)
        # 使用 sliding_window_view 函数创建 arr 的滑动窗口视图
        arr_view = sliding_window_view(arr, shape)
        # 期望的结果数组,包含所有可能的滑动窗口视图
        expected = np.array([[[[0, 1, 2], [10, 11, 12]],
                              [[1, 2, 3], [11, 12, 13]]],
                             [[[10, 11, 12], [20, 21, 22]],
                              [[11, 12, 13], [21, 22, 23]]],
                             [[[20, 21, 22], [30, 31, 32]],
                              [[21, 22, 23], [31, 32, 33]]]])
        # 断言 arr_view 与期望结果 expected 相等
        assert_array_equal(arr_view, expected)

    # 定义一个测试方法,测试各种错误情况下的异常处理
    def test_errors(self):
        # 创建两个二维的坐标网格,i 为行,j 为列,形状为 (4, 4)
        i, j = np.ogrid[:4, :4]
        # 根据公式 arr = 10*i + j 创建一个二维数组 arr
        arr = 10*i + j
        # 测试 ValueError 异常,窗口形状包含负值
        with pytest.raises(ValueError, match='cannot contain negative values'):
            sliding_window_view(arr, (-1, 3))
        # 测试 ValueError 异常,未为所有维度提供窗口形状
        with pytest.raises(
                ValueError,
                match='must provide window_shape for all dimensions of `x`'):
            sliding_window_view(arr, (1,))
        # 测试 ValueError 异常,窗口形状与轴的长度不匹配
        with pytest.raises(
                ValueError,
                match='Must provide matching length window_shape and axis'):
            sliding_window_view(arr, (1, 3, 4), axis=(0, 1))
        # 测试 ValueError 异常,窗口形状大于输入数组的尺寸
        with pytest.raises(
                ValueError,
                match='window shape cannot be larger than input array'):
            sliding_window_view(arr, (5, 5))
    # 定义一个名为 test_writeable 的测试方法,用于测试 sliding_window_view 函数的 writeable 参数
    def test_writeable(self):
        # 创建一个包含 0 到 4 的 NumPy 数组
        arr = np.arange(5)
        # 使用 sliding_window_view 函数创建一个窗口视图,窗口大小为 2,设置为不可写
        view = sliding_window_view(arr, 2, writeable=False)
        # 断言视图不可写
        assert_(not view.flags.writeable)
        # 使用 pytest 检查写入不可写视图的操作是否引发 ValueError 异常,并检查异常消息
        with pytest.raises(
                ValueError,
                match='assignment destination is read-only'):
            view[0, 0] = 3
        # 使用 sliding_window_view 函数创建一个窗口视图,窗口大小为 2,设置为可写
        view = sliding_window_view(arr, 2, writeable=True)
        # 断言视图可写
        assert_(view.flags.writeable)
        # 修改可写视图的元素
        view[0, 1] = 3
        # 断言原始数组被正确修改
        assert_array_equal(arr, np.array([0, 3, 2, 3, 4]))

    # 定义一个名为 test_subok 的测试方法,用于测试 sliding_window_view 函数的 subok 参数
    def test_subok(self):
        # 定义一个继承自 np.ndarray 的子类 MyArray
        class MyArray(np.ndarray):
            pass

        # 创建一个包含 0 到 4 的 NumPy 数组,并将其视图转换为 MyArray 类型
        arr = np.arange(5).view(MyArray)
        # 断言不使用 subok 参数时,生成的窗口视图不是 MyArray 类型
        assert_(not isinstance(sliding_window_view(arr, 2,
                                                   subok=False),
                               MyArray))
        # 断言使用 subok 参数时,生成的窗口视图是 MyArray 类型
        assert_(isinstance(sliding_window_view(arr, 2, subok=True), MyArray))
        # 默认行为下,断言生成的窗口视图不是 MyArray 类型
        assert_(not isinstance(sliding_window_view(arr, 2), MyArray))
def as_strided_writeable():
    # 创建一个长度为10的全为1的 NumPy 数组
    arr = np.ones(10)
    # 创建一个 arr 的视图,但设置为不可写
    view = as_strided(arr, writeable=False)
    # 断言视图的可写标志为 False
    assert_(not view.flags.writeable)

    # 检查可写性是否正常:
    view = as_strided(arr, writeable=True)
    # 断言视图的可写标志为 True
    assert_(view.flags.writeable)
    # 修改视图中的所有元素为3
    view[...] = 3
    # 断言原始数组 arr 的所有元素都被修改为3
    assert_array_equal(arr, np.full_like(arr, 3))

    # 测试只读模式下的情况:
    arr.flags.writeable = False
    view = as_strided(arr, writeable=False)
    # 创建一个 arr 的视图,此时设置为可写
    view = as_strided(arr, writeable=True)
    # 断言视图的可写标志为 False
    assert_(not view.flags.writeable)


class VerySimpleSubClass(np.ndarray):
    def __new__(cls, *args, **kwargs):
        # 创建一个 np.array,并将其视图化为当前类的实例
        return np.array(*args, subok=True, **kwargs).view(cls)


class SimpleSubClass(VerySimpleSubClass):
    def __new__(cls, *args, **kwargs):
        # 创建一个 np.array,并将其视图化为当前类的实例
        self = np.array(*args, subok=True, **kwargs).view(cls)
        # 为实例添加附加信息
        self.info = 'simple'
        return self

    def __array_finalize__(self, obj):
        # 如果 obj 中有 'info' 属性,则将其赋给当前实例的 'info' 属性
        self.info = getattr(obj, 'info', '') + ' finalized'


def test_subclasses():
    # 测试只有当 subok=True 时,子类才能被保留
    a = VerySimpleSubClass([1, 2, 3, 4])
    assert_(type(a) is VerySimpleSubClass)
    # 创建一个 a 的视图,但不指定类型,应返回基本的 ndarray 类型
    a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,))
    assert_(type(a_view) is np.ndarray)
    # 创建一个 a 的视图,并指定 subok=True,应返回子类 VerySimpleSubClass 的类型
    a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
    assert_(type(a_view) is VerySimpleSubClass)
    
    # 测试如果子类定义了 __array_finalize__ 方法,它会被调用
    a = SimpleSubClass([1, 2, 3, 4])
    # 创建一个 a 的视图,并指定 subok=True,应返回子类 SimpleSubClass 的类型
    a_view = as_strided(a, shape=(2,), strides=(2 * a.itemsize,), subok=True)
    assert_(type(a_view) is SimpleSubClass)
    assert_(a_view.info == 'simple finalized')

    # 类似的测试用于 broadcast_arrays
    b = np.arange(len(a)).reshape(-1, 1)
    # 对 a 和 b 进行 broadcast,预期返回的是普通的 ndarray 类型
    a_view, b_view = broadcast_arrays(a, b)
    assert_(type(a_view) is np.ndarray)
    assert_(type(b_view) is np.ndarray)
    assert_(a_view.shape == b_view.shape)
    # 对 a 和 b 进行 broadcast,指定 subok=True,预期返回的是子类 SimpleSubClass 类型
    a_view, b_view = broadcast_arrays(a, b, subok=True)
    assert_(type(a_view) is SimpleSubClass)
    assert_(a_view.info == 'simple finalized')
    assert_(type(b_view) is np.ndarray)
    assert_(a_view.shape == b_view.shape)

    # 还有对 broadcast_to 的测试
    shape = (2, 4)
    # 对 a 进行 broadcast 到指定 shape,预期返回的是普通的 ndarray 类型
    a_view = broadcast_to(a, shape)
    assert_(type(a_view) is np.ndarray)
    assert_(a_view.shape == shape)
    # 对 a 进行 broadcast 到指定 shape,指定 subok=True,预期返回的是子类 SimpleSubClass 类型
    a_view = broadcast_to(a, shape, subok=True)
    assert_(type(a_view) is SimpleSubClass)
    assert_(a_view.info == 'simple finalized')
    assert_(a_view.shape == shape)


def test_writeable():
    # broadcast_to 应返回一个只读数组
    original = np.array([1, 2, 3])
    result = broadcast_to(original, (2, 3))
    assert_equal(result.flags.writeable, False)
    assert_raises(ValueError, result.__setitem__, slice(None), 0)

    # 但 broadcast_arrays 的结果需要是可写的,以保持向后兼容性
    test_cases = [((False,), broadcast_arrays(original,)),
                  ((True, False), broadcast_arrays(0, original))]
    for is_broadcast, results in test_cases:
        for array_is_broadcast, result in zip(is_broadcast, results):
            # 遍历测试用例,每个测试用例包括是否广播和结果列表
            # 如果数组被广播,则执行以下操作:
            if array_is_broadcast:
                # 发出未来版本警告,即将改为 False
                with assert_warns(FutureWarning):
                    # 检查结果的可写标志是否为 True
                    assert_equal(result.flags.writeable, True)
                # 发出弃用警告
                with assert_warns(DeprecationWarning):
                    # 将结果数组全部设为 0
                    result[:] = 0
                # 警告未发出,写入数组会重置可写状态为 True
                assert_equal(result.flags.writeable, True)
            else:
                # 没有警告:
                # 检查结果的可写标志是否为 True
                assert_equal(result.flags.writeable, True)

    for results in [broadcast_arrays(original),
                    broadcast_arrays(0, original)]:
        for result in results:
            # 重置 warn_on_write 弃用警告
            result.flags.writeable = True
            # 检查:没有警告被发出
            assert_equal(result.flags.writeable, True)
            # 将结果数组全部设为 0
            result[:] = 0

    # 保持只读输入的只读状态
    original.flags.writeable = False
    # 使用广播数组函数返回的结果
    _, result = broadcast_arrays(0, original)
    # 检查结果的可写标志是否为 False
    assert_equal(result.flags.writeable, False)

    # GH6491 的回归测试
    # 创建一个形状为 (2,),步幅为 [0] 的数组
    shape = (2,)
    strides = [0]
    tricky_array = as_strided(np.array(0), shape, strides)
    # 创建一个形状为 (1,) 的零数组
    other = np.zeros((1,))
    # 执行广播数组操作
    first, second = broadcast_arrays(tricky_array, other)
    # 检查第一个和第二个数组的形状是否相同
    assert_(first.shape == second.shape)
def test_writeable_memoryview():
    # 创建一个原始的 NumPy 数组,包含整数 [1, 2, 3]
    original = np.array([1, 2, 3])

    # 定义测试用例列表,每个元素是一个元组,包含是否进行广播的布尔值和广播后的结果
    test_cases = [((False, ), broadcast_arrays(original,)),
                  ((True, False), broadcast_arrays(0, original))]
    # 遍历测试用例
    for is_broadcast, results in test_cases:
        # 遍历每个结果和对应的是否广播标志
        for array_is_broadcast, result in zip(is_broadcast, results):
            # 如果数组是广播的,断言结果的 memoryview 是只读的
            if array_is_broadcast:
                # 此处在未来版本中会更改为 False
                assert memoryview(result).readonly
            else:
                # 如果数组不是广播的,断言结果的 memoryview 不是只读的
                assert not memoryview(result).readonly


def test_reference_types():
    # 创建一个包含单个字符 'a' 的对象数组
    input_array = np.array('a', dtype=object)
    # 创建预期的对象数组,包含三个 'a' 字符串
    expected = np.array(['a'] * 3, dtype=object)
    # 使用 broadcast_to 函数将输入数组广播到形状为 (3,) 的数组
    actual = broadcast_to(input_array, (3,))
    # 断言广播后的数组与预期结果相等
    assert_array_equal(expected, actual)

    # 使用 broadcast_arrays 函数广播输入数组和一个包含三个 1 的数组
    actual, _ = broadcast_arrays(input_array, np.ones(3))
    # 断言广播后的数组与预期结果相等
    assert_array_equal(expected, actual)

.\numpy\numpy\lib\tests\test_twodim_base.py

"""Test functions for matrix module

"""
# 导入必要的模块和函数
from numpy.testing import (
    assert_equal, assert_array_equal, assert_array_max_ulp,
    assert_array_almost_equal, assert_raises, assert_
)
from numpy import (
    arange, add, fliplr, flipud, zeros, ones, eye, array, diag, histogram2d,
    tri, mask_indices, triu_indices, triu_indices_from, tril_indices,
    tril_indices_from, vander,
)
import numpy as np  # 导入 NumPy 库

import pytest  # 导入 pytest 模块


def get_mat(n):
    # 创建一个 n x n 的矩阵,其中每个元素是其行和列索引的和
    data = arange(n)
    data = add.outer(data, data)
    return data


class TestEye:
    def test_basic(self):
        # 测试生成单位矩阵
        assert_equal(eye(4),
                     array([[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 0],
                            [0, 0, 0, 1]]))

        # 测试生成指定数据类型的单位矩阵
        assert_equal(eye(4, dtype='f'),
                     array([[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 0],
                            [0, 0, 0, 1]], 'f'))

        # 测试生成布尔类型的单位矩阵
        assert_equal(eye(3) == 1,
                     eye(3, dtype=bool))

    def test_uint64(self):
        # 测试对于 uint64 类型的单位矩阵生成
        # gh-9982 的回归测试
        assert_equal(eye(np.uint64(2), dtype=int), array([[1, 0], [0, 1]]))
        assert_equal(eye(np.uint64(2), M=np.uint64(4), k=np.uint64(1)),
                     array([[0, 1, 0, 0], [0, 0, 1, 0]]))

    def test_diag(self):
        # 测试生成具有偏移量 k 的单位对角矩阵
        assert_equal(eye(4, k=1),
                     array([[0, 1, 0, 0],
                            [0, 0, 1, 0],
                            [0, 0, 0, 1],
                            [0, 0, 0, 0]]))

        assert_equal(eye(4, k=-1),
                     array([[0, 0, 0, 0],
                            [1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 0]]))

    def test_2d(self):
        # 测试生成二维单位矩阵
        assert_equal(eye(4, 3),
                     array([[1, 0, 0],
                            [0, 1, 0],
                            [0, 0, 1],
                            [0, 0, 0]]))

        assert_equal(eye(3, 4),
                     array([[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 0]]))

    def test_diag2d(self):
        # 测试生成具有偏移量 k 的二维单位对角矩阵
        assert_equal(eye(3, 4, k=2),
                     array([[0, 0, 1, 0],
                            [0, 0, 0, 1],
                            [0, 0, 0, 0]]))

        assert_equal(eye(4, 3, k=-2),
                     array([[0, 0, 0],
                            [0, 0, 0],
                            [1, 0, 0],
                            [0, 1, 0]]))
    # 测试函数,验证 `eye` 函数的边界条件
    def test_eye_bounds(self):
        # 断言:生成 2x2 的单位矩阵,主对角线向上偏移1个位置
        assert_equal(eye(2, 2, 1), [[0, 1], [0, 0]])
        # 断言:生成 2x2 的单位矩阵,主对角线向下偏移1个位置
        assert_equal(eye(2, 2, -1), [[0, 0], [1, 0]])
        # 断言:生成 2x2 的单位矩阵,但超过矩阵维度,返回全0矩阵
        assert_equal(eye(2, 2, 2), [[0, 0], [0, 0]])
        # 断言:生成 2x2 的单位矩阵,主对角线向下偏移超过矩阵维度,返回全0矩阵
        assert_equal(eye(2, 2, -2), [[0, 0], [0, 0]])
        # 断言:生成 3x2 的单位矩阵,但超过矩阵维度,返回全0矩阵
        assert_equal(eye(3, 2, 2), [[0, 0], [0, 0], [0, 0]])
        # 断言:生成 3x2 的单位矩阵,主对角线向上偏移1个位置
        assert_equal(eye(3, 2, 1), [[0, 1], [0, 0], [0, 0]])
        # 断言:生成 3x2 的单位矩阵,主对角线向下偏移1个位置,且超过矩阵维度,返回全0矩阵
        assert_equal(eye(3, 2, -1), [[0, 0], [1, 0], [0, 1]])
        # 断言:生成 3x2 的单位矩阵,主对角线向下偏移2个位置,且超过矩阵维度,返回全0矩阵
        assert_equal(eye(3, 2, -2), [[0, 0], [0, 0], [1, 0]])
        # 断言:生成 3x2 的单位矩阵,主对角线向下偏移3个位置,超过矩阵维度,返回全0矩阵
        assert_equal(eye(3, 2, -3), [[0, 0], [0, 0], [0, 0]])

    # 测试函数,验证 `eye` 函数对字符串类型的处理
    def test_strings(self):
        # 断言:生成 2x2 的单位矩阵,数据类型为字节串,字符串长度为3
        assert_equal(eye(2, 2, dtype='S3'),
                     [[b'1', b''], [b'', b'1']])

    # 测试函数,验证 `eye` 函数对布尔类型的处理
    def test_bool(self):
        # 断言:生成 2x2 的单位矩阵,数据类型为布尔型
        assert_equal(eye(2, 2, dtype=bool), [[True, False], [False, True]])

    # 测试函数,验证 `eye` 函数对矩阵顺序的处理
    def test_order(self):
        # 生成 4x3 的单位矩阵,主对角线向上偏移1个位置,以 C 顺序存储
        mat_c = eye(4, 3, k=-1)
        # 生成 4x3 的单位矩阵,主对角线向上偏移1个位置,以 Fortran(F)顺序存储
        mat_f = eye(4, 3, k=-1, order='F')
        # 断言:验证两种顺序下生成的矩阵内容相同
        assert_equal(mat_c, mat_f)
        # 断言:验证以 C 顺序存储的矩阵标志
        assert mat_c.flags.c_contiguous
        # 断言:验证以 C 顺序存储的矩阵不是以 Fortran 顺序存储
        assert not mat_c.flags.f_contiguous
        # 断言:验证以 Fortran 顺序存储的矩阵不是以 C 顺序存储
        assert not mat_f.flags.c_contiguous
        # 断言:验证以 Fortran 顺序存储的矩阵标志
        assert mat_f.flags.f_contiguous
class TestDiag:
    # 测试对角向量情况
    def test_vector(self):
        # 创建一个整型的数组,包含 0 到 400 之间的数,步长为 100
        vals = (100 * arange(5)).astype('l')
        # 创建一个5x5的全零数组
        b = zeros((5, 5))
        # 遍历范围为5的循环
        for k in range(5):
            # 在对角线上分别赋值
            b[k, k] = vals[k]
        # 断言对角线函数对vals和b的值相等
        assert_equal(diag(vals), b)
        # 创建一个7x7的全零数组
        b = zeros((7, 7))
        # 复制b数组
        c = b.copy()
        # 遍历范围为5的循环
        for k in range(5):
            # 在对角线上的不同位置赋值
            b[k, k + 2] = vals[k]
            c[k + 2, k] = vals[k]
        # 断言对角线函数对vals和b的值相等
        assert_equal(diag(vals, k=2), b)
        # 断言对角线函数对vals和c的值相等
        assert_equal(diag(vals, k=-2), c)

    # 测试对角矩阵情况
    def test_matrix(self, vals=None):
        # 如果vals为空,则获取一个5x5的矩阵并转换为整型数组
        if vals is None:
            vals = (100 * get_mat(5) + 1).astype('l')
        # 创建一个包含5个元素的全零数组
        b = zeros((5,))
        # 遍历范围为5的循环
        for k in range(5):
            # 在对角线上分别赋值
            b[k] = vals[k, k]
        # 断言对角线函数对vals和b的值相等
        assert_equal(diag(vals), b)
        # 将b数组元素全部置零
        b = b * 0
        # 遍历范围为3的循环
        for k in range(3):
            # 在对角线上的不同位置赋值
            b[k] = vals[k, k + 2]
        # 断言对角线函数对vals和b的前三个值相等
        assert_equal(diag(vals, 2), b[:3])
        # 遍历范围为3的循环
        for k in range(3):
            # 在对角线上的不同位置赋值
            b[k] = vals[k + 2, k]
        # 断言对角线函数对vals和b的前三个值相等
        assert_equal(diag(vals, -2), b[:3])

    # 测试Fortran顺序的情况
    def test_fortran_order(self):
        # 创建一个Fortran顺序的5x5矩阵,并转换为整型数组
        vals = array((100 * get_mat(5) + 1), order='F', dtype='l')
        # 调用test_matrix函数进行测试
        self.test_matrix(vals)

    # 测试对角线边界情况
    def test_diag_bounds(self):
        # 创建一个列表A
        A = [[1, 2], [3, 4], [5, 6]]
        # 断言对角线函数应返回空列表
        assert_equal(diag(A, k=2), [])
        # 断言对角线函数应返回包含值2的列表
        assert_equal(diag(A, k=1), [2])
        # 断言对角线函数应返回包含值1和4的列表
        assert_equal(diag(A, k=0), [1, 4])
        # 断言对角线函数应返回包含值3和6的列表
        assert_equal(diag(A, k=-1), [3, 6])
        # 断言对角线函数应返回包含值5的列表
        assert_equal(diag(A, k=-2), [5])
        # 断言对角线函数应返回空列表
        assert_equal(diag(A, k=-3), [])

    # 测试失败的情况
    def test_failure(self):
        # 断言当传入一个三重嵌套的列表时,应引发ValueError异常
        assert_raises(ValueError, diag, [[[1]]])


class TestFliplr:
    # 测试基本情况
    def test_basic(self):
        # 断言当传入一个全1的4x4数组时,应引发ValueError异常
        assert_raises(ValueError, fliplr, ones(4))
        # 获取一个4x4矩阵
        a = get_mat(4)
        # 创建一个在水平方向翻转后的矩阵b
        b = a[:, ::-1]
        # 断言翻转函数对a和b的值相等
        assert_equal(fliplr(a), b)
        # 创建一个列表a
        a = [[0, 1, 2],
             [3, 4, 5]]
        # 创建一个在水平方向翻转后的列表b
        b = [[2, 1, 0],
             [5, 4, 3]]
        # 断言翻转函数对a和b的值相等
        assert_equal(fliplr(a), b)


class TestFlipud:
    # 测试基本情况
    def test_basic(self):
        # 获取一个4x4矩阵
        a = get_mat(4)
        # 创建一个在垂直方向翻转后的矩阵b
        b = a[::-1, :]
        # 断言翻转函数对a和b的值相等
        assert_equal(flipud(a), b)
        # 创建一个列表a
        a = [[0, 1, 2],
             [3, 4, 5]]
        # 创建一个在垂直方向翻转后的列表b
        b = [[3, 4, 5],
             [0, 1, 2]]
        # 断言翻转函数对a和b的值相等
        assert_equal(flipud(a), b)


class TestHistogram2d:
    pass
    # 测试简单的二维直方图计算,用于测试基本功能
    def test_simple(self):
        x = array(
            [0.41702200, 0.72032449, 1.1437481e-4, 0.302332573, 0.146755891])
        y = array(
            [0.09233859, 0.18626021, 0.34556073, 0.39676747, 0.53881673])
        # 在指定区间内生成 x 和 y 的边界
        xedges = np.linspace(0, 1, 10)
        yedges = np.linspace(0, 1, 10)
        # 计算二维直方图,并获取直方图数据
        H = histogram2d(x, y, (xedges, yedges))[0]
        # 预期的正确结果
        answer = array(
            [[0, 0, 0, 1, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 1, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0],
             [1, 0, 1, 0, 0, 0, 0, 0, 0],
             [0, 1, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0]])
        # 断言直方图数据与预期结果的相等性
        assert_array_equal(H.T, answer)
        # 重新计算二维直方图,使用不同的参数形式
        H = histogram2d(x, y, xedges)[0]
        # 再次断言直方图数据与预期结果的相等性
        assert_array_equal(H.T, answer)
        # 执行直方图计算,并获取直方图数据、以及新的边界
        H, xedges, yedges = histogram2d(list(range(10)), list(range(10)))
        # 断言直方图数据与单位矩阵的相等性
        assert_array_equal(H, eye(10, 10))
        # 断言 x 边界与预期结果的相等性
        assert_array_equal(xedges, np.linspace(0, 9, 11))
        # 断言 y 边界与预期结果的相等性
        assert_array_equal(yedges, np.linspace(0, 9, 11))

    # 测试不对称数据的二维直方图计算,验证不对称情况下的功能
    def test_asym(self):
        x = array([1, 1, 2, 3, 4, 4, 4, 5])
        y = array([1, 3, 2, 0, 1, 2, 3, 4])
        # 计算带有额外参数的二维直方图,包括范围、密度等设置
        H, xed, yed = histogram2d(
            x, y, (6, 5), range=[[0, 6], [0, 5]], density=True)
        # 预期的正确结果
        answer = array(
            [[0., 0, 0, 0, 0],
             [0, 1, 0, 1, 0],
             [0, 0, 1, 0, 0],
             [1, 0, 0, 0, 0],
             [0, 1, 1, 1, 0],
             [0, 0, 0, 0, 1]])
        # 断言直方图数据与预期结果的近似相等性
        assert_array_almost_equal(H, answer/8., 3)
        # 断言 x 边界与预期结果的相等性
        assert_array_equal(xed, np.linspace(0, 6, 7))
        # 断言 y 边界与预期结果的相等性
        assert_array_equal(yed, np.linspace(0, 5, 6))

    # 测试密度计算的二维直方图功能
    def test_density(self):
        x = array([1, 2, 3, 1, 2, 3, 1, 2, 3])
        y = array([1, 1, 1, 2, 2, 2, 3, 3, 3])
        # 计算密度归一化的二维直方图,使用自定义的边界
        H, xed, yed = histogram2d(
            x, y, [[1, 2, 3, 5], [1, 2, 3, 5]], density=True)
        # 预期的正确结果
        answer = array([[1, 1, .5],
                        [1, 1, .5],
                        [.5, .5, .25]])/9.
        # 断言直方图数据与预期结果的近似相等性
        assert_array_almost_equal(H, answer, 3)

    # 测试全部数据为异常值的情况
    def test_all_outliers(self):
        r = np.random.rand(100) + 1. + 1e6  # histogramdd rounds by decimal=6
        # 计算异常值情况下的二维直方图,期望所有元素为零
        H, xed, yed = histogram2d(r, r, (4, 5), range=([0, 1], [0, 1]))
        # 断言直方图数据全为零
        assert_array_equal(H, 0)

    # 测试空数据集的情况
    def test_empty(self):
        # 测试空数据的二维直方图计算,使用自定义的边界
        a, edge1, edge2 = histogram2d([], [], bins=([0, 1], [0, 1]))
        # 断言直方图数据与预期结果的最大误差不超过单位最小浮点数
        assert_array_max_ulp(a, array([[0.]]))

        # 再次测试空数据的二维直方图计算,使用相同的边界数量
        a, edge1, edge2 = histogram2d([], [], bins=4)
        # 断言直方图数据与全零矩阵的最大误差不超过单位最小浮点数
        assert_array_max_ulp(a, np.zeros((4, 4)))
    # 定义测试方法,用于验证不同二进制参数组合的直方图计算
    def test_binparameter_combination(self):
        # 定义输入的两个一维数组 x 和 y
        x = array(
            [0, 0.09207008, 0.64575234, 0.12875982, 0.47390599,
             0.59944483, 1])
        y = array(
            [0, 0.14344267, 0.48988575, 0.30558665, 0.44700682,
             0.15886423, 1])
        # 定义直方图的边界
        edges = (0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1)
        # 调用 histogram2d 函数计算二维直方图,返回直方图 H 和边界 xe, ye
        H, xe, ye = histogram2d(x, y, (edges, 4))
        # 预期的二维数组结果
        answer = array(
            [[2., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 1., 0., 0.],
             [1., 0., 0., 0.],
             [0., 1., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 1.]])
        # 使用 assert_array_equal 断言函数检查 H 是否与预期结果 answer 相等
        assert_array_equal(H, answer)
        # 使用 assert_array_equal 断言函数检查 ye 是否与预期结果相等
        assert_array_equal(ye, array([0., 0.25, 0.5, 0.75, 1]))
        
        # 重新计算直方图,但是交换了 bins 的顺序
        H, xe, ye = histogram2d(x, y, (4, edges))
        # 更新预期的二维数组结果
        answer = array(
            [[1., 1., 0., 1., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
             [0., 1., 0., 0., 1., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]])
        # 使用 assert_array_equal 断言函数检查 H 是否与更新后的预期结果 answer 相等
        assert_array_equal(H, answer)
        # 使用 assert_array_equal 断言函数检查 xe 是否与预期结果相等
        assert_array_equal(xe, array([0., 0.25, 0.5, 0.75, 1]))

    # 定义测试方法,用于验证直方图函数的分发机制
    def test_dispatch(self):
        # 定义一个类 ShouldDispatch,实现了 __array_function__ 方法
        class ShouldDispatch:
            def __array_function__(self, function, types, args, kwargs):
                # 返回传入的 types, args 和 kwargs
                return types, args, kwargs

        # 初始化一个列表 xy
        xy = [1, 2]
        # 创建 ShouldDispatch 的实例 s_d
        s_d = ShouldDispatch()
        
        # 调用 histogram2d 函数,验证对 s_d 和 xy 的直方图计算调用是否被分发
        r = histogram2d(s_d, xy)
        # 使用 assert_ 函数断言 r 是否等于预期的结果元组
        assert_(r == ((ShouldDispatch,), (s_d, xy), {}))
        
        # 再次调用 histogram2d 函数,验证对 xy 和 s_d 的直方图计算调用是否被分发
        r = histogram2d(xy, s_d)
        # 使用 assert_ 函数断言 r 是否等于预期的结果元组
        assert_(r == ((ShouldDispatch,), (xy, s_d), {}))
        
        # 再次调用 histogram2d 函数,验证对 xy 和 xy(使用 bins=s_d 参数)的直方图计算调用是否被分发
        r = histogram2d(xy, xy, bins=s_d)
        # 使用 assert_ 函数断言 r 是否等于预期的结果元组
        assert_(r, ((ShouldDispatch,), (xy, xy), dict(bins=s_d)))
        
        # 再次调用 histogram2d 函数,验证对 xy 和 xy(使用 bins=[s_d, 5] 参数)的直方图计算调用是否被分发
        r = histogram2d(xy, xy, bins=[s_d, 5])
        # 使用 assert_ 函数断言 r 是否等于预期的结果元组
        assert_(r, ((ShouldDispatch,), (xy, xy), dict(bins=[s_d, 5])))
        
        # 使用 assert_raises 函数验证调用 histogram2d 函数时,传入 bins=[s_d] 参数是否会引发异常
        assert_raises(Exception, histogram2d, xy, xy, bins=[s_d])
        
        # 再次调用 histogram2d 函数,验证对 xy 和 xy(使用 weights=s_d 参数)的直方图计算调用是否被分发
        r = histogram2d(xy, xy, weights=s_d)
        # 使用 assert_ 函数断言 r 是否等于预期的结果元组
        assert_(r, ((ShouldDispatch,), (xy, xy), dict(weights=s_d)))

    # 使用 pytest 的参数化装饰器,定义测试方法,用于验证不同长度的输入数组引发的 ValueError 异常
    @pytest.mark.parametrize(("x_len", "y_len"), [(10, 11), (20, 19)])
    def test_bad_length(self, x_len, y_len):
        # 创建两个长度不同的全 1 数组 x 和 y
        x, y = np.ones(x_len), np.ones(y_len)
        # 使用 assertRaises 函数验证调用 histogram2d 函数时,传入不同长度的 x 和 y 是否会引发 ValueError 异常
        with pytest.raises(ValueError,
                           match='x and y must have the same length.'):
            histogram2d(x, y)
# 定义一个测试类 TestTri
class TestTri:
    # 测试函数,测试 np.tri 函数的返回值是否正确
    def test_dtype(self):
        # 创建一个三阶单位矩阵,用于与 np.tri(3) 的结果比较
        out = array([[1, 0, 0],
                     [1, 1, 0],
                     [1, 1, 1]])
        # 断言 np.tri(3) 返回的结果与 out 相等
        assert_array_equal(tri(3), out)
        # 断言使用 bool 类型调用 np.tri(3) 的结果与 out.astype(bool) 相等
        assert_array_equal(tri(3, dtype=bool), out.astype(bool))


# 测试函数,测试 np.tril 和 np.triu 在二维数组上的行为
def test_tril_triu_ndim2():
    # 遍历所有浮点数和整数类型
    for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']:
        # 创建一个全为 1 的二维数组,数据类型为当前循环的类型 dtype
        a = np.ones((2, 2), dtype=dtype)
        # 使用 np.tril 函数生成 a 的下三角矩阵
        b = np.tril(a)
        # 使用 np.triu 函数生成 a 的上三角矩阵
        c = np.triu(a)
        # 断言 np.tril(a) 的结果与预期的下三角矩阵 [[1, 0], [1, 1]] 相等
        assert_array_equal(b, [[1, 0], [1, 1]])
        # 断言 np.triu(a) 的结果与 np.tril(a) 的转置相等
        assert_array_equal(c, b.T)
        # 断言 np.tril(a) 和 np.triu(a) 的数据类型与 a 相同
        assert_equal(b.dtype, a.dtype)
        assert_equal(c.dtype, a.dtype)


# 测试函数,测试 np.tril 和 np.triu 在三维数组上的行为
def test_tril_triu_ndim3():
    # 遍历所有浮点数和整数类型
    for dtype in np.typecodes['AllFloat'] + np.typecodes['AllInteger']:
        # 创建一个三维数组 a,数据类型为当前循环的类型 dtype
        a = np.array([
            [[1, 1], [1, 1]],
            [[1, 1], [1, 0]],
            [[1, 1], [0, 0]],
            ], dtype=dtype)
        # 预期的 a 的下三角矩阵
        a_tril_desired = np.array([
            [[1, 0], [1, 1]],
            [[1, 0], [1, 0]],
            [[1, 0], [0, 0]],
            ], dtype=dtype)
        # 预期的 a 的上三角矩阵
        a_triu_desired = np.array([
            [[1, 1], [0, 1]],
            [[1, 1], [0, 0]],
            [[1, 1], [0, 0]],
            ], dtype=dtype)
        # 使用 np.triu 函数生成 a 的上三角矩阵
        a_triu_observed = np.triu(a)
        # 使用 np.tril 函数生成 a 的下三角矩阵
        a_tril_observed = np.tril(a)
        # 断言 np.triu(a) 的结果与预期的上三角矩阵相等
        assert_array_equal(a_triu_observed, a_triu_desired)
        # 断言 np.tril(a) 的结果与预期的下三角矩阵相等
        assert_array_equal(a_tril_observed, a_tril_desired)
        # 断言 np.triu(a) 和 np.tril(a) 的数据类型与 a 相同
        assert_equal(a_triu_observed.dtype, a.dtype)
        assert_equal(a_tril_observed.dtype, a.dtype)


# 测试函数,测试 np.tril 和 np.triu 处理含有无穷大值的数组时的行为
def test_tril_triu_with_inf():
    # 创建一个包含无穷大值的数组 arr
    arr = np.array([[1, 1, np.inf],
                    [1, 1, 1],
                    [np.inf, 1, 1]])
    # 预期的 arr 的下三角矩阵
    out_tril = np.array([[1, 0, 0],
                         [1, 1, 0],
                         [np.inf, 1, 1]])
    # 预期的 arr 的上三角矩阵
    out_triu = out_tril.T
    # 断言 np.triu(arr) 的结果与预期的上三角矩阵相等
    assert_array_equal(np.triu(arr), out_triu)
    # 断言 np.tril(arr) 的结果与预期的下三角矩阵相等
    assert_array_equal(np.tril(arr), out_tril)


# 测试函数,测试 np.tril 和 np.triu 返回值的数据类型与输入数组相同
def test_tril_triu_dtype():
    # 遍历所有数据类型
    for c in np.typecodes['All']:
        # 跳过 'V' 类型
        if c == 'V':
            continue
        # 创建一个全为 0 的 3x3 数组 arr,数据类型为当前循环的类型 c
        arr = np.zeros((3, 3), dtype=c)
        # 断言 np.triu(arr) 的数据类型与 arr 相同
        assert_equal(np.triu(arr).dtype, arr.dtype)
        # 断言 np.tril(arr) 的数据类型与 arr 相同
        assert_equal(np.tril(arr).dtype, arr.dtype)

    # 检查特殊情况
    # 创建一个 datetime64 类型的数组 arr
    arr = np.array([['2001-01-01T12:00', '2002-02-03T13:56'],
                    ['2004-01-01T12:00', '2003-01-03T13:45']],
                   dtype='datetime64')
    # 断言 np.triu(arr) 的数据类型与 arr 相同
    assert_equal(np.triu(arr).dtype, arr.dtype)
    # 断言 np.tril(arr) 的数据类型与 arr 相同
    assert_equal(np.tril(arr).dtype, arr.dtype)

    # 创建一个结构化数据类型为 'f4,f4' 的全为 0 的 3x3 数组 arr
    arr = np.zeros((3, 3), dtype='f4,f4')
    # 断言 np.triu(arr) 的数据类型与 arr 相同
    assert_equal(np.triu(arr).dtype, arr.dtype)
    # 断言 np.tril(arr) 的数据类型与 arr 相同
    assert_equal(np.tril(arr).dtype, arr.dtype)


# 测试函数,测试 mask_indices 函数的行为
def test_mask_indices():
    # 简单测试,无偏移量
    # 调用 mask_indices(3, np.triu) 函数,返回上三角矩阵的非零元素索引
    iu = mask_indices(3, np.triu)
    # 创建一个 3x3 的数组 a
    a = np.arange(9).reshape(3, 3)
    # 断言 a[iu] 的结果与预期的非零元素索引数组相等
    assert_array_equal(a[iu], array([0, 1, 2, 4, 5, 8]))
    
    # 带偏移量的测试
    # 调用 mask_indices(3, np.triu, 1) 函数,返回上三角矩阵的非零元素索引(带偏移量)
    iu1 = mask_indices(3, np.triu, 1)
    # 断言 a[iu1] 的结果与预期的非零元素索引数组相等
    assert_array_equal(a[iu1], array([1, 2, 5]))


# 测试函数,测试 tril_indices 函数
def test_tril_indices():
    # 创建一个表示不带偏移的下三角矩阵的索引
    il1 = tril_indices(4)
    # 创建一个表示带有偏移的下三角矩阵的索引
    il2 = tril_indices(4, k=2)
    # 创建一个表示带有更大行数限制的下三角矩阵的索引
    il3 = tril_indices(4, m=5)
    # 创建一个表示带有偏移和更大行数限制的下三角矩阵的索引
    il4 = tril_indices(4, k=2, m=5)

    # 创建一个4x4的NumPy数组
    a = np.array([[1, 2, 3, 4],
                  [5, 6, 7, 8],
                  [9, 10, 11, 12],
                  [13, 14, 15, 16]])
    # 创建一个4x5的NumPy数组,元素为1到20
    b = np.arange(1, 21).reshape(4, 5)

    # 对数组a进行索引操作,并使用assert_array_equal进行断言
    assert_array_equal(a[il1],
                       np.array([1, 5, 6, 9, 10, 11, 13, 14, 15, 16]))
    # 对数组b进行索引操作,并使用assert_array_equal进行断言
    assert_array_equal(b[il3],
                       np.array([1, 6, 7, 11, 12, 13, 16, 17, 18, 19]))

    # 对数组a进行赋值操作,并使用assert_array_equal进行断言
    a[il1] = -1
    assert_array_equal(a,
                       np.array([[-1, 2, 3, 4],
                                 [-1, -1, 7, 8],
                                 [-1, -1, -1, 12],
                                 [-1, -1, -1, -1]]))
    # 对数组b进行赋值操作,并使用assert_array_equal进行断言
    b[il3] = -1
    assert_array_equal(b,
                       np.array([[-1, 2, 3, 4, 5],
                                 [-1, -1, 8, 9, 10],
                                 [-1, -1, -1, 14, 15],
                                 [-1, -1, -1, -1, 20]]))
    # 对数组a进行赋值操作,覆盖几乎整个数组(主对角线右侧的两个对角线)
    a[il2] = -10
    assert_array_equal(a,
                       np.array([[-10, -10, -10, 4],
                                 [-10, -10, -10, -10],
                                 [-10, -10, -10, -10],
                                 [-10, -10, -10, -10]]))
    # 对数组b进行赋值操作,覆盖几乎整个数组(带有偏移和更大行数限制的两个对角线)
    b[il4] = -10
    assert_array_equal(b,
                       np.array([[-10, -10, -10, 4, 5],
                                 [-10, -10, -10, -10, 10],
                                 [-10, -10, -10, -10, -10],
                                 [-10, -10, -10, -10, -10]]))
class TestTriuIndices:
    def test_triu_indices(self):
        # 生成一个包含主对角线及其以上部分的上三角索引数组
        iu1 = triu_indices(4)
        # 生成一个包含主对角线及其以上部分、从主对角线向右偏移2的上三角索引数组
        iu2 = triu_indices(4, k=2)
        # 生成一个包含主对角线及其以上部分、矩阵行数为5的上三角索引数组
        iu3 = triu_indices(4, m=5)
        # 生成一个包含主对角线及其以上部分、从主对角线向右偏移2、矩阵行数为5的上三角索引数组

        a = np.array([[1, 2, 3, 4],
                      [5, 6, 7, 8],
                      [9, 10, 11, 12],
                      [13, 14, 15, 16]])
        b = np.arange(1, 21).reshape(4, 5)

        # 用iu1索引数组获取a的元素,与指定数组比较并断言相等
        assert_array_equal(a[iu1],
                           np.array([1, 2, 3, 4, 6, 7, 8, 11, 12, 16]))
        # 用iu3索引数组获取b的元素,与指定数组比较并断言相等
        assert_array_equal(b[iu3],
                           np.array([1, 2, 3, 4, 5, 7, 8, 9,
                                     10, 13, 14, 15, 19, 20]))

        # 使用iu1索引数组修改a的元素为-1,并与指定数组比较并断言相等
        a[iu1] = -1
        assert_array_equal(a,
                           np.array([[-1, -1, -1, -1],
                                     [5, -1, -1, -1],
                                     [9, 10, -1, -1],
                                     [13, 14, 15, -1]]))
        # 使用iu3索引数组修改b的元素为-1,并与指定数组比较并断言相等
        b[iu3] = -1
        assert_array_equal(b,
                           np.array([[-1, -1, -1, -1, -1],
                                     [6, -1, -1, -1, -1],
                                     [11, 12, -1, -1, -1],
                                     [16, 17, 18, -1, -1]]))

        # 使用iu2索引数组修改a的元素为-10,并与指定数组比较并断言相等
        a[iu2] = -10
        assert_array_equal(a,
                           np.array([[-1, -1, -10, -10],
                                     [5, -1, -1, -10],
                                     [9, 10, -1, -1],
                                     [13, 14, 15, -1]]))
        # 使用iu4索引数组修改b的元素为-10,并与指定数组比较并断言相等
        b[iu4] = -10
        assert_array_equal(b,
                           np.array([[-1, -1, -10, -10, -10],
                                     [6, -1, -1, -10, -10],
                                     [11, 12, -1, -1, -10],
                                     [16, 17, 18, -1, -1]]))


class TestTrilIndicesFrom:
    def test_exceptions(self):
        # 对于维度不为2的矩阵,引发值错误异常
        assert_raises(ValueError, tril_indices_from, np.ones((2,)))
        # 对于维度不为2的3维矩阵,引发值错误异常
        assert_raises(ValueError, tril_indices_from, np.ones((2, 2, 2)))
        # 对于维度为(2, 3)的矩阵,可能引发值错误异常,但已注释掉


class TestTriuIndicesFrom:
    def test_exceptions(self):
        # 对于维度不为2的矩阵,引发值错误异常
        assert_raises(ValueError, triu_indices_from, np.ones((2,)))
        # 对于维度不为2的3维矩阵,引发值错误异常
        assert_raises(ValueError, triu_indices_from, np.ones((2, 2, 2)))
        # 对于维度为(2, 3)的矩阵,可能引发值错误异常,但已注释掉
    # 定义一个测试方法,用于测试 `vander` 函数的基本功能
    def test_basic(self):
        # 创建一个包含整数的 NumPy 数组
        c = np.array([0, 1, -2, 3])
        # 调用 `vander` 函数生成一个 Vandermonde 矩阵
        v = vander(c)
        # 创建一个预期的幂次矩阵,用于与 `vander` 函数生成的结果进行比较
        powers = np.array([[0, 0, 0, 0, 1],
                           [1, 1, 1, 1, 1],
                           [16, -8, 4, -2, 1],
                           [81, 27, 9, 3, 1]])
        # 检查默认的 N 值是否符合预期
        assert_array_equal(v, powers[:, 1:])
        # 检查一系列不同的 N 值,包括 0 和 5(大于默认值)
        m = powers.shape[1]
        for n in range(6):
            # 调用 `vander` 函数,指定不同的 N 值
            v = vander(c, N=n)
            # 检查生成的 Vandermonde 矩阵是否符合预期
            assert_array_equal(v, powers[:, m-n:m])

    # 定义一个测试方法,用于测试 `vander` 函数处理不同数据类型的情况
    def test_dtypes(self):
        # 创建一个包含 int8 类型数据的 NumPy 数组
        c = np.array([11, -12, 13], dtype=np.int8)
        # 调用 `vander` 函数生成一个 Vandermonde 矩阵
        v = vander(c)
        # 创建一个预期的 Vandermonde 矩阵,用于与 `vander` 函数生成的结果进行比较
        expected = np.array([[121, 11, 1],
                             [144, -12, 1],
                             [169, 13, 1]])
        # 检查生成的 Vandermonde 矩阵是否符合预期
        assert_array_equal(v, expected)

        # 创建一个包含复数的 NumPy 数组
        c = np.array([1.0+1j, 1.0-1j])
        # 调用 `vander` 函数生成一个 Vandermonde 矩阵,指定 N 值为 3
        v = vander(c, N=3)
        # 创建一个预期的 Vandermonde 矩阵,用于与 `vander` 函数生成的结果进行比较
        expected = np.array([[2j, 1+1j, 1],
                             [-2j, 1-1j, 1]])
        # 由于数据是浮点型,但值是小整数,使用 `assert_array_equal` 进行比较应该是安全的
        # (而不是使用 `assert_array_almost_equal`)
        assert_array_equal(v, expected)

.\numpy\numpy\lib\tests\test_type_check.py

# 导入 NumPy 库,并从中导入一些函数和类
import numpy as np
from numpy import (
    common_type, mintypecode, isreal, iscomplex, isposinf, isneginf,
    nan_to_num, isrealobj, iscomplexobj, real_if_close
    )
from numpy.testing import (
    assert_, assert_equal, assert_array_equal, assert_raises
    )

# 定义一个断言函数,用于验证条件是否为真
def assert_all(x):
    assert_(np.all(x), x)

# 定义一个测试类 TestCommonType,用于测试 common_type 函数
class TestCommonType:
    # 定义测试方法 test_basic,测试常见数据类型的 common_type 结果
    def test_basic(self):
        # 创建不同数据类型的 NumPy 数组
        ai32 = np.array([[1, 2], [3, 4]], dtype=np.int32)
        af16 = np.array([[1, 2], [3, 4]], dtype=np.float16)
        af32 = np.array([[1, 2], [3, 4]], dtype=np.float32)
        af64 = np.array([[1, 2], [3, 4]], dtype=np.float64)
        acs = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.complex64)
        acd = np.array([[1+5j, 2+6j], [3+7j, 4+8j]], dtype=np.complex128)
        
        # 断言不同数据类型的 common_type 结果是否符合预期
        assert_(common_type(ai32) == np.float64)
        assert_(common_type(af16) == np.float16)
        assert_(common_type(af32) == np.float32)
        assert_(common_type(af64) == np.float64)
        assert_(common_type(acs) == np.complex64)
        assert_(common_type(acd) == np.complex128)

# 定义一个测试类 TestMintypecode,用于测试 mintypecode 函数
class TestMintypecode:
    
    # 定义测试方法 test_default_1,测试默认情况下的 mintypecode 结果
    def test_default_1(self):
        # 遍历不同输入类型,验证 mintypecode 返回值是否符合预期
        for itype in '1bcsuwil':
            assert_equal(mintypecode(itype), 'd')
        assert_equal(mintypecode('f'), 'f')
        assert_equal(mintypecode('d'), 'd')
        assert_equal(mintypecode('F'), 'F')
        assert_equal(mintypecode('D'), 'D')

    # 定义测试方法 test_default_2,测试另一种情况下的 mintypecode 结果
    def test_default_2(self):
        # 遍历不同输入类型组合,验证 mintypecode 返回值是否符合预期
        for itype in '1bcsuwil':
            assert_equal(mintypecode(itype+'f'), 'f')
            assert_equal(mintypecode(itype+'d'), 'd')
            assert_equal(mintypecode(itype+'F'), 'F')
            assert_equal(mintypecode(itype+'D'), 'D')
        assert_equal(mintypecode('ff'), 'f')
        assert_equal(mintypecode('fd'), 'd')
        assert_equal(mintypecode('fF'), 'F')
        assert_equal(mintypecode('fD'), 'D')
        assert_equal(mintypecode('df'), 'd')
        assert_equal(mintypecode('dd'), 'd')
        #assert_equal(mintypecode('dF',savespace=1),'F')
        assert_equal(mintypecode('dF'), 'D')
        assert_equal(mintypecode('dD'), 'D')
        assert_equal(mintypecode('Ff'), 'F')
        #assert_equal(mintypecode('Fd',savespace=1),'F')
        assert_equal(mintypecode('Fd'), 'D')
        assert_equal(mintypecode('FF'), 'F')
        assert_equal(mintypecode('FD'), 'D')
        assert_equal(mintypecode('Df'), 'D')
        assert_equal(mintypecode('Dd'), 'D')
        assert_equal(mintypecode('DF'), 'D')
        assert_equal(mintypecode('DD'), 'D')
    # 定义一个测试方法,用于测试 mintypecode 函数的默认行为
    def test_default_3(self):
        # 断言 mintypecode('fdF') 的返回结果是否等于 'D'
        assert_equal(mintypecode('fdF'), 'D')
        #assert_equal(mintypecode('fdF',savespace=1),'F')  # 这行代码被注释掉了,不参与测试
        # 断言 mintypecode('fdD') 的返回结果是否等于 'D'
        assert_equal(mintypecode('fdD'), 'D')
        # 断言 mintypecode('fFD') 的返回结果是否等于 'D'
        assert_equal(mintypecode('fFD'), 'D')
        # 断言 mintypecode('dFD') 的返回结果是否等于 'D'
        assert_equal(mintypecode('dFD'), 'D')

        # 断言 mintypecode('ifd') 的返回结果是否等于 'd'
        assert_equal(mintypecode('ifd'), 'd')
        # 断言 mintypecode('ifF') 的返回结果是否等于 'F'
        assert_equal(mintypecode('ifF'), 'F')
        # 断言 mintypecode('ifD') 的返回结果是否等于 'D'
        assert_equal(mintypecode('ifD'), 'D')
        # 断言 mintypecode('idF') 的返回结果是否等于 'D'
        assert_equal(mintypecode('idF'), 'D')
        #assert_equal(mintypecode('idF',savespace=1),'F')  # 这行代码被注释掉了,不参与测试
        # 断言 mintypecode('idD') 的返回结果是否等于 'D'
        assert_equal(mintypecode('idD'), 'D')
class TestIsscalar:

    def test_basic(self):
        # 检查是否为标量(单个数值)
        assert_(np.isscalar(3))
        # 检查是否不是标量(列表不是标量)
        assert_(not np.isscalar([3]))
        # 检查是否不是标量(元组不是标量)
        assert_(not np.isscalar((3,)))
        # 检查是否为标量(复数是标量)
        assert_(np.isscalar(3j))
        # 检查是否为标量(浮点数是标量)
        assert_(np.isscalar(4.0))


class TestReal:

    def test_real(self):
        # 生成一个包含随机数的数组
        y = np.random.rand(10,)
        # 断言数组和其实部相等
        assert_array_equal(y, np.real(y))

        # 创建一个包含单个元素的数组
        y = np.array(1)
        # 获取数组的实部
        out = np.real(y)
        # 断言输入和输出数组相等
        assert_array_equal(y, out)
        # 断言输出为 ndarray 类型
        assert_(isinstance(out, np.ndarray))

        # 创建一个标量
        y = 1
        # 获取实部
        out = np.real(y)
        # 断言输入和输出相等
        assert_equal(y, out)
        # 断言输出不是 ndarray 类型
        assert_(not isinstance(out, np.ndarray))

    def test_cmplx(self):
        # 生成一个包含随机复数的数组
        y = np.random.rand(10,)+1j*np.random.rand(10,)
        # 断言实部数组和输入数组的实部相等
        assert_array_equal(y.real, np.real(y))

        # 创建一个包含单个复数的数组
        y = np.array(1 + 1j)
        # 获取数组的实部
        out = np.real(y)
        # 断言输入数组的实部和输出数组相等
        assert_array_equal(y.real, out)
        # 断言输出为 ndarray 类型
        assert_(isinstance(out, np.ndarray))

        # 创建一个复数标量
        y = 1 + 1j
        # 获取实部
        out = np.real(y)
        # 断言实部为 1.0
        assert_equal(1.0, out)
        # 断言输出不是 ndarray 类型
        assert_(not isinstance(out, np.ndarray))


class TestImag:

    def test_real(self):
        # 生成一个包含随机数的数组
        y = np.random.rand(10,)
        # 断言虚部数组为 0
        assert_array_equal(0, np.imag(y))

        # 创建一个包含单个元素的数组
        y = np.array(1)
        # 获取数组的虚部
        out = np.imag(y)
        # 断言虚部为 0
        assert_array_equal(0, out)
        # 断言输出为 ndarray 类型
        assert_(isinstance(out, np.ndarray))

        # 创建一个标量
        y = 1
        # 获取虚部
        out = np.imag(y)
        # 断言虚部为 0
        assert_equal(0, out)
        # 断言输出不是 ndarray 类型
        assert_(not isinstance(out, np.ndarray))

    def test_cmplx(self):
        # 生成一个包含随机复数的数组
        y = np.random.rand(10,)+1j*np.random.rand(10,)
        # 断言虚部数组和输入数组的虚部相等
        assert_array_equal(y.imag, np.imag(y))

        # 创建一个包含单个复数的数组
        y = np.array(1 + 1j)
        # 获取数组的虚部
        out = np.imag(y)
        # 断言输入数组的虚部和输出数组相等
        assert_array_equal(y.imag, out)
        # 断言输出为 ndarray 类型
        assert_(isinstance(out, np.ndarray))

        # 创建一个复数标量
        y = 1 + 1j
        # 获取虚部
        out = np.imag(y)
        # 断言虚部为 1.0
        assert_equal(1.0, out)
        # 断言输出不是 ndarray 类型
        assert_(not isinstance(out, np.ndarray))


class TestIscomplex:

    def test_fail(self):
        # 创建一个数组
        z = np.array([-1, 0, 1])
        # 检查数组中是否没有复数
        res = iscomplex(z)
        assert_(not np.any(res, axis=0))

    def test_pass(self):
        # 创建一个数组
        z = np.array([-1j, 1, 0])
        # 检查数组中每个元素是否为复数
        res = iscomplex(z)
        assert_array_equal(res, [1, 0, 0])


class TestIsreal:

    def test_pass(self):
        # 创建一个数组
        z = np.array([-1, 0, 1j])
        # 检查数组中每个元素是否为实数
        res = isreal(z)
        assert_array_equal(res, [1, 1, 0])

    def test_fail(self):
        # 创建一个数组
        z = np.array([-1j, 1, 0])
        # 检查数组中每个元素是否为实数
        res = isreal(z)
        assert_array_equal(res, [0, 1, 1])


class TestIscomplexobj:

    def test_basic(self):
        # 创建一个数组
        z = np.array([-1, 0, 1])
        # 检查数组是否包含复数对象
        assert_(not iscomplexobj(z))
        # 创建一个包含复数的数组
        z = np.array([-1j, 0, -1])
        # 检查数组是否包含复数对象
        assert_(iscomplexobj(z))

    def test_scalar(self):
        # 检查标量是否为复数对象
        assert_(not iscomplexobj(1.0))
        assert_(iscomplexobj(1+0j))

    def test_list(self):
        # 检查列表中是否包含复数对象
        assert_(iscomplexobj([3, 1+0j, True]))
        assert_(not iscomplexobj([3, 1, True]))

    def test_duck(self):
        # 创建一个虚拟的复数数组类
        class DummyComplexArray:
            @property
            def dtype(self):
                return np.dtype(complex)
        dummy = DummyComplexArray()
        # 检查虚拟复数数组是否为复数对象
        assert_(iscomplexobj(dummy))
    # 定义一个测试方法,用于验证自定义的 np.dtype 鸭子类型类,比如 pandas 使用的类(pandas.core.dtypes)
    def test_pandas_duck(self):
        # 定义一个继承自 np.complex128 的 pandas 复杂类型类
        class PdComplex(np.complex128):
            pass
        
        # 定义一个模拟的 pandas 数据类型类
        class PdDtype:
            name = 'category'  # 数据类型名称为 'category'
            names = None       # 名称列表为空
            type = PdComplex   # 数据类型为 PdComplex 类型
            kind = 'c'         # 类别标识为 'c'
            str = '<c16'       # 字符串描述为 '<c16'
            base = np.dtype('complex128')  # 基础数据类型为 np.complex128
        
        # 定义一个虚拟的 DummyPd 类,具有 dtype 属性,返回 PdDtype 类
        class DummyPd:
            @property
            def dtype(self):
                return PdDtype
        
        # 创建 DummyPd 类的实例 dummy
        dummy = DummyPd()
        
        # 断言 dummy 对象是否是复数对象
        assert_(iscomplexobj(dummy))

    # 定义另一个测试方法,用于验证自定义数据类型鸭子类型
    def test_custom_dtype_duck(self):
        # 定义一个继承自 list 的自定义数组类 MyArray
        class MyArray(list):
            # 定义 dtype 属性,返回复数类型
            @property
            def dtype(self):
                return complex
        
        # 创建 MyArray 类的实例 a,包含三个复数
        a = MyArray([1+0j, 2+0j, 3+0j])
        
        # 断言 a 对象是否是复数对象
        assert_(iscomplexobj(a))
class TestIsrealobj:
    def test_basic(self):
        # 创建一个包含三个元素的 numpy 数组,用于测试是否为实数对象
        z = np.array([-1, 0, 1])
        # 断言 z 是实数对象
        assert_(isrealobj(z))
        
        # 创建另一个 numpy 数组,包含复数元素,用于测试非实数对象
        z = np.array([-1j, 0, -1])
        # 断言 z 不是实数对象
        assert_(not isrealobj(z))


class TestIsnan:

    def test_goodvalues(self):
        # 创建一个包含三个浮点数的 numpy 数组,测试它们不是 NaN
        z = np.array((-1., 0., 1.))
        # 生成一个布尔数组,检查 z 中的元素是否不是 NaN
        res = np.isnan(z) == 0
        # 断言 res 中所有元素都为 True
        assert_all(np.all(res, axis=0))

    def test_posinf(self):
        # 使用 np.errstate 忽略除以零带来的警告
        with np.errstate(divide='ignore'):
            # 创建包含正无穷的 numpy 数组,测试它们不是 NaN
            assert_all(np.isnan(np.array((1.,))/0.) == 0)

    def test_neginf(self):
        # 使用 np.errstate 忽略除以零带来的警告
        with np.errstate(divide='ignore'):
            # 创建包含负无穷的 numpy 数组,测试它们不是 NaN
            assert_all(np.isnan(np.array((-1.,))/0.) == 0)

    def test_ind(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含非法操作结果的 numpy 数组,测试它们是 NaN
            assert_all(np.isnan(np.array((0.,))/0.) == 1)

    def test_integer(self):
        # 创建整数值,测试它不是 NaN
        assert_all(np.isnan(1) == 0)

    def test_complex(self):
        # 创建复数值,测试它不是 NaN
        assert_all(np.isnan(1+1j) == 0)

    def test_complex1(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建复数值进行除法运算,测试它是 NaN
            assert_all(np.isnan(np.array(0+0j)/0.) == 1)


class TestIsfinite:
    # Fixme, wrong place, isfinite now ufunc

    def test_goodvalues(self):
        # 创建一个包含三个浮点数的 numpy 数组,测试它们是有限数
        z = np.array((-1., 0., 1.))
        # 生成一个布尔数组,检查 z 中的元素是否是有限数
        res = np.isfinite(z) == 1
        # 断言 res 中所有元素都为 True
        assert_all(np.all(res, axis=0))

    def test_posinf(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含正无穷的 numpy 数组,测试它们不是有限数
            assert_all(np.isfinite(np.array((1.,))/0.) == 0)

    def test_neginf(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含负无穷的 numpy 数组,测试它们不是有限数
            assert_all(np.isfinite(np.array((-1.,))/0.) == 0)

    def test_ind(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含非法操作结果的 numpy 数组,测试它们不是有限数
            assert_all(np.isfinite(np.array((0.,))/0.) == 0)

    def test_integer(self):
        # 创建整数值,测试它是有限数
        assert_all(np.isfinite(1) == 1)

    def test_complex(self):
        # 创建复数值,测试它是有限数
        assert_all(np.isfinite(1+1j) == 1)

    def test_complex1(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建复数值进行除法运算,测试它不是有限数
            assert_all(np.isfinite(np.array(1+1j)/0.) == 0)


class TestIsinf:
    # Fixme, wrong place, isinf now ufunc

    def test_goodvalues(self):
        # 创建一个包含三个浮点数的 numpy 数组,测试它们不是无穷数
        z = np.array((-1., 0., 1.))
        # 生成一个布尔数组,检查 z 中的元素是否不是无穷数
        res = np.isinf(z) == 0
        # 断言 res 中所有元素都为 True
        assert_all(np.all(res, axis=0))

    def test_posinf(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含正无穷的 numpy 数组,测试它们是正无穷数
            assert_all(np.isinf(np.array((1.,))/0.) == 1)

    def test_posinf_scalar(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建单个正无穷数,测试它是正无穷数
            assert_all(np.isinf(np.array(1.,)/0.) == 1)

    def test_neginf(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含负无穷的 numpy 数组,测试它们是负无穷数
            assert_all(np.isinf(np.array((-1.,))/0.) == 1)

    def test_neginf_scalar(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建单个负无穷数,测试它是负无穷数
            assert_all(np.isinf(np.array(-1.)/0.) == 1)

    def test_ind(self):
        # 使用 np.errstate 忽略除以零和无效操作带来的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建包含非法操作结果的 numpy 数组,测试它们不是无穷数
            assert_all(np.isinf(np.array((0.,))/0.) == 0)


class TestIsposinf:
    # Fixme, wrong place, isposinf not a ufunc yet
    # 定义一个测试函数,用于测试通用情况
    def test_generic(self):
        # 在计算过程中忽略除以零的警告和无效操作的警告
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建一个包含负无穷、零和正无穷的数组,并计算其是否为正无穷
            vals = isposinf(np.array((-1., 0, 1))/0.)
        # 断言:检查计算结果中第一个元素是否为零
        assert_(vals[0] == 0)
        # 断言:检查计算结果中第二个元素是否为零
        assert_(vals[1] == 0)
        # 断言:检查计算结果中第三个元素是否为正无穷
        assert_(vals[2] == 1)
class TestIsneginf:

    def test_generic(self):
        # 忽略除法和无效值错误,执行下面的代码块
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建一个包含负无穷值的数组
            vals = isneginf(np.array((-1., 0, 1))/0.)
        # 断言第一个值为1
        assert_(vals[0] == 1)
        # 断言第二个值为0
        assert_(vals[1] == 0)
        # 断言第三个值为0
        assert_(vals[2] == 0)


class TestNanToNum:

    def test_generic(self):
        # 忽略除法和无效值错误,执行下面的代码块
        with np.errstate(divide='ignore', invalid='ignore'):
            # 使用 nan_to_num 将包含无限值的数组处理为特定的值
            vals = nan_to_num(np.array((-1., 0, 1))/0.)
        # 断言第一个值小于 -1e10,并且是有限的
        assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0]))
        # 断言第二个值为0
        assert_(vals[1] == 0)
        # 断言第三个值大于 1e10,并且是有限的
        assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
        # 断言结果的类型为 numpy 数组
        assert_equal(type(vals), np.ndarray)
        
        # 使用 nan=10, posinf=20, neginf=30 参数再次进行相同的测试
        with np.errstate(divide='ignore', invalid='ignore'):
            # 使用 nan_to_num 处理数组,将 nan 替换为 10,将正无穷替换为 20,将负无穷替换为 30
            vals = nan_to_num(np.array((-1., 0, 1))/0., 
                              nan=10, posinf=20, neginf=30)
        # 断言结果数组与期望的数组相等
        assert_equal(vals, [30, 10, 20])
        # 断言结果数组的第一个和第三个元素是有限的
        assert_all(np.isfinite(vals[[0, 2]]))
        # 断言结果的类型为 numpy 数组
        assert_equal(type(vals), np.ndarray)

        # 在原地进行相同的测试
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建一个包含无限值的数组
            vals = np.array((-1., 0, 1))/0.
        # 使用 nan_to_num 在原地处理数组
        result = nan_to_num(vals, copy=False)

        # 断言处理结果与原数组是同一个对象
        assert_(result is vals)
        # 断言原数组的第一个值小于 -1e10,并且是有限的
        assert_all(vals[0] < -1e10) and assert_all(np.isfinite(vals[0]))
        # 断言原数组的第二个值为0
        assert_(vals[1] == 0)
        # 断言原数组的第三个值大于 1e10,并且是有限的
        assert_all(vals[2] > 1e10) and assert_all(np.isfinite(vals[2]))
        # 断言结果的类型为 numpy 数组
        assert_equal(type(vals), np.ndarray)
        
        # 在原地进行相同的测试,但使用 nan=10, posinf=20, neginf=30 参数
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建一个包含无限值的数组
            vals = np.array((-1., 0, 1))/0.
        # 使用 nan_to_num 在原地处理数组,并将 nan 替换为 10,将正无穷替换为 20,将负无穷替换为 30
        result = nan_to_num(vals, copy=False, nan=10, posinf=20, neginf=30)

        # 断言处理结果与原数组是同一个对象
        assert_(result is vals)
        # 断言原数组等于期望的数组 [30, 10, 20]
        assert_equal(vals, [30, 10, 20])
        # 断言原数组的第一个和第三个元素是有限的
        assert_all(np.isfinite(vals[[0, 2]]))
        # 断言结果的类型为 numpy 数组
        assert_equal(type(vals), np.ndarray)

    def test_array(self):
        # 使用 nan_to_num 处理数组 [1]
        vals = nan_to_num([1])
        # 断言处理结果与期望的数组相等
        assert_array_equal(vals, np.array([1], int))
        # 断言结果的类型为 numpy 数组
        assert_equal(type(vals), np.ndarray)
        # 使用 nan=10, posinf=20, neginf=30 参数再次处理数组 [1]
        vals = nan_to_num([1], nan=10, posinf=20, neginf=30)
        # 断言处理结果与期望的数组相等
        assert_array_equal(vals, np.array([1], int))
        # 断言结果的类型为 numpy 数组
        assert_equal(type(vals), np.ndarray)

    def test_integer(self):
        # 使用 nan_to_num 处理整数 1
        vals = nan_to_num(1)
        # 断言处理结果等于 1
        assert_all(vals == 1)
        # 断言结果的类型为 np.int_
        assert_equal(type(vals), np.int_)
        # 使用 nan=10, posinf=20, neginf=30 参数再次处理整数 1
        vals = nan_to_num(1, nan=10, posinf=20, neginf=30)
        # 断言处理结果等于 1
        assert_all(vals == 1)
        # 断言结果的类型为 np.int_

    def test_float(self):
        # 使用 nan_to_num 处理浮点数 1.0
        vals = nan_to_num(1.0)
        # 断言处理结果等于 1.0
        assert_all(vals == 1.0)
        # 断言结果的类型为 np.float64
        assert_equal(type(vals), np.float64)
        # 使用 nan=10, posinf=20, neginf=30 参数再次处理浮点数 1.1
        vals = nan_to_num(1.1, nan=10, posinf=20, neginf=30)
        # 断言处理结果等于 1.1
        assert_all(vals == 1.1)
        # 断言结果的类型为 np.float64
        assert_equal(type(vals), np.float64)
    # 定义一个测试函数,用于测试处理复数的情况(正常情况)
    def test_complex_good(self):
        # 将复数中的 NaN 替换为 0,并返回处理后的值
        vals = nan_to_num(1+1j)
        # 断言所有处理后的值等于原始复数值
        assert_all(vals == 1+1j)
        # 断言处理后的值的数据类型为 np.complex128
        assert_equal(type(vals), np.complex128)
        
        # 将复数中的 NaN 替换为 10,正无穷替换为 20,负无穷替换为 30,并返回处理后的值
        vals = nan_to_num(1+1j, nan=10, posinf=20, neginf=30)
        # 断言所有处理后的值等于原始复数值
        assert_all(vals == 1+1j)
        # 断言处理后的值的数据类型为 np.complex128
        assert_equal(type(vals), np.complex128)

    # 定义一个测试函数,用于测试处理复数的情况(异常情况1)
    def test_complex_bad(self):
        # 在忽略除法错误和无效值的错误状态下执行以下操作
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建一个复数 v = 1 + 1j
            v = 1 + 1j
            # 将 v 增加一个除以零得到的复数数组的结果
            v += np.array(0+1.j)/0.
        # 将复数 v 中的 NaN 替换为 0,并返回处理后的值
        vals = nan_to_num(v)
        # 断言所有处理后的值是有限的(即不是 NaN 或 inf)
        # !! 这实际上是 (意外地) 得到了零
        assert_all(np.isfinite(vals))
        # 断言处理后的值的数据类型为 np.complex128
        assert_equal(type(vals), np.complex128)

    # 定义一个测试函数,用于测试处理复数的情况(异常情况2)
    def test_complex_bad2(self):
        # 在忽略除法错误和无效值的错误状态下执行以下操作
        with np.errstate(divide='ignore', invalid='ignore'):
            # 创建一个复数 v = 1 + 1j
            v = 1 + 1j
            # 将 v 增加一个除以零得到的复数数组的结果
            v += np.array(-1+1.j)/0.
        # 将复数 v 中的 NaN 替换为 0,并返回处理后的值
        vals = nan_to_num(v)
        # 断言所有处理后的值是有限的(即不是 NaN 或 inf)
        assert_all(np.isfinite(vals))
        # 断言处理后的值的数据类型为 np.complex128
        assert_equal(type(vals), np.complex128)
        # Fixme
        #assert_all(vals.imag > 1e10)  and assert_all(np.isfinite(vals))
        # !! 这实际上是 (意外地) 是正数
        # !! inf。暂时注释掉,并观察是否有变化
        #assert_all(vals.real < -1e10) and assert_all(np.isfinite(vals))

    # 定义一个测试函数,用于测试 nan_to_num 函数不会重写先前关键字的情况
    def test_do_not_rewrite_previous_keyword(self):
        # 在忽略除法错误和无效值的错误状态下执行以下操作
        with np.errstate(divide='ignore', invalid='ignore'):
            # 对数组中的 (-1., 0, 1)/0. 进行 NaN 替换为 np.inf,posinf 替换为 999
            vals = nan_to_num(np.array((-1., 0, 1))/0., nan=np.inf, posinf=999)
        # 断言数组中指定位置的值是有限的(不是 NaN 或 inf)
        assert_all(np.isfinite(vals[[0, 2]]))
        # 断言数组中第一个位置的值小于 -1e10
        assert_all(vals[0] < -1e10)
        # 断言数组中第二个和第三个位置的值分别等于 np.inf 和 999
        assert_equal(vals[[1, 2]], [np.inf, 999])
        # 断言处理后的值的数据类型为 np.ndarray
        assert_equal(type(vals), np.ndarray)
# 定义一个名为 TestRealIfClose 的测试类
class TestRealIfClose:

    # 定义一个测试方法 test_basic,用于测试 real_if_close 函数的基本功能
    def test_basic(self):
        # 生成一个包含 10 个随机数的数组 a
        a = np.random.rand(10)
        # 调用 real_if_close 函数,处理 a 中每个元素加上虚数部分 1e-15j
        b = real_if_close(a+1e-15j)
        # 断言处理后的数组 b 中所有元素都是实数
        assert_all(isrealobj(b))
        # 断言处理后的数组 b 与原始数组 a 相等
        assert_array_equal(a, b)
        # 再次调用 real_if_close 函数,处理 a 中每个元素加上虚数部分 1e-7j
        b = real_if_close(a+1e-7j)
        # 断言处理后的数组 b 中所有元素都是复数
        assert_all(iscomplexobj(b))
        # 第三次调用 real_if_close 函数,设置容差参数 tol 为 1e-6,处理 a 中每个元素加上虚数部分 1e-7j
        b = real_if_close(a+1e-7j, tol=1e-6)
        # 断言处理后的数组 b 中所有元素都是实数
        assert_all(isrealobj(b))

.\numpy\numpy\lib\tests\test_ufunclike.py

import numpy as np  # 导入NumPy库

from numpy import fix, isposinf, isneginf  # 从NumPy中导入fix, isposinf, isneginf函数
from numpy.testing import (  # 从NumPy的testing模块导入以下函数
    assert_, assert_equal, assert_array_equal, assert_raises
)


class TestUfunclike:  # 定义测试类TestUfunclike

    def test_isposinf(self):  # 定义测试isposinf函数的方法
        a = np.array([np.inf, -np.inf, np.nan, 0.0, 3.0, -3.0])  # 创建NumPy数组a
        out = np.zeros(a.shape, bool)  # 创建与a形状相同的布尔型数组out
        tgt = np.array([True, False, False, False, False, False])  # 创建目标数组tgt

        res = isposinf(a)  # 调用isposinf函数
        assert_equal(res, tgt)  # 断言res与tgt相等
        res = isposinf(a, out)  # 调用isposinf函数,将结果存入out
        assert_equal(res, tgt)  # 断言res与tgt相等
        assert_equal(out, tgt)  # 断言out与tgt相等

        a = a.astype(np.complex128)  # 将数组a的数据类型转换为复数型
        with assert_raises(TypeError):  # 检查是否抛出TypeError异常
            isposinf(a)  # 调用isposinf函数

    def test_isneginf(self):  # 定义测试isneginf函数的方法
        a = np.array([np.inf, -np.inf, np.nan, 0.0, 3.0, -3.0])  # 创建NumPy数组a
        out = np.zeros(a.shape, bool)  # 创建与a形状相同的布尔型数组out
        tgt = np.array([False, True, False, False, False, False])  # 创建目标数组tgt

        res = isneginf(a)  # 调用isneginf函数
        assert_equal(res, tgt)  # 断言res与tgt相等
        res = isneginf(a, out)  # 调用isneginf函数,将结果存入out
        assert_equal(res, tgt)  # 断言res与tgt相等
        assert_equal(out, tgt)  # 断言out与tgt相等

        a = a.astype(np.complex128)  # 将数组a的数据类型转换为复数型
        with assert_raises(TypeError):  # 检查是否抛出TypeError异常
            isneginf(a)  # 调用isneginf函数

    def test_fix(self):  # 定义测试fix函数的方法
        a = np.array([[1.0, 1.1, 1.5, 1.8], [-1.0, -1.1, -1.5, -1.8]])  # 创建NumPy数组a
        out = np.zeros(a.shape, float)  # 创建与a形状相同的浮点型数组out
        tgt = np.array([[1., 1., 1., 1.], [-1., -1., -1., -1.]])  # 创建目标数组tgt

        res = fix(a)  # 调用fix函数
        assert_equal(res, tgt)  # 断言res与tgt相等
        res = fix(a, out)  # 调用fix函数,将结果存入out
        assert_equal(res, tgt)  # 断言res与tgt相等
        assert_equal(out, tgt)  # 断言out与tgt相等
        assert_equal(fix(3.14), 3)  # 断言fix函数对标量输入的正确性

    def test_fix_with_subclass(self):  # 定义测试带子类的fix函数的方法
        class MyArray(np.ndarray):  # 定义名为MyArray的子类,继承自np.ndarray
            def __new__(cls, data, metadata=None):  # 定义__new__方法
                res = np.array(data, copy=True).view(cls)  # 创建一个数据的副本,并视图为当前类的实例
                res.metadata = metadata  # 将metadata赋值给实例属性metadata
                return res

            def __array_wrap__(self, obj, context=None, return_scalar=False):  # 定义__array_wrap__方法
                if not isinstance(obj, MyArray):  # 如果obj不是MyArray的实例
                    obj = obj.view(MyArray)  # 将obj视图为MyArray类的实例
                if obj.metadata is None:  # 如果obj的metadata属性为None
                    obj.metadata = self.metadata  # 将self.metadata赋值给obj.metadata
                return obj  # 返回obj

            def __array_finalize__(self, obj):  # 定义__array_finalize__方法
                self.metadata = getattr(obj, 'metadata', None)  # 获取obj的metadata属性,若不存在则为None
                return self  # 返回self

        a = np.array([1.1, -1.1])  # 创建NumPy数组a
        m = MyArray(a, metadata='foo')  # 创建MyArray类的实例m,传入a和metadata参数
        f = fix(m)  # 调用fix函数,传入m

        assert_array_equal(f, np.array([1, -1]))  # 断言f与期望的数组相等
        assert_(isinstance(f, MyArray))  # 断言f是MyArray类的实例
        assert_equal(f.metadata, 'foo')  # 断言f的metadata属性值为'foo'

        # 检查0维数组不会退化为标量
        m0d = m[0,...]  # 创建m的0维视图m0d
        m0d.metadata = 'bar'  # 给m0d的metadata属性赋值'bar'
        f0d = fix(m0d)  # 调用fix函数,传入m0d

        assert_(isinstance(f0d, MyArray))  # 断言f0d是MyArray类的实例
        assert_equal(f0d.metadata, 'bar')  # 断言f0d的metadata属性值为'bar'
    # 定义一个测试方法,用于测试numpy库中的数值处理函数
    def test_scalar(self):
        # 设置一个无穷大的浮点数
        x = np.inf
        # 调用numpy的函数,判断x是否为正无穷
        actual = np.isposinf(x)
        # 预期值是True
        expected = np.True_
        # 使用assert_equal函数断言实际结果与预期结果相等
        assert_equal(actual, expected)
        # 再次使用assert_equal函数断言实际结果的类型与预期结果的类型相等
        assert_equal(type(actual), type(expected))

        # 设置一个浮点数
        x = -3.4
        # 调用numpy的函数,返回小于或等于x的最大整数
        actual = np.fix(x)
        # 预期值是一个numpy float64类型的数值 -3.0
        expected = np.float64(-3.0)
        # 使用assert_equal函数断言实际结果与预期结果相等
        assert_equal(actual, expected)
        # 再次使用assert_equal函数断言实际结果的类型与预期结果的类型相等
        assert_equal(type(actual), type(expected))

        # 创建一个浮点数类型的numpy数组,初始值为0.0
        out = np.array(0.0)
        # 调用numpy的函数,返回小于或等于x的最大整数,并将结果存储到指定的输出数组中
        actual = np.fix(x, out=out)
        # 使用assert_函数断言实际返回的数组与指定的输出数组是同一个对象
        assert_(actual is out)

.\numpy\numpy\lib\tests\test_utils.py

import pytest  # 导入 pytest 库

import numpy as np  # 导入 NumPy 库,并使用 np 别名
from numpy.testing import assert_raises_regex  # 导入 assert_raises_regex 函数
import numpy.lib._utils_impl as _utils_impl  # 导入 _utils_impl 模块

from io import StringIO  # 从 io 模块导入 StringIO 类


def test_assert_raises_regex_context_manager():
    with assert_raises_regex(ValueError, 'no deprecation warning'):  # 使用 assert_raises_regex 上下文管理器检查是否抛出 ValueError 异常,并检查错误消息是否包含特定文本
        raise ValueError('no deprecation warning')


def test_info_method_heading():
    # info(class) should only print "Methods:" heading if methods exist
    # 定义两个类来测试:NoPublicMethods 没有公共方法,WithPublicMethods 有一个公共方法
    class NoPublicMethods:
        pass

    class WithPublicMethods:
        def first_method():
            pass

    def _has_method_heading(cls):
        out = StringIO()  # 创建一个 StringIO 对象用于捕获输出
        np.info(cls, output=out)  # 使用 np.info 函数打印类信息到 StringIO 对象
        return 'Methods:' in out.getvalue()  # 检查输出字符串中是否包含 "Methods:" 标题

    assert _has_method_heading(WithPublicMethods)  # 测试 WithPublicMethods 类应该打印出 "Methods:" 标题
    assert not _has_method_heading(NoPublicMethods)  # 测试 NoPublicMethods 类不应该打印出 "Methods:" 标题


def test_drop_metadata():
    def _compare_dtypes(dt1, dt2):
        return np.can_cast(dt1, dt2, casting='no')  # 比较两个 dtype 是否可以强制转换

    # structured dtype
    dt = np.dtype([('l1', [('l2', np.dtype('S8', metadata={'msg': 'toto'}))])],
                  metadata={'msg': 'titi'})  # 创建一个结构化 dtype,并设置 metadata
    dt_m = _utils_impl.drop_metadata(dt)  # 调用 _utils_impl.drop_metadata 函数去除 metadata
    assert _compare_dtypes(dt, dt_m) is True  # 检查去除 metadata 后 dtype 是否保持可转换性
    assert dt_m.metadata is None  # 检查去除 metadata 后顶层 dtype 的 metadata 是否为 None
    assert dt_m['l1'].metadata is None  # 检查去除 metadata 后子 dtype 的 metadata 是否为 None

    # alignment
    dt = np.dtype([('x', '<f8'), ('y', '<i4')],
                  align=True,
                  metadata={'msg': 'toto'})  # 创建一个带 alignment 和 metadata 的 dtype
    dt_m = _utils_impl.drop_metadata(dt)  # 调用 _utils_impl.drop_metadata 函数去除 metadata
    assert _compare_dtypes(dt, dt_m) is True  # 检查去除 metadata 后 dtype 是否保持可转换性
    assert dt_m.metadata is None  # 检查去除 metadata 后顶层 dtype 的 metadata 是否为 None

    # subdtype
    dt = np.dtype('8f',
                  metadata={'msg': 'toto'})  # 创建一个带 metadata 的 subdtype
    dt_m = _utils_impl.drop_metadata(dt)  # 调用 _utils_impl.drop_metadata 函数去除 metadata
    assert _compare_dtypes(dt, dt_m) is True  # 检查去除 metadata 后 dtype 是否保持可转换性
    assert dt_m.metadata is None  # 检查去除 metadata 后 dtype 的 metadata 是否为 None

    # scalar
    dt = np.dtype('uint32',
                  metadata={'msg': 'toto'})  # 创建一个带 metadata 的标量 dtype
    dt_m = _utils_impl.drop_metadata(dt)  # 调用 _utils_impl.drop_metadata 函数去除 metadata
    assert _compare_dtypes(dt, dt_m) is True  # 检查去除 metadata 后 dtype 是否保持可转换性
    assert dt_m.metadata is None  # 检查去除 metadata 后 dtype 的 metadata 是否为 None


@pytest.mark.parametrize("dtype",
        [np.dtype("i,i,i,i")[["f1", "f3"]],
        np.dtype("f8"),
        np.dtype("10i")])
def test_drop_metadata_identity_and_copy(dtype):
    # If there is no metadata, the identity is preserved:
    assert _utils_impl.drop_metadata(dtype) is dtype  # 如果没有 metadata,则保持 dtype 的身份不变

    # If there is any, it is dropped (subforms are checked above)
    dtype = np.dtype(dtype, metadata={1: 2})  # 给 dtype 添加一个 metadata
    assert _utils_impl.drop_metadata(dtype).metadata is None  # 检查去除 metadata 后是否为 None

.\numpy\numpy\lib\tests\test__datasource.py

import os  # 导入操作系统模块
import pytest  # 导入 pytest 测试框架
from tempfile import mkdtemp, mkstemp, NamedTemporaryFile  # 导入临时文件和目录创建相关的函数
from shutil import rmtree  # 导入删除目录的函数

import numpy.lib._datasource as datasource  # 导入 numpy 的数据源模块
from numpy.testing import assert_, assert_equal, assert_raises  # 导入 numpy 测试相关的断言函数

import urllib.request as urllib_request  # 导入 urllib 的请求模块
from urllib.parse import urlparse  # 导入解析 URL 的函数
from urllib.error import URLError  # 导入处理 URL 错误的异常类


def urlopen_stub(url, data=None):
    '''Stub to replace urlopen for testing.'''
    if url == valid_httpurl():
        tmpfile = NamedTemporaryFile(prefix='urltmp_')
        return tmpfile
    else:
        raise URLError('Name or service not known')

# setup and teardown
old_urlopen = None


def setup_module():
    global old_urlopen

    old_urlopen = urllib_request.urlopen
    urllib_request.urlopen = urlopen_stub


def teardown_module():
    urllib_request.urlopen = old_urlopen

# A valid website for more robust testing
http_path = 'http://www.google.com/'  # 定义一个有效的 HTTP 网址
http_file = 'index.html'  # 定义一个 HTTP 文件名

http_fakepath = 'http://fake.abc.web/site/'  # 定义一个无效的 HTTP 网址
http_fakefile = 'fake.txt'  # 定义一个无效的 HTTP 文件名

malicious_files = ['/etc/shadow', '../../shadow',
                   '..\\system.dat', 'c:\\windows\\system.dat']  # 定义一些恶意文件路径

magic_line = b'three is the magic number'  # 定义一个神奇的字节序列


# Utility functions used by many tests
def valid_textfile(filedir):
    # Generate and return a valid temporary file.
    fd, path = mkstemp(suffix='.txt', prefix='dstmp_', dir=filedir, text=True)
    os.close(fd)
    return path


def invalid_textfile(filedir):
    # Generate and return an invalid filename.
    fd, path = mkstemp(suffix='.txt', prefix='dstmp_', dir=filedir)
    os.close(fd)
    os.remove(path)
    return path


def valid_httpurl():
    return http_path+http_file


def invalid_httpurl():
    return http_fakepath+http_fakefile


def valid_baseurl():
    return http_path


def invalid_baseurl():
    return http_fakepath


def valid_httpfile():
    return http_file


def invalid_httpfile():
    return http_fakefile


class TestDataSourceOpen:
    def setup_method(self):
        self.tmpdir = mkdtemp()  # 创建临时目录
        self.ds = datasource.DataSource(self.tmpdir)  # 初始化数据源对象

    def teardown_method(self):
        rmtree(self.tmpdir)  # 删除临时目录
        del self.ds  # 删除数据源对象

    def test_ValidHTTP(self):
        fh = self.ds.open(valid_httpurl())  # 打开一个有效的 HTTP 资源
        assert_(fh)  # 断言文件句柄有效
        fh.close()  # 关闭文件句柄

    def test_InvalidHTTP(self):
        url = invalid_httpurl()
        assert_raises(OSError, self.ds.open, url)  # 断言打开无效 HTTP 资源时抛出 OSError 异常
        try:
            self.ds.open(url)
        except OSError as e:
            # Regression test for bug fixed in r4342.
            assert_(e.errno is None)  # 断言异常中的错误号为 None

    def test_InvalidHTTPCacheURLError(self):
        assert_raises(URLError, self.ds._cache, invalid_httpurl())  # 断言在缓存无效 HTTP 资源时抛出 URLError 异常

    def test_ValidFile(self):
        local_file = valid_textfile(self.tmpdir)  # 创建一个有效的本地文本文件
        fh = self.ds.open(local_file)  # 打开该本地文件
        assert_(fh)  # 断言文件句柄有效
        fh.close()  # 关闭文件句柄

    def test_InvalidFile(self):
        invalid_file = invalid_textfile(self.tmpdir)  # 创建一个无效的本地文件名
        assert_raises(OSError, self.ds.open, invalid_file)  # 断言打开无效文件时抛出 OSError 异常
    # 定义测试函数,用于验证处理有效的 Gzip 文件的功能
    def test_ValidGzipFile(self):
        try:
            import gzip
        except ImportError:
            # 如果导入 gzip 失败,跳过测试
            pytest.skip()
        # 设置测试文件路径为临时目录下的 foobar.txt.gz
        filepath = os.path.join(self.tmpdir, 'foobar.txt.gz')
        # 打开文件准备写入 gzip 数据
        fp = gzip.open(filepath, 'w')
        # 写入预设的 magic_line 到 gzip 文件中
        fp.write(magic_line)
        # 关闭文件
        fp.close()
        # 调用被测试的数据源对象的 open 方法打开文件
        fp = self.ds.open(filepath)
        # 从文件中读取一行数据
        result = fp.readline()
        # 关闭文件
        fp.close()
        # 断言读取的结果与预设的 magic_line 相等
        assert_equal(magic_line, result)

    # 定义测试函数,用于验证处理有效的 BZip2 文件的功能
    def test_ValidBz2File(self):
        try:
            import bz2
        except ImportError:
            # 如果导入 bz2 失败,跳过测试
            pytest.skip()
        # 设置测试文件路径为临时目录下的 foobar.txt.bz2
        filepath = os.path.join(self.tmpdir, 'foobar.txt.bz2')
        # 打开文件准备写入 BZip2 数据
        fp = bz2.BZ2File(filepath, 'w')
        # 写入预设的 magic_line 到 BZip2 文件中
        fp.write(magic_line)
        # 关闭文件
        fp.close()
        # 调用被测试的数据源对象的 open 方法打开文件
        fp = self.ds.open(filepath)
        # 从文件中读取一行数据
        result = fp.readline()
        # 关闭文件
        fp.close()
        # 断言读取的结果与预设的 magic_line 相等
        assert_equal(magic_line, result)
class TestDataSourceExists:
    # 测试数据源存在性的测试类

    def setup_method(self):
        # 每个测试方法执行前的设置方法
        self.tmpdir = mkdtemp()
        self.ds = datasource.DataSource(self.tmpdir)

    def teardown_method(self):
        # 每个测试方法执行后的清理方法
        rmtree(self.tmpdir)
        del self.ds

    def test_ValidHTTP(self):
        # 测试有效的 HTTP 路径
        assert_(self.ds.exists(valid_httpurl()))

    def test_InvalidHTTP(self):
        # 测试无效的 HTTP 路径
        assert_equal(self.ds.exists(invalid_httpurl()), False)

    def test_ValidFile(self):
        # 测试存在于目标路径中的有效文件
        tmpfile = valid_textfile(self.tmpdir)
        assert_(self.ds.exists(tmpfile))
        # 测试不在目标路径中的本地有效文件
        localdir = mkdtemp()
        tmpfile = valid_textfile(localdir)
        assert_(self.ds.exists(tmpfile))
        rmtree(localdir)

    def test_InvalidFile(self):
        # 测试无效的文件
        tmpfile = invalid_textfile(self.tmpdir)
        assert_equal(self.ds.exists(tmpfile), False)


class TestDataSourceAbspath:
    # 测试数据源绝对路径的测试类

    def setup_method(self):
        # 每个测试方法执行前的设置方法
        self.tmpdir = os.path.abspath(mkdtemp())
        self.ds = datasource.DataSource(self.tmpdir)

    def teardown_method(self):
        # 每个测试方法执行后的清理方法
        rmtree(self.tmpdir)
        del self.ds

    def test_ValidHTTP(self):
        # 测试有效的 HTTP 路径
        scheme, netloc, upath, pms, qry, frg = urlparse(valid_httpurl())
        local_path = os.path.join(self.tmpdir, netloc,
                                  upath.strip(os.sep).strip('/'))
        assert_equal(local_path, self.ds.abspath(valid_httpurl()))

    def test_ValidFile(self):
        # 测试有效的文件路径
        tmpfile = valid_textfile(self.tmpdir)
        tmpfilename = os.path.split(tmpfile)[-1]
        # 测试仅使用文件名的情况
        assert_equal(tmpfile, self.ds.abspath(tmpfilename))
        # 测试包含完整路径的文件名
        assert_equal(tmpfile, self.ds.abspath(tmpfile))

    def test_InvalidHTTP(self):
        # 测试无效的 HTTP 路径
        scheme, netloc, upath, pms, qry, frg = urlparse(invalid_httpurl())
        invalidhttp = os.path.join(self.tmpdir, netloc,
                                   upath.strip(os.sep).strip('/'))
        assert_(invalidhttp != self.ds.abspath(valid_httpurl()))

    def test_InvalidFile(self):
        # 测试无效的文件路径
        invalidfile = valid_textfile(self.tmpdir)
        tmpfile = valid_textfile(self.tmpdir)
        tmpfilename = os.path.split(tmpfile)[-1]
        # 测试仅使用文件名的情况
        assert_(invalidfile != self.ds.abspath(tmpfilename))
        # 测试包含完整路径的文件名
        assert_(invalidfile != self.ds.abspath(tmpfile))
    # 测试函数:测试沙盒环境限制

    # 创建一个有效的文本文件并返回其路径
    tmpfile = valid_textfile(self.tmpdir)
    # 获取临时文件名
    tmpfilename = os.path.split(tmpfile)[-1]

    # 定义一个临时路径的lambda函数,将输入路径转换为绝对路径并返回
    tmp_path = lambda x: os.path.abspath(self.ds.abspath(x))

    # 断言:验证有效 HTTP URL 转换后的路径是否以 self.tmpdir 开头
    assert_(tmp_path(valid_httpurl()).startswith(self.tmpdir))
    # 断言:验证无效 HTTP URL 转换后的路径是否以 self.tmpdir 开头
    assert_(tmp_path(invalid_httpurl()).startswith(self.tmpdir))
    # 断言:验证临时文件路径转换后是否以 self.tmpdir 开头
    assert_(tmp_path(tmpfile).startswith(self.tmpdir))
    # 断言:验证临时文件名路径转换后是否以 self.tmpdir 开头
    assert_(tmp_path(tmpfilename).startswith(self.tmpdir))

    # 遍历恶意文件列表,验证连接到 HTTP 路径的恶意文件是否以 self.tmpdir 开头
    for fn in malicious_files:
        assert_(tmp_path(http_path+fn).startswith(self.tmpdir))
        # 断言:验证恶意文件路径是否以 self.tmpdir 开头
        assert_(tmp_path(fn).startswith(self.tmpdir))


    # 测试函数:测试 Windows 系统路径分隔符

    # 保存原始的系统路径分隔符
    orig_os_sep = os.sep
    try:
        # 设置系统路径分隔符为反斜杠
        os.sep = '\\'
        # 执行以下测试函数
        self.test_ValidHTTP()
        self.test_ValidFile()
        self.test_InvalidHTTP()
        self.test_InvalidFile()
        self.test_sandboxing()
    finally:
        # 恢复原始的系统路径分隔符
        os.sep = orig_os_sep
# 定义一个名为TestRepositoryAbspath的类
class TestRepositoryAbspath:
    # 定义初始化方法
    def setup_method(self):
        # 创建临时目录,并获取其绝对路径
        self.tmpdir = os.path.abspath(mkdtemp())
        # 创建数据源,使用有效的基本URL和临时目录
        self.repos = datasource.Repository(valid_baseurl(), self.tmpdir)

    # 定义清理方法
    def teardown_method(self):
        # 递归删除临时目录及其内容
        rmtree(self.tmpdir)
        # 删除数据源
        del self.repos

    # 定义测试方法,测试有效的HTTP地址
    def test_ValidHTTP(self):
        # 解析有效的HTTP URL,获取各部分信息
        scheme, netloc, upath, pms, qry, frg = urlparse(valid_httpurl())
        # 拼接本地路径
        local_path = os.path.join(self.repos._destpath, netloc, upath.strip(os.sep).strip('/'))
        # 获取HTTP文件的绝对路径
        filepath = self.repos.abspath(valid_httpfile())
        # 断言本地路径和文件路径相等
        assert_equal(local_path, filepath)

    # 定义测试方法,测试沙盒功能
    def test_sandboxing(self):
        # 使用lambda函数获取临时路径,并断言其以临时目录开头
        tmp_path = lambda x: os.path.abspath(self.repos.abspath(x))
        assert_(tmp_path(valid_httpfile()).startswith(self.tmpdir))
        # 遍历恶意文件列表,断言其绝对路径以临时目录开头
        for fn in malicious_files:
            assert_(tmp_path(http_path+fn).startswith(self.tmpdir))
            assert_(tmp_path(fn).startswith(self.tmpdir))

    # 定义测试方法,测试Windows系统下路径分隔符
    def test_windows_os_sep(self):
        # 保存原始路径分隔符
        orig_os_sep = os.sep
        try:
            # 修改路径分隔符为反斜杠
            os.sep = '\\'
            # 执行ValidHTTP测试
            self.test_ValidHTTP()
            # 执行sandboxing测试
            self.test_sandboxing()
        finally:
            # 恢复原始路径分隔符
            os.sep = orig_os_sep


# 定义一个名为TestRepositoryExists的类
class TestRepositoryExists:
    # 定义初始化方法
    def setup_method(self):
        # 创建临时目录
        self.tmpdir = mkdtemp()
        # 创建数据源,使用有效的基本URL和临时目录
        self.repos = datasource.Repository(valid_baseurl(), self.tmpdir)

    # 定义清理方法
    def teardown_method(self):
        # 递归删除临时目录及其内容
        rmtree(self.tmpdir)
        # 删除数据源
        del self.repos

    # 定义测试方法,测试有效文件是否存在
    def test_ValidFile(self):
        # 创建本地临时文件
        tmpfile = valid_textfile(self.tmpdir)
        # 断言数据源中存在该文件
        assert_(self.repos.exists(tmpfile))

    # 定义测试方法,测试无效文件是否存在
    def test_InvalidFile(self):
        # 创建无效的本地临时文件
        tmpfile = invalid_textfile(self.tmpdir)
        # 断言数据源中不存在该文件
        assert_equal(self.repos.exists(tmpfile), False)

    # 定义测试方法,测试移除HTTP文件
    def test_RemoveHTTPFile(self):
        # 断言数据源中存在有效的HTTP文件
        assert_(self.repos.exists(valid_httpurl()))

    # 定义测试方法,测试缓存的HTTP文件是否存在
    def test_CachedHTTPFile(self):
        # 获取有效的HTTP URL
        localfile = valid_httpurl()
        # 创建一个具有URL目录结构的本地缓存临时文件,类似于Repository.open的操作
        scheme, netloc, upath, pms, qry, frg = urlparse(localfile)
        local_path = os.path.join(self.repos._destpath, netloc)
        os.mkdir(local_path, 0o0700)
        tmpfile = valid_textfile(local_path)
        # 断言数据源中存在该缓存文件
        assert_(self.repos.exists(tmpfile))


# 定义一个名为TestOpenFunc的类
class TestOpenFunc:
    # 定义初始化方法
    def setup_method(self):
        # 创建临时目录
        self.tmpdir = mkdtemp()

    # 定义清理方法
    def teardown_method(self):
        # 递归删除临时目录及其内容
        rmtree(self.tmpdir)

    # 定义测试方法,测试DataSource的打开操作
    def test_DataSourceOpen(self):
        # 创建本地临时文件
        local_file = valid_textfile(self.tmpdir)
        # 测试传入目标路径的情况
        fp = datasource.open(local_file, destpath=self.tmpdir)
        assert_(fp)
        fp.close()
        # 测试使用默认目标路径的情况
        fp = datasource.open(local_file)
        assert_(fp)
        fp.close()

# 定义测试删除属性处理的函数
def test_del_attr_handling():
    # 数据源__del__可能会被调用,即使在初始化失败时(被调用的异常对象被捕获)
    # 就像在refguide_check的is_deprecated()函数中发生的情况一样
    ds = datasource.DataSource()
    # 创建一个名为 ds 的 DataSource 对象实例
    
    # 模拟由于删除 __init__ 中产生的关键属性而导致的初始化失败
    del ds._istmpdest
    
    # 调用 __del__ 方法来确保在初始化失败时也不会触发 AttributeError
    ds.__del__()
    # 执行对象的析构函数,清理资源或执行必要的清理操作

.\numpy\numpy\lib\tests\test__iotools.py

# 导入时间模块
import time
# 从日期模块中导入日期类
from datetime import date

# 导入NumPy库并重命名为np
import numpy as np
# 从NumPy测试模块中导入断言函数
from numpy.testing import (
    assert_, assert_equal, assert_allclose, assert_raises,
    )
# 从NumPy输入输出工具模块中导入特定功能
from numpy.lib._iotools import (
    LineSplitter, NameValidator, StringConverter,
    has_nested_fields, easy_dtype, flatten_dtype
    )

# 定义测试类 TestLineSplitter
class TestLineSplitter:
    "Tests the LineSplitter class."

    # 定义测试方法 test_no_delimiter,测试无分隔符情况
    def test_no_delimiter(self):
        "Test LineSplitter w/o delimiter"
        # 测试字符串
        strg = " 1 2 3 4  5 # test"
        # 创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter()(strg)
        assert_equal(test, ['1', '2', '3', '4', '5'])
        # 使用空字符串作为分隔符,再次调用 LineSplitter 实例,进行断言
        test = LineSplitter('')(strg)
        assert_equal(test, ['1', '2', '3', '4', '5'])

    # 定义测试方法 test_space_delimiter,测试空格分隔符情况
    def test_space_delimiter(self):
        "Test space delimiter"
        # 测试字符串
        strg = " 1 2 3 4  5 # test"
        # 使用空格作为分隔符,创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter(' ')(strg)
        assert_equal(test, ['1', '2', '3', '4', '', '5'])
        # 使用两个空格作为分隔符,再次调用 LineSplitter 实例,进行断言
        test = LineSplitter('  ')(strg)
        assert_equal(test, ['1 2 3 4', '5'])

    # 定义测试方法 test_tab_delimiter,测试制表符分隔符情况
    def test_tab_delimiter(self):
        "Test tab delimiter"
        # 测试字符串
        strg = " 1\t 2\t 3\t 4\t 5  6"
        # 使用制表符作为分隔符,创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter('\t')(strg)
        assert_equal(test, ['1', '2', '3', '4', '5  6'])
        # 测试字符串
        strg = " 1  2\t 3  4\t 5  6"
        # 使用制表符作为分隔符,再次调用 LineSplitter 实例,返回结果进行断言
        test = LineSplitter('\t')(strg)
        assert_equal(test, ['1  2', '3  4', '5  6'])

    # 定义测试方法 test_other_delimiter,测试其他自定义分隔符情况
    def test_other_delimiter(self):
        "Test LineSplitter on delimiter"
        # 测试字符串
        strg = "1,2,3,4,,5"
        # 使用逗号作为分隔符,创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter(',')(strg)
        assert_equal(test, ['1', '2', '3', '4', '', '5'])
        #
        # 测试字符串
        strg = " 1,2,3,4,,5 # test"
        # 使用逗号作为分隔符,再次调用 LineSplitter 实例,返回结果进行断言
        test = LineSplitter(',')(strg)
        assert_equal(test, ['1', '2', '3', '4', '', '5'])

        # gh-11028 bytes comment/delimiters should get encoded
        # 测试字节字符串
        strg = b" 1,2,3,4,,5 % test"
        # 使用逗号和百分号作为分隔符和注释,创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter(delimiter=b',', comments=b'%')(strg)
        assert_equal(test, ['1', '2', '3', '4', '', '5'])

    # 定义测试方法 test_constant_fixed_width,测试固定宽度字段情况
    def test_constant_fixed_width(self):
        "Test LineSplitter w/ fixed-width fields"
        # 测试字符串
        strg = "  1  2  3  4     5   # test"
        # 使用固定宽度为3的字段,创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter(3)(strg)
        assert_equal(test, ['1', '2', '3', '4', '', '5', ''])
        #
        # 测试字符串
        strg = "  1     3  4  5  6# test"
        # 使用固定宽度为20的字段,再次调用 LineSplitter 实例,返回结果进行断言
        test = LineSplitter(20)(strg)
        assert_equal(test, ['1     3  4  5  6'])
        #
        # 测试字符串
        strg = "  1     3  4  5  6# test"
        # 使用固定宽度为30的字段,再次调用 LineSplitter 实例,返回结果进行断言
        test = LineSplitter(30)(strg)
        assert_equal(test, ['1     3  4  5  6'])

    # 定义测试方法 test_variable_fixed_width,测试变量宽度字段情况
    def test_variable_fixed_width(self):
        # 测试字符串
        strg = "  1     3  4  5  6# test"
        # 使用不同宽度(3, 6, 6, 3)的字段,创建 LineSplitter 实例并调用,返回结果进行断言
        test = LineSplitter((3, 6, 6, 3))(strg)
        assert_equal(test, ['1', '3', '4  5', '6'])
        #
        # 测试字符串
        strg = "  1     3  4  5  6# test"
        # 使用不同宽度(6, 6, 9)的字段,再次调用 LineSplitter 实例,返回结果进行断言
        test = LineSplitter((6, 6, 9))(strg)
        assert_equal(test, ['1', '3  4', '5  6'])

# -----------------------------------------------------------------------------
    def test_case_sensitivity(self):
        "Test case sensitivity"
        # 定义测试用例,包含大小写敏感和不敏感的情况
        names = ['A', 'a', 'b', 'c']
        # 使用默认设置进行名称验证
        test = NameValidator().validate(names)
        # 断言结果与预期相同
        assert_equal(test, ['A', 'a', 'b', 'c'])
        # 使用不区分大小写的设置进行名称验证
        test = NameValidator(case_sensitive=False).validate(names)
        # 断言结果与预期相同,会自动修正冲突名称
        assert_equal(test, ['A', 'A_1', 'B', 'C'])
        # 使用大写字母形式进行名称验证
        test = NameValidator(case_sensitive='upper').validate(names)
        # 断言结果与预期相同,会自动修正冲突名称
        assert_equal(test, ['A', 'A_1', 'B', 'C'])
        # 使用小写字母形式进行名称验证
        test = NameValidator(case_sensitive='lower').validate(names)
        # 断言结果与预期相同,会自动修正冲突名称
        assert_equal(test, ['a', 'a_1', 'b', 'c'])

        # 检查异常情况,应该引发 ValueError 异常
        assert_raises(ValueError, NameValidator, case_sensitive='foobar')

    def test_excludelist(self):
        "Test excludelist"
        # 定义测试用例,包含排除列表的情况
        names = ['dates', 'data', 'Other Data', 'mask']
        # 创建排除特定名称的验证器
        validator = NameValidator(excludelist=['dates', 'data', 'mask'])
        # 对名称列表进行验证
        test = validator.validate(names)
        # 断言结果与预期相同,会自动修正冲突名称
        assert_equal(test, ['dates_', 'data_', 'Other_Data', 'mask_'])

    def test_missing_names(self):
        "Test validate missing names"
        # 定义测试用例,包含缺失名称的情况
        namelist = ('a', 'b', 'c')
        # 创建默认验证器
        validator = NameValidator()
        # 对名称列表进行验证
        assert_equal(validator(namelist), ['a', 'b', 'c'])
        namelist = ('', 'b', 'c')
        # 对包含空字符串的名称列表进行验证
        assert_equal(validator(namelist), ['f0', 'b', 'c'])
        namelist = ('a', 'b', '')
        # 对包含空字符串的名称列表进行验证
        assert_equal(validator(namelist), ['a', 'b', 'f0'])
        namelist = ('', 'f0', '')
        # 对包含多个空字符串的名称列表进行验证
        assert_equal(validator(namelist), ['f1', 'f0', 'f2'])

    def test_validate_nb_names(self):
        "Test validate nb names"
        # 定义测试用例,包含限制字段数量的情况
        namelist = ('a', 'b', 'c')
        # 创建默认验证器
        validator = NameValidator()
        # 对名称列表进行验证,限制为 1 个字段
        assert_equal(validator(namelist, nbfields=1), ('a',))
        # 对名称列表进行验证,增加到 5 个字段,并指定默认格式
        assert_equal(validator(namelist, nbfields=5, defaultfmt="g%i"),
                     ['a', 'b', 'c', 'g0', 'g1'])

    def test_validate_wo_names(self):
        "Test validate no names"
        # 定义测试用例,包含空名称列表的情况
        namelist = None
        # 创建默认验证器
        validator = NameValidator()
        # 验证空名称列表
        assert_(validator(namelist) is None)
        # 验证空名称列表,限制为 3 个字段
        assert_equal(validator(namelist, nbfields=3), ['f0', 'f1', 'f2'])
# -----------------------------------------------------------------------------
# 将字节字符串转换为日期对象
def _bytes_to_date(s):
    return date(*time.strptime(s, "%Y-%m-%d")[:3])


class TestStringConverter:
    "Test StringConverter"

    def test_creation(self):
        "Test creation of a StringConverter"
        # 创建一个整数型的 StringConverter,设置默认值为 -99999
        converter = StringConverter(int, -99999)
        # 断言状态为 1
        assert_equal(converter._status, 1)
        # 断言默认值为 -99999
        assert_equal(converter.default, -99999)

    def test_upgrade(self):
        "Tests the upgrade method."

        # 创建一个默认的 StringConverter
        converter = StringConverter()
        # 断言状态为 0
        assert_equal(converter._status, 0)

        # 测试整数类型
        assert_equal(converter.upgrade('0'), 0)
        # 断言状态为 1
        assert_equal(converter._status, 1)

        # 在 long 类型默认为 32 位系统上,状态将会有一个偏移量,因此我们在此处检查这一点
        import numpy._core.numeric as nx
        status_offset = int(nx.dtype(nx.int_).itemsize < nx.dtype(nx.int64).itemsize)

        # 测试大于 2**32 的整数
        assert_equal(converter.upgrade('17179869184'), 17179869184)
        # 断言状态为 1 + status_offset
        assert_equal(converter._status, 1 + status_offset)

        # 测试浮点数类型
        assert_allclose(converter.upgrade('0.'), 0.0)
        # 断言状态为 2 + status_offset
        assert_equal(converter._status, 2 + status_offset)

        # 测试复数类型
        assert_equal(converter.upgrade('0j'), complex('0j'))
        # 断言状态为 3 + status_offset
        assert_equal(converter._status, 3 + status_offset)

        # 测试字符串类型
        # 注意长双精度类型已被跳过,因此状态增加 2。所有的 unicode 转换应该都成功(8)。
        for s in ['a', b'a']:
            res = converter.upgrade(s)
            assert_(type(res) is str)
            assert_equal(res, 'a')
            assert_equal(converter._status, 8 + status_offset)

    def test_missing(self):
        "Tests the use of missing values."
        # 创建一个带有自定义缺失值的 StringConverter
        converter = StringConverter(missing_values=('missing', 'missed'))
        converter.upgrade('0')
        assert_equal(converter('0'), 0)
        assert_equal(converter(''), converter.default)
        assert_equal(converter('missing'), converter.default)
        assert_equal(converter('missed'), converter.default)
        # 测试不存在的值是否会引发 ValueError 异常
        try:
            converter('miss')
        except ValueError:
            pass

    def test_upgrademapper(self):
        "Tests updatemapper"
        # 创建一个日期解析器函数
        dateparser = _bytes_to_date
        # 保存原始的 mapper 列表
        _original_mapper = StringConverter._mapper[:]
        try:
            # 更新 mapper 使用日期解析器和指定的日期作为默认值
            StringConverter.upgrade_mapper(dateparser, date(2000, 1, 1))
            # 创建一个使用新 mapper 的 StringConverter
            convert = StringConverter(dateparser, date(2000, 1, 1))
            # 测试日期转换是否正确
            test = convert('2001-01-01')
            assert_equal(test, date(2001, 1, 1))
            test = convert('2009-01-01')
            assert_equal(test, date(2009, 1, 1))
            # 测试空字符串是否使用默认日期
            test = convert('')
            assert_equal(test, date(2000, 1, 1))
        finally:
            # 恢复原始的 mapper 列表
            StringConverter._mapper = _original_mapper
    # 定义测试方法,验证字符串转对象函数是否被正确识别
    def test_string_to_object(self):
        # 备份 StringConverter._mapper 列表
        old_mapper = StringConverter._mapper[:]  # copy of list
        # 创建 StringConverter 实例,使用 _bytes_to_date 作为转换函数
        conv = StringConverter(_bytes_to_date)
        # 断言新的 StringConverter._mapper 与备份的列表相等
        assert_equal(conv._mapper, old_mapper)
        # 断言 conv 实例具有属性 'default'
        assert_(hasattr(conv, 'default'))

    # 定义测试方法,验证不会丢失显式默认值
    def test_keep_default(self):
        # 创建 StringConverter 实例,设定默认值为 -999,且不会丢失缺失值
        converter = StringConverter(None, missing_values='',
                                    default=-999)
        # 使用 upgrade 方法更新转换器
        converter.upgrade('3.14159265')
        # 断言转换器的默认值为 -999
        assert_equal(converter.default, -999)
        # 断言转换器的类型为 float 的 NumPy 数据类型
        assert_equal(converter.type, np.dtype(float))
        #
        # 创建 StringConverter 实例,设定默认值为 0,且不会丢失缺失值
        converter = StringConverter(
            None, missing_values='', default=0)
        # 使用 upgrade 方法更新转换器
        converter.upgrade('3.14159265')
        # 断言转换器的默认值为 0
        assert_equal(converter.default, 0)
        # 断言转换器的类型为 float 的 NumPy 数据类型

    # 定义测试方法,验证不会丢失默认值为零
    def test_keep_default_zero(self):
        # 创建 StringConverter 实例,设定类型为 int,默认值为 0,不会丢失缺失值
        converter = StringConverter(int, default=0,
                                    missing_values="N/A")
        # 断言转换器的默认值为 0
        assert_equal(converter.default, 0)

    # 定义测试方法,验证不会丢失缺失值设定
    def test_keep_missing_values(self):
        # 创建 StringConverter 实例,设定类型为 int,默认值为 0,不会丢失缺失值
        converter = StringConverter(int, default=0,
                                    missing_values="N/A")
        # 断言转换器的缺失值设定包括空字符串和 'N/A'
        assert_equal(
            converter.missing_values, {'', 'N/A'})

    # 定义测试方法,验证可以指定 int64 类型整数
    def test_int64_dtype(self):
        # 创建 StringConverter 实例,设定类型为 int64,默认值为 0
        converter = StringConverter(np.int64, default=0)
        # 设定待转换的字符串值
        val = "-9223372036854775807"
        # 断言转换后的值为 -9223372036854775807
        assert_(converter(val) == -9223372036854775807)
        # 设定另一个待转换的字符串值
        val = "9223372036854775807"
        # 断言转换后的值为 9223372036854775807
        assert_(converter(val) == 9223372036854775807)

    # 定义测试方法,验证可以指定 uint64 类型整数
    def test_uint64_dtype(self):
        # 创建 StringConverter 实例,设定类型为 uint64,默认值为 0
        converter = StringConverter(np.uint64, default=0)
        # 设定待转换的字符串值
        val = "9223372043271415339"
        # 断言转换后的值为 9223372043271415339
        assert_(converter(val) == 9223372043271415339)
class TestMiscFunctions:

    def test_has_nested_dtype():
        "Test has_nested_dtype"
        # 创建一个浮点类型的 NumPy 数据类型对象 ndtype
        ndtype = np.dtype(float)
        # 调用函数检查 ndtype 是否有嵌套字段,断言结果应为 False
        assert_equal(has_nested_fields(ndtype), False)
        
        # 创建一个复合结构的 NumPy 数据类型对象 ndtype
        ndtype = np.dtype([('A', '|S3'), ('B', float)])
        # 调用函数检查 ndtype 是否有嵌套字段,断言结果应为 False
        assert_equal(has_nested_fields(ndtype), False)
        
        # 创建一个更复杂的复合结构的 NumPy 数据类型对象 ndtype
        ndtype = np.dtype([('A', int), ('B', [('BA', float), ('BB', '|S1')])])
        # 调用函数检查 ndtype 是否有嵌套字段,断言结果应为 True
        assert_equal(has_nested_fields(ndtype), True)
    def test_easy_dtype(self):
        "Test ndtype on dtypes"
        # 定义一个简单的数据类型
        ndtype = float
        # 断言函数返回的数据类型与预期的 numpy 数据类型相等
        assert_equal(easy_dtype(ndtype), np.dtype(float))
        
        # 使用字符串定义数据类型,并且不指定字段名
        ndtype = "i4, f8"
        # 断言函数返回的数据类型与预期的 numpy 结构化数据类型相等,自动分配字段名
        assert_equal(easy_dtype(ndtype),
                     np.dtype([('f0', "i4"), ('f1', "f8")]))
        
        # 使用字符串定义数据类型,不指定字段名,并且设置了不同的默认字段格式
        assert_equal(easy_dtype(ndtype, defaultfmt="field_%03i"),
                     np.dtype([('field_000', "i4"), ('field_001', "f8")]))
        
        # 使用字符串定义数据类型,并且指定字段名
        ndtype = "i4, f8"
        # 断言函数返回的数据类型与预期的 numpy 结构化数据类型相等,使用指定的字段名
        assert_equal(easy_dtype(ndtype, names="a, b"),
                     np.dtype([('a', "i4"), ('b', "f8")]))
        
        # 使用字符串定义数据类型,并且指定了过多的字段名
        ndtype = "i4, f8"
        assert_equal(easy_dtype(ndtype, names="a, b, c"),
                     np.dtype([('a', "i4"), ('b', "f8")]))
        
        # 使用字符串定义数据类型,并且指定了不足的字段名
        ndtype = "i4, f8"
        assert_equal(easy_dtype(ndtype, names=", b"),
                     np.dtype([('f0', "i4"), ('b', "f8")]))
        
        # 使用字符串定义数据类型,指定字段名,并且设置了不同的默认字段格式
        assert_equal(easy_dtype(ndtype, names="a", defaultfmt="f%02i"),
                     np.dtype([('a', "i4"), ('f00', "f8")]))
        
        # 使用元组列表定义数据类型,不指定字段名
        ndtype = [('A', int), ('B', float)]
        assert_equal(easy_dtype(ndtype), np.dtype([('A', int), ('B', float)]))
        
        # 使用元组列表定义数据类型,并且指定字段名
        assert_equal(easy_dtype(ndtype, names="a,b"),
                     np.dtype([('a', int), ('b', float)]))
        
        # 使用元组列表定义数据类型,并且指定了不足的字段名
        assert_equal(easy_dtype(ndtype, names="a"),
                     np.dtype([('a', int), ('f0', float)]))
        
        # 使用元组列表定义数据类型,并且指定了过多的字段名
        assert_equal(easy_dtype(ndtype, names="a,b,c"),
                     np.dtype([('a', int), ('b', float)]))
        
        # 使用类型列表定义数据类型,不指定字段名
        ndtype = (int, float, float)
        assert_equal(easy_dtype(ndtype),
                     np.dtype([('f0', int), ('f1', float), ('f2', float)]))
        
        # 使用类型列表定义数据类型,并且指定字段名
        ndtype = (int, float, float)
        assert_equal(easy_dtype(ndtype, names="a, b, c"),
                     np.dtype([('a', int), ('b', float), ('c', float)]))
        
        # 使用简单的 numpy 数据类型,并且指定字段名
        ndtype = np.dtype(float)
        assert_equal(easy_dtype(ndtype, names="a, b, c"),
                     np.dtype([(_, float) for _ in ('a', 'b', 'c')]))
        
        # 使用简单的 numpy 数据类型,不指定字段名但有多个字段
        ndtype = np.dtype(float)
        assert_equal(
            easy_dtype(ndtype, names=['', '', ''], defaultfmt="f%02i"),
            np.dtype([(_, float) for _ in ('f00', 'f01', 'f02')]))
    # 定义测试方法:测试 flatten_dtype 函数的各种情况
    def test_flatten_dtype(self):
        # 测试标准的数据类型 dt
        dt = np.dtype([("a", "f8"), ("b", "f8")])
        # 调用 flatten_dtype 函数对数据类型进行扁平化处理
        dt_flat = flatten_dtype(dt)
        # 断言扁平化后的结果是否符合预期,应为 [float, float]
        assert_equal(dt_flat, [float, float])

        # 测试递归数据类型 dt
        dt = np.dtype([("a", [("aa", '|S1'), ("ab", '|S2')]), ("b", int)])
        # 再次调用 flatten_dtype 函数对数据类型进行扁平化处理
        dt_flat = flatten_dtype(dt)
        # 断言扁平化后的结果是否符合预期,应为 [np.dtype('|S1'), np.dtype('|S2'), int]
        assert_equal(dt_flat, [np.dtype('|S1'), np.dtype('|S2'), int])

        # 测试带有形状字段的数据类型 dt
        dt = np.dtype([("a", (float, 2)), ("b", (int, 3))])
        # 再次调用 flatten_dtype 函数对数据类型进行扁平化处理
        dt_flat = flatten_dtype(dt)
        # 断言扁平化后的结果是否符合预期,应为 [float, int]
        assert_equal(dt_flat, [float, int])

        # 继续测试带有形状字段的数据类型 dt,并且保留形状信息
        dt_flat = flatten_dtype(dt, True)
        # 断言扁平化后的结果是否符合预期,应为 [float, float, int, int, int]
        assert_equal(dt_flat, [float] * 2 + [int] * 3)

        # 测试带有标题的数据类型 dt
        dt = np.dtype([(("a", "A"), "f8"), (("b", "B"), "f8")])
        # 再次调用 flatten_dtype 函数对数据类型进行扁平化处理
        dt_flat = flatten_dtype(dt)
        # 断言扁平化后的结果是否符合预期,应为 [float, float]
        assert_equal(dt_flat, [float, float])

.\numpy\numpy\lib\tests\test__version.py

"""Tests for the NumpyVersion class.

"""
# 从numpy.testing中导入assert_和assert_raises函数
from numpy.testing import assert_, assert_raises
# 从numpy.lib中导入NumpyVersion类
from numpy.lib import NumpyVersion


# 定义测试函数test_main_versions
def test_main_versions():
    # 断言NumpyVersion('1.8.0')等于字符串'1.8.0'
    assert_(NumpyVersion('1.8.0') == '1.8.0')
    # 遍历列表,断言NumpyVersion('1.8.0')小于列表中的版本字符串
    for ver in ['1.9.0', '2.0.0', '1.8.1', '10.0.1']:
        assert_(NumpyVersion('1.8.0') < ver)
    # 遍历列表,断言NumpyVersion('1.8.0')大于列表中的版本字符串
    for ver in ['1.7.0', '1.7.1', '0.9.9']:
        assert_(NumpyVersion('1.8.0') > ver)


# 定义测试函数test_version_1_point_10
def test_version_1_point_10():
    # regression test for gh-2998.
    # 断言NumpyVersion('1.9.0')小于字符串'1.10.0'
    assert_(NumpyVersion('1.9.0') < '1.10.0')
    # 断言NumpyVersion('1.11.0')小于字符串'1.11.1'
    assert_(NumpyVersion('1.11.0') < '1.11.1')
    # 断言NumpyVersion('1.11.0')等于字符串'1.11.0'
    assert_(NumpyVersion('1.11.0') == '1.11.0')
    # 断言NumpyVersion('1.99.11')小于字符串'1.99.12'
    assert_(NumpyVersion('1.99.11') < '1.99.12')


# 定义测试函数test_alpha_beta_rc
def test_alpha_beta_rc():
    # 断言NumpyVersion('1.8.0rc1')等于字符串'1.8.0rc1'
    assert_(NumpyVersion('1.8.0rc1') == '1.8.0rc1')
    # 遍历列表,断言NumpyVersion('1.8.0rc1')小于列表中的版本字符串
    for ver in ['1.8.0', '1.8.0rc2']:
        assert_(NumpyVersion('1.8.0rc1') < ver)
    # 遍历列表,断言NumpyVersion('1.8.0rc1')大于列表中的版本字符串
    for ver in ['1.8.0a2', '1.8.0b3', '1.7.2rc4']:
        assert_(NumpyVersion('1.8.0rc1') > ver)
    # 断言NumpyVersion('1.8.0b1')大于字符串'1.8.0a2'
    assert_(NumpyVersion('1.8.0b1') > '1.8.0a2')


# 定义测试函数test_dev_version
def test_dev_version():
    # 断言NumpyVersion('1.9.0.dev-Unknown')小于字符串'1.9.0'
    assert_(NumpyVersion('1.9.0.dev-Unknown') < '1.9.0')
    # 遍历列表,断言NumpyVersion('1.9.0.dev-f16acvda')小于列表中的版本字符串
    for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev-ffffffff']:
        assert_(NumpyVersion('1.9.0.dev-f16acvda') < ver)
    # 断言NumpyVersion('1.9.0.dev-f16acvda')等于字符串'1.9.0.dev-11111111'
    assert_(NumpyVersion('1.9.0.dev-f16acvda') == '1.9.0.dev-11111111')


# 定义测试函数test_dev_a_b_rc_mixed
def test_dev_a_b_rc_mixed():
    # 断言NumpyVersion('1.9.0a2.dev-f16acvda')等于字符串'1.9.0a2.dev-11111111'
    assert_(NumpyVersion('1.9.0a2.dev-f16acvda') == '1.9.0a2.dev-11111111')
    # 断言NumpyVersion('1.9.0a2.dev-6acvda54')小于字符串'1.9.0a2'
    assert_(NumpyVersion('1.9.0a2.dev-6acvda54') < '1.9.0a2')


# 定义测试函数test_dev0_version
def test_dev0_version():
    # 断言NumpyVersion('1.9.0.dev0+Unknown')小于字符串'1.9.0'
    assert_(NumpyVersion('1.9.0.dev0+Unknown') < '1.9.0')
    # 遍历列表,断言NumpyVersion('1.9.0.dev0+f16acvda')小于列表中的版本字符串
    for ver in ['1.9.0', '1.9.0a1', '1.9.0b2', '1.9.0b2.dev0+ffffffff']:
        assert_(NumpyVersion('1.9.0.dev0+f16acvda') < ver)
    # 断言NumpyVersion('1.9.0.dev0+f16acvda')等于字符串'1.9.0.dev0+11111111'
    assert_(NumpyVersion('1.9.0.dev0+f16acvda') == '1.9.0.dev0+11111111')


# 定义测试函数test_dev0_a_b_rc_mixed
def test_dev0_a_b_rc_mixed():
    # 断言NumpyVersion('1.9.0a2.dev0+f16acvda')等于字符串'1.9.0a2.dev0+11111111'
    assert_(NumpyVersion('1.9.0a2.dev0+f16acvda') == '1.9.0a2.dev0+11111111')
    # 断言NumpyVersion('1.9.0a2.dev0+6acvda54')小于字符串'1.9.0a2'
    assert_(NumpyVersion('1.9.0a2.dev0+6acvda54') < '1.9.0a2')


# 定义测试函数test_raises
def test_raises():
    # 遍历列表,对于每个版本字符串,断言调用NumpyVersion(ver)会引发ValueError异常
    for ver in ['1.9', '1,9.0', '1.7.x']:
        assert_raises(ValueError, NumpyVersion, ver)

.\numpy\numpy\lib\tests\__init__.py

# 定义一个名为 `calculate_total` 的函数,接受一个参数 `items`
def calculate_total(items):
    # 初始化变量 `total` 为 0
    total = 0
    # 对于 `item` 中的每个元素,执行以下操作:
    for item in items:
        # 将 `item` 的值加到 `total` 上
        total += item
    # 返回累加后的结果 `total`
    return total

.\numpy\numpy\lib\user_array.py

# 从 _user_array_impl 模块中导入 __doc__ 和 container 变量
from ._user_array_impl import __doc__, container

.\numpy\numpy\lib\_arraypad_impl.py

"""
The arraypad module contains a group of functions to pad values onto the edges
of an n-dimensional array.

"""
import numpy as np
from numpy._core.overrides import array_function_dispatch
from numpy.lib._index_tricks_impl import ndindex


__all__ = ['pad']


###############################################################################
# Private utility functions.


def _round_if_needed(arr, dtype):
    """
    Rounds arr inplace if destination dtype is integer.

    Parameters
    ----------
    arr : ndarray
        Input array.
    dtype : dtype
        The dtype of the destination array.
    """
    if np.issubdtype(dtype, np.integer):
        arr.round(out=arr)


def _slice_at_axis(sl, axis):
    """
    Construct tuple of slices to slice an array in the given dimension.

    Parameters
    ----------
    sl : slice
        The slice for the given dimension.
    axis : int
        The axis to which `sl` is applied. All other dimensions are left
        "unsliced".

    Returns
    -------
    sl : tuple of slices
        A tuple with slices matching `shape` in length.

    Examples
    --------
    >>> _slice_at_axis(slice(None, 3, -1), 1)
    (slice(None, None, None), slice(None, 3, -1), (...,))
    """
    return (slice(None),) * axis + (sl,) + (...,)


def _view_roi(array, original_area_slice, axis):
    """
    Get a view of the current region of interest during iterative padding.

    When padding multiple dimensions iteratively corner values are
    unnecessarily overwritten multiple times. This function reduces the
    working area for the first dimensions so that corners are excluded.

    Parameters
    ----------
    array : ndarray
        The array with the region of interest.
    original_area_slice : tuple of slices
        Denotes the area with original values of the unpadded array.
    axis : int
        The currently padded dimension assuming that `axis` is padded before
        `axis` + 1.

    Returns
    -------
    roi : ndarray
        The region of interest of the original `array`.
    """
    axis += 1
    sl = (slice(None),) * axis + original_area_slice[axis:]
    return array[sl]


def _pad_simple(array, pad_width, fill_value=None):
    """
    Pad array on all sides with either a single value or undefined values.

    Parameters
    ----------
    array : ndarray
        Array to grow.
    pad_width : sequence of tuple[int, int]
        Pad width on both sides for each dimension in `arr`.
    fill_value : scalar, optional
        If provided the padded area is filled with this value, otherwise
        the pad area left undefined.

    Returns
    -------
    padded : ndarray
        The padded array with the same dtype as`array`. Its order will default
        to C-style if `array` is not F-contiguous.
    original_area_slice : tuple
        A tuple of slices pointing to the area of the original array.
    """
    # Allocate grown array
    pad_width = np.asarray(pad_width)
    new_shape = tuple((np.array(array.shape) + pad_width.sum(axis=1)).tolist())
    padded = np.empty(new_shape, dtype=array.dtype, order='C' if not array.flags.f_contiguous else 'F')

    # Copy original array into padded array
    original_area_slice = tuple(slice(pad_width[i, 0], pad_width[i, 0] + array.shape[i]) for i in range(array.ndim))
    padded[original_area_slice] = array

    # Fill padded areas if fill_value is provided
    if fill_value is not None:
        padded[...] = fill_value

    return padded, original_area_slice
    # 计算扩展后的数组形状,将左边界、数组原始大小、右边界相加
    new_shape = tuple(
        left + size + right
        for size, (left, right) in zip(array.shape, pad_width)
    )
    # 确定数组的存储顺序为 'F'(Fortran顺序)或 'C'(C顺序)
    order = 'F' if array.flags.fnc else 'C'  # Fortran and not also C-order
    # 创建一个空数组,使用指定的数据类型和存储顺序
    padded = np.empty(new_shape, dtype=array.dtype, order=order)

    # 如果指定了填充值,用填充值填充扩展后的数组
    if fill_value is not None:
        padded.fill(fill_value)

    # 将原始数组复制到扩展后的数组的正确位置
    # 计算原始数组在扩展后数组中的切片范围
    original_area_slice = tuple(
        slice(left, left + size)
        for size, (left, right) in zip(array.shape, pad_width)
    )
    # 将原始数组复制到扩展后数组的指定切片位置
    padded[original_area_slice] = array

    # 返回扩展后的数组和原始数组在扩展后数组中的切片范围
    return padded, original_area_slice
# 设置给定维度中的空白填充区域。
def _set_pad_area(padded, axis, width_pair, value_pair):
    # 创建左侧切片,以指定维度上的前width_pair[0]个位置
    left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
    # 将左侧填充区域设置为value_pair[0]
    padded[left_slice] = value_pair[0]

    # 创建右侧切片,以指定维度上从padded.shape[axis] - width_pair[1]到结尾的位置
    right_slice = _slice_at_axis(
        slice(padded.shape[axis] - width_pair[1], None), axis)
    # 将右侧填充区域设置为value_pair[1]
    padded[right_slice] = value_pair[1]


# 从给定维度的空白填充数组中检索边缘值。
def _get_edges(padded, axis, width_pair):
    # 左侧边缘的索引为width_pair[0]
    left_index = width_pair[0]
    # 创建左侧切片,指定维度上的[left_index, left_index + 1)范围
    left_slice = _slice_at_axis(slice(left_index, left_index + 1), axis)
    # 获取左侧边缘值
    left_edge = padded[left_slice]

    # 右侧边缘的索引为padded.shape[axis] - width_pair[1]
    right_index = padded.shape[axis] - width_pair[1]
    # 创建右侧切片,指定维度上的[right_index - 1, right_index)范围
    right_slice = _slice_at_axis(slice(right_index - 1, right_index), axis)
    # 获取右侧边缘值
    right_edge = padded[right_slice]

    return left_edge, right_edge


# 在给定维度的空白填充数组中构造线性斜坡。
def _get_linear_ramps(padded, axis, width_pair, end_value_pair):
    # 获取边缘值对
    edge_pair = _get_edges(padded, axis, width_pair)

    # 生成左侧和右侧线性斜坡
    left_ramp, right_ramp = (
        np.linspace(
            start=end_value,
            stop=edge.squeeze(axis),  # 使用edge的指定维度上的值
            num=width,
            endpoint=False,
            dtype=padded.dtype,
            axis=axis
        )
        for end_value, edge, width in zip(
            end_value_pair, edge_pair, width_pair
        )
    )
    # 在指定维度上反转线性空间
    right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)]
    # 返回更新后的左线性空间和右线性空间
    return left_ramp, right_ramp
# 计算空填充数组在给定维度上的统计量。

def _get_stats(padded, axis, width_pair, length_pair, stat_func):
    # 计算包含原始值区域的边界索引
    left_index = width_pair[0]
    right_index = padded.shape[axis] - width_pair[1]
    # 计算有效区域的长度
    max_length = right_index - left_index

    # 限制统计长度不超过 max_length
    left_length, right_length = length_pair
    if left_length is None or max_length < left_length:
        left_length = max_length
    if right_length is None or max_length < right_length:
        right_length = max_length

    if (left_length == 0 or right_length == 0) \
            and stat_func in {np.amax, np.amin}:
        # 如果左右统计长度为 0,并且统计函数是 np.amax 或 np.amin,
        # 抛出更具描述性的异常信息
        raise ValueError("stat_length of 0 yields no value for padding")

    # 计算左侧的统计量
    left_slice = _slice_at_axis(
        slice(left_index, left_index + left_length), axis)
    left_chunk = padded[left_slice]
    left_stat = stat_func(left_chunk, axis=axis, keepdims=True)
    _round_if_needed(left_stat, padded.dtype)

    if left_length == right_length == max_length:
        # 如果左右统计长度相等且等于 max_length,则右侧的统计量必须与左侧相同,
        # 直接返回左侧统计量
        return left_stat, left_stat

    # 计算右侧的统计量
    right_slice = _slice_at_axis(
        slice(right_index - right_length, right_index), axis)
    right_chunk = padded[right_slice]
    right_stat = stat_func(right_chunk, axis=axis, keepdims=True)
    _round_if_needed(right_stat, padded.dtype)

    return left_stat, right_stat


def _set_reflect_both(padded, axis, width_pair, method, 
                      original_period, include_edge=False):
    """
    用反射方式填充数组 `padded` 在指定轴 `axis` 上。

    Parameters
    ----------
    padded : ndarray
        任意形状的输入数组。
    axis : int
        要填充的 `padded` 的轴。
    width_pair : (int, int)
        在给定维度上标记填充区域两侧的宽度对。
    method : str
        # 控制反射方法的选择;选项为 'even' 或 'odd'。
        Controls method of reflection; options are 'even' or 'odd'.
    original_period : int
        # `arr` 的 `axis` 上数据的原始长度。
        Original length of data on `axis` of `arr`.
    include_edge : bool
        # 如果为真,则在反射中包含边缘值;否则,边缘值形成对称轴的一部分。
        If true, edge value is included in reflection, otherwise the edge
        value forms the symmetric axis to the reflection.

    Returns
    -------
    pad_amt : tuple of ints, length 2
        # 沿 `axis` 要进行填充的新索引位置。如果这两个值都为0,则在此维度上进行填充。
        New index positions of padding to do along the `axis`. If these are
        both 0, padding is done in this dimension.
    """
    left_pad, right_pad = width_pair
    # 计算未填充数组的有效长度
    old_length = padded.shape[axis] - right_pad - left_pad
    
    if include_edge:
        # 如果要包含边缘值,在计算反射量时需要进行偏移
        # 避免只使用原始区域的子集来进行包装
        old_length = old_length // original_period * original_period
        # 边缘值被包含,需要将填充量偏移1
        edge_offset = 1
    else:
        # 如果不包含边缘值,在计算反射量时也需要进行偏移
        # 避免只使用原始区域的子集来进行包装
        old_length = ((old_length - 1) // (original_period - 1)
            * (original_period - 1) + 1)
        edge_offset = 0  # 不包含边缘值,填充量无需偏移
        old_length -= 1  # 但必须从块中省略

    if left_pad > 0:
        # 在左侧使用反射值进行填充:
        # 首先限制块的大小,不能大于填充区域
        chunk_length = min(old_length, left_pad)
        # 从右到左切片,停在或靠近边缘,相对于停止开始
        stop = left_pad - edge_offset
        start = stop + chunk_length
        left_slice = _slice_at_axis(slice(start, stop, -1), axis)
        left_chunk = padded[left_slice]

        if method == "odd":
            # 反转块并与边缘对齐,如果方法为 'odd'
            edge_slice = _slice_at_axis(slice(left_pad, left_pad + 1), axis)
            left_chunk = 2 * padded[edge_slice] - left_chunk

        # 将块插入填充区域
        start = left_pad - chunk_length
        stop = left_pad
        pad_area = _slice_at_axis(slice(start, stop), axis)
        padded[pad_area] = left_chunk
        # 调整指向下一次迭代的左边缘
        left_pad -= chunk_length
    if right_pad > 0:
        # 如果右侧需要填充
        # 首先限制填充区域的长度,不能大于需要填充的长度
        chunk_length = min(old_length, right_pad)
        # 从右向左切片,从边缘或其相邻处开始,向左切片的终点相对于开始的位置
        start = -right_pad + edge_offset - 2
        stop = start - chunk_length
        # 在指定的轴上创建切片对象,用于获取右侧的数据块
        right_slice = _slice_at_axis(slice(start, stop, -1), axis)
        right_chunk = padded[right_slice]

        if method == "odd":
            # 如果使用奇数方法(odd),对块进行取反并与边缘对齐
            edge_slice = _slice_at_axis(
                slice(-right_pad - 1, -right_pad), axis)
            right_chunk = 2 * padded[edge_slice] - right_chunk

        # 将右侧数据块插入到填充区域
        start = padded.shape[axis] - right_pad
        stop = start + chunk_length
        pad_area = _slice_at_axis(slice(start, stop), axis)
        padded[pad_area] = right_chunk
        # 调整右侧填充的指针位置,为下一次迭代做准备
        right_pad -= chunk_length

    # 返回更新后的左侧和右侧填充值
    return left_pad, right_pad
# 在输入数组 `x` 上操作,将其转换为形状符合 `ndim` 维度的元组对
def _as_pairs(x, ndim, as_index=False):
    # 如果输入 `x` 是整数,则将其视为在所有维度上相同的元组对
    if isinstance(x, int):
        return tuple((x,) * ndim)
    # 如果输入 `x` 是一个可迭代对象且 `as_index` 为真,则返回其前 `ndim` 个元素作为索引
    elif isinstance(x, collections.Iterable):
        return tuple(x[:ndim]) if as_index else tuple(x)
    # 其他情况,引发类型错误
    else:
        raise TypeError('Expected integer or iterable, got {}'.format(type(x)))
    # 如果 x 为 None,则返回一个形状为 (ndim, 2) 的嵌套迭代器,每个元素为 (None, None)
    if x is None:
        # Pass through None as a special case, otherwise np.round(x) fails
        # with an AttributeError
        return ((None, None),) * ndim

    # 将 x 转换为 numpy 数组,以便后续处理
    x = np.array(x)

    # 如果需要将 x 转换为索引形式(整数并确保非负),则进行相应处理
    if as_index:
        x = np.round(x).astype(np.intp, copy=False)

    # 如果 x 的维度小于 3,进行优化处理
    if x.ndim < 3:
        # 优化:对于 x 只有一个或两个元素的情况,可能采用更快的路径处理。
        # `np.broadcast_to` 也可以处理这些情况,但目前速度较慢

        # 如果 x 只有一个元素
        if x.size == 1:
            # 将 x 拉平,确保对 x.ndim == 0, 1, 2 的情况都适用
            x = x.ravel()
            # 如果 as_index 为 True 且 x 小于 0,则抛出 ValueError
            if as_index and x < 0:
                raise ValueError("index can't contain negative values")
            # 返回形状为 (ndim, 2) 的嵌套迭代器,每个元素为 (x[0], x[0])
            return ((x[0], x[0]),) * ndim

        # 如果 x 有两个元素且不是 (2, 1) 的形状
        if x.size == 2 and x.shape != (2, 1):
            # 将 x 拉平,确保对 x[0], x[1] 的操作适用
            x = x.ravel()
            # 如果 as_index 为 True 且 x[0] 或 x[1] 小于 0,则抛出 ValueError
            if as_index and (x[0] < 0 or x[1] < 0):
                raise ValueError("index can't contain negative values")
            # 返回形状为 (ndim, 2) 的嵌套迭代器,每个元素为 (x[0], x[1])
            return ((x[0], x[1]),) * ndim

    # 如果 as_index 为 True 且 x 中最小值小于 0,则抛出 ValueError
    if as_index and x.min() < 0:
        raise ValueError("index can't contain negative values")

    # 将 x 广播到形状为 (ndim, 2),并转换为列表形式返回
    # 使用 `tolist` 转换数组为列表似乎可以提高迭代和索引结果时的性能(见 `pad` 中的使用)
    return np.broadcast_to(x, (ndim, 2)).tolist()
# 定义一个内部函数 _pad_dispatcher,用于分发数组填充操作
def _pad_dispatcher(array, pad_width, mode=None, **kwargs):
    # 返回一个元组,其中包含输入的数组 array
    return (array,)


###############################################################################
# Public functions

# 使用装饰器 array_function_dispatch 将 _pad_dispatcher 函数与 numpy 模块关联起来
@array_function_dispatch(_pad_dispatcher, module='numpy')
# 定义公共函数 pad,用于对数组进行填充操作
def pad(array, pad_width, mode='constant', **kwargs):
    """
    Pad an array.

    Parameters
    ----------
    array : array_like of rank N
        The array to pad.
    pad_width : {sequence, array_like, int}
        Number of values padded to the edges of each axis.
        ``((before_1, after_1), ... (before_N, after_N))`` unique pad widths
        for each axis.
        ``(before, after)`` or ``((before, after),)`` yields same before
        and after pad for each axis.
        ``(pad,)`` or ``int`` is a shortcut for before = after = pad width
        for all axes.
    mode : str or function, optional
        One of the following string values or a user supplied function.

        'constant' (default)
            Pads with a constant value.
        'edge'
            Pads with the edge values of array.
        'linear_ramp'
            Pads with the linear ramp between end_value and the
            array edge value.
        'maximum'
            Pads with the maximum value of all or part of the
            vector along each axis.
        'mean'
            Pads with the mean value of all or part of the
            vector along each axis.
        'median'
            Pads with the median value of all or part of the
            vector along each axis.
        'minimum'
            Pads with the minimum value of all or part of the
            vector along each axis.
        'reflect'
            Pads with the reflection of the vector mirrored on
            the first and last values of the vector along each
            axis.
        'symmetric'
            Pads with the reflection of the vector mirrored
            along the edge of the array.
        'wrap'
            Pads with the wrap of the vector along the axis.
            The first values are used to pad the end and the
            end values are used to pad the beginning.
        'empty'
            Pads with undefined values.

            .. versionadded:: 1.17

        <function>
            Padding function, see Notes.
    stat_length : sequence or int, optional
        Used in 'maximum', 'mean', 'median', and 'minimum'.  Number of
        values at edge of each axis used to calculate the statistic value.

        ``((before_1, after_1), ... (before_N, after_N))`` unique statistic
        lengths for each axis.

        ``(before, after)`` or ``((before, after),)`` yields same before
        and after statistic lengths for each axis.

        ``(stat_length,)`` or ``int`` is a shortcut for
        ``before = after = statistic`` length for all axes.

        Default is ``None``, to use the entire axis.
    """
    constant_values : sequence or scalar, optional
        # 用于 'constant' 模式的参数。设置每个轴的填充值。
        # 格式为 ((before_1, after_1), ... (before_N, after_N)),每个轴有唯一的填充常量。
        # 如果是 (before, after) 或 ((before, after),),则每个轴使用相同的填充常量。
        # 如果是 (constant,) 或 constant,则所有轴的填充常量相同。
        默认值为 0.

    end_values : sequence or scalar, optional
        # 用于 'linear_ramp' 模式的参数。设置线性斜坡的结束值,形成填充数组的边缘。
        # 格式为 ((before_1, after_1), ... (before_N, after_N)),每个轴有唯一的结束值。
        # 如果是 (before, after) 或 ((before, after),),则每个轴使用相同的结束值。
        # 如果是 (constant,) 或 constant,则所有轴的结束值相同。
        默认值为 0.

    reflect_type : {'even', 'odd'}, optional
        # 用于 'reflect' 和 'symmetric' 模式的参数。
        # 'even' 表示默认的反射模式,围绕边缘值没有改变的反射。
        # 'odd' 表示扩展部分由边缘值的两倍减去反射值得到。
    # 导入 NumPy 库,用于科学计算和数组操作
    import numpy as np
    
    # 将输入参数 `array` 转换为 NumPy 数组
    array = np.asarray(array)
    
    # 将输入参数 `pad_width` 转换为 NumPy 数组
    pad_width = np.asarray(pad_width)
    
    # 检查 `pad_width` 数组元素的数据类型是否为整数,若不是则抛出类型错误异常
    if not pad_width.dtype.kind == 'i':
        raise TypeError('`pad_width` must be of integral type.')
    
    # 将 `pad_width` 数组广播为形状 (array.ndim, 2)
    pad_width = _as_pairs(pad_width, array.ndim, as_index=True)
    if callable(mode):
        # 如果 mode 是可调用对象,则使用用户提供的函数进行处理,结合 np.apply_along_axis 的旧行为
        function = mode
        # 创建一个新的零填充数组
        padded, _ = _pad_simple(array, pad_width, fill_value=0)
        # 然后沿着每个轴应用函数

        for axis in range(padded.ndim):
            # 使用 ndindex 迭代,类似于 apply_along_axis,但假设函数在填充数组上原地操作。

            # 视图,将迭代轴放在最后
            view = np.moveaxis(padded, axis, -1)

            # 计算迭代轴的索引,并添加一个尾随省略号,以防止 0 维数组衰减为标量 (gh-8642)
            inds = ndindex(view.shape[:-1])
            inds = (ind + (Ellipsis,) for ind in inds)
            for ind in inds:
                function(view[ind], pad_width[axis], axis, kwargs)

        return padded

    # 确保对于当前模式没有传递不支持的关键字参数
    allowed_kwargs = {
        'empty': [], 'edge': [], 'wrap': [],
        'constant': ['constant_values'],
        'linear_ramp': ['end_values'],
        'maximum': ['stat_length'],
        'mean': ['stat_length'],
        'median': ['stat_length'],
        'minimum': ['stat_length'],
        'reflect': ['reflect_type'],
        'symmetric': ['reflect_type'],
    }
    try:
        unsupported_kwargs = set(kwargs) - set(allowed_kwargs[mode])
    except KeyError:
        # 如果出现不支持的模式,抛出 ValueError 异常
        raise ValueError("mode '{}' is not supported".format(mode)) from None
    if unsupported_kwargs:
        # 如果有不支持的关键字参数,抛出 ValueError 异常
        raise ValueError("unsupported keyword arguments for mode '{}': {}"
                         .format(mode, unsupported_kwargs))

    # 统计函数字典,用于 mode 是 "maximum", "minimum", "mean", "median" 时选择对应的 numpy 统计函数
    stat_functions = {"maximum": np.amax, "minimum": np.amin,
                      "mean": np.mean, "median": np.median}

    # 创建具有最终形状和原始值的填充数组(填充区域未定义)
    padded, original_area_slice = _pad_simple(array, pad_width)
    # 准备在所有维度上进行迭代(使用 zip 比使用 enumerate 更可读)
    axes = range(padded.ndim)

    if mode == "constant":
        # 如果 mode 是 "constant"
        values = kwargs.get("constant_values", 0)
        values = _as_pairs(values, padded.ndim)
        for axis, width_pair, value_pair in zip(axes, pad_width, values):
            # 获取感兴趣区域(ROI),即填充数组上的视图
            roi = _view_roi(padded, original_area_slice, axis)
            # 设置填充区域的值
            _set_pad_area(roi, axis, width_pair, value_pair)

    elif mode == "empty":
        # 如果 mode 是 "empty",则什么都不做,因为 _pad_simple 已经返回了正确的结果
        pass
    elif array.size == 0:
        # 如果数组为空
        # 只有 "constant" 和 "empty" 模式可以扩展空轴,其它模式都要求数组非空
        # -> 确保每个空轴只能 "用0填充"
        for axis, width_pair in zip(axes, pad_width):
            if array.shape[axis] == 0 and any(width_pair):
                raise ValueError(
                    "can't extend empty axis {} using modes other than "
                    "'constant' or 'empty'".format(axis)
                )
        # 通过检查,不需要进行更多操作,因为 _pad_simple 已经返回了正确的结果

    elif mode == "edge":
        # 如果模式为 "edge"
        for axis, width_pair in zip(axes, pad_width):
            # 获取填充后区域的视图
            roi = _view_roi(padded, original_area_slice, axis)
            # 获取边缘值对
            edge_pair = _get_edges(roi, axis, width_pair)
            # 设置填充区域
            _set_pad_area(roi, axis, width_pair, edge_pair)

    elif mode == "linear_ramp":
        # 如果模式为 "linear_ramp"
        # 获取端点值,默认为0
        end_values = kwargs.get("end_values", 0)
        end_values = _as_pairs(end_values, padded.ndim)
        for axis, width_pair, value_pair in zip(axes, pad_width, end_values):
            # 获取填充后区域的视图
            roi = _view_roi(padded, original_area_slice, axis)
            # 获取线性斜坡
            ramp_pair = _get_linear_ramps(roi, axis, width_pair, value_pair)
            # 设置填充区域
            _set_pad_area(roi, axis, width_pair, ramp_pair)

    elif mode in stat_functions:
        # 如果模式在统计函数中
        # 获取统计函数
        func = stat_functions[mode]
        # 获取统计长度,默认为None
        length = kwargs.get("stat_length", None)
        length = _as_pairs(length, padded.ndim, as_index=True)
        for axis, width_pair, length_pair in zip(axes, pad_width, length):
            # 获取填充后区域的视图
            roi = _view_roi(padded, original_area_slice, axis)
            # 获取统计值对
            stat_pair = _get_stats(roi, axis, width_pair, length_pair, func)
            # 设置填充区域
            _set_pad_area(roi, axis, width_pair, stat_pair)

    elif mode in {"reflect", "symmetric"}:
        # 如果模式是 "reflect" 或 "symmetric"
        # 获取反射类型,默认为 "even"
        method = kwargs.get("reflect_type", "even")
        # 是否包含边缘
        include_edge = True if mode == "symmetric" else False
        for axis, (left_index, right_index) in zip(axes, pad_width):
            if array.shape[axis] == 1 and (left_index > 0 or right_index > 0):
                # 如果数组在该轴上的形状为1,并且左右索引大于0
                # 对于 'reflect' 扩展单例维度是旧行为,它实际上应该引发错误。
                # 获取边缘值对
                edge_pair = _get_edges(padded, axis, (left_index, right_index))
                # 设置填充区域
                _set_pad_area(padded, axis, (left_index, right_index), edge_pair)
                continue

            # 获取填充后区域的视图
            roi = _view_roi(padded, original_area_slice, axis)
            while left_index > 0 or right_index > 0:
                # 反复填充,直到该维度的填充区域被反射值填满。
                # 如果填充区域大于当前维度中原始值的长度,则此过程是必要的。
                left_index, right_index = _set_reflect_both(
                    roi, axis, (left_index, right_index),
                    method, array.shape[axis], include_edge
                )
    # 如果模式为 "wrap",则执行以下操作
    elif mode == "wrap":
        # 遍历每个轴以及其对应的填充宽度
        for axis, (left_index, right_index) in zip(axes, pad_width):
            # 获取当前轴上原始区域的视图
            roi = _view_roi(padded, original_area_slice, axis)
            # 计算当前轴上原始值的周期长度
            original_period = padded.shape[axis] - right_index - left_index
            # 当左右两侧的填充数量大于0时,进行循环填充
            while left_index > 0 or right_index > 0:
                # 调用函数 _set_wrap_both,设置填充方式为 "wrap"
                left_index, right_index = _set_wrap_both(
                    roi, axis, (left_index, right_index), original_period)

    # 返回填充后的数组
    return padded