NumPy-源码解析-十五-

39 阅读1小时+

NumPy 源码解析(十五)

.\numpy\numpy\lib\tests\test_loadtxt.py

"""
`np.loadtxt`的特定测试,用于在将loadtxt移至C代码后进行的补充测试。
这些测试是`test_io.py`中已有测试的补充。
"""

import sys  # 导入sys模块,用于系统相关操作
import os   # 导入os模块,用于操作系统相关功能
import pytest   # 导入pytest测试框架
from tempfile import NamedTemporaryFile, mkstemp   # 导入临时文件相关函数
from io import StringIO   # 导入StringIO用于内存中文件操作

import numpy as np   # 导入NumPy库
from numpy.ma.testutils import assert_equal   # 导入NumPy的测试工具函数
from numpy.testing import assert_array_equal, HAS_REFCOUNT, IS_PYPY   # 导入NumPy的测试工具函数和相关常量


def test_scientific_notation():
    """测试科学计数法中 'e' 和 'E' 的解析是否正确。"""
    data = StringIO(
        (
            "1.0e-1,2.0E1,3.0\n"
            "4.0e-2,5.0E-1,6.0\n"
            "7.0e-3,8.0E1,9.0\n"
            "0.0e-4,1.0E-1,2.0"
        )
    )
    expected = np.array(
        [[0.1, 20., 3.0], [0.04, 0.5, 6], [0.007, 80., 9], [0, 0.1, 2]]
    )
    assert_array_equal(np.loadtxt(data, delimiter=","), expected)


@pytest.mark.parametrize("comment", ["..", "//", "@-", "this is a comment:"])
def test_comment_multiple_chars(comment):
    """测试多字符注释在加载数据时的处理。"""
    content = "# IGNORE\n1.5, 2.5# ABC\n3.0,4.0# XXX\n5.5,6.0\n"
    txt = StringIO(content.replace("#", comment))
    a = np.loadtxt(txt, delimiter=",", comments=comment)
    assert_equal(a, [[1.5, 2.5], [3.0, 4.0], [5.5, 6.0]])


@pytest.fixture
def mixed_types_structured():
    """
    提供具有结构化dtype的异构输入数据和相关结构化数组的fixture。
    """
    data = StringIO(
        (
            "1000;2.4;alpha;-34\n"
            "2000;3.1;beta;29\n"
            "3500;9.9;gamma;120\n"
            "4090;8.1;delta;0\n"
            "5001;4.4;epsilon;-99\n"
            "6543;7.8;omega;-1\n"
        )
    )
    dtype = np.dtype(
        [('f0', np.uint16), ('f1', np.float64), ('f2', 'S7'), ('f3', np.int8)]
    )
    expected = np.array(
        [
            (1000, 2.4, "alpha", -34),
            (2000, 3.1, "beta", 29),
            (3500, 9.9, "gamma", 120),
            (4090, 8.1, "delta", 0),
            (5001, 4.4, "epsilon", -99),
            (6543, 7.8, "omega", -1)
        ],
        dtype=dtype
    )
    return data, dtype, expected


@pytest.mark.parametrize('skiprows', [0, 1, 2, 3])
def test_structured_dtype_and_skiprows_no_empty_lines(
        skiprows, mixed_types_structured):
    """测试结构化dtype和跳过行数(无空行)的情况。"""
    data, dtype, expected = mixed_types_structured
    a = np.loadtxt(data, dtype=dtype, delimiter=";", skiprows=skiprows)
    assert_array_equal(a, expected[skiprows:])


def test_unpack_structured(mixed_types_structured):
    """测试结构化dtype在解包时的处理。"""
    data, dtype, expected = mixed_types_structured

    a, b, c, d = np.loadtxt(data, dtype=dtype, delimiter=";", unpack=True)
    assert_array_equal(a, expected["f0"])
    assert_array_equal(b, expected["f1"])
    assert_array_equal(c, expected["f2"])
    assert_array_equal(d, expected["f3"])


def test_structured_dtype_with_shape():
    """测试带形状的结构化dtype的情况。"""
    dtype = np.dtype([("a", "u1", 2), ("b", "u1", 2)])
    data = StringIO("0,1,2,3\n6,7,8,9\n")
    expected = np.array([((0, 1), (2, 3)), ((6, 7), (8, 9))], dtype=dtype)
    # 使用 numpy 库中的 loadtxt 函数读取数据文件,并使用指定的逗号分隔符和数据类型进行加载
    assert_array_equal(np.loadtxt(data, delimiter=",", dtype=dtype), expected)
def test_structured_dtype_with_multi_shape():
    # 定义一个结构化的 NumPy 数据类型,包含字段 'a',每个元素是一个 2x2 的无符号整数数组
    dtype = np.dtype([("a", "u1", (2, 2))])
    # 创建一个包含数据的字符串流对象
    data = StringIO("0 1 2 3\n")
    # 期望的 NumPy 数组,包含一个元素,该元素是一个 2x2 的数组,元素值为 (0, 1, 2, 3)
    expected = np.array([(((0, 1), (2, 3)),)], dtype=dtype)
    # 断言使用 np.loadtxt 函数加载数据,并与期望的数组进行比较
    assert_array_equal(np.loadtxt(data, dtype=dtype), expected)


def test_nested_structured_subarray():
    # 测试来自 GitHub issue #16678
    # 定义一个结构化数据类型 'point',包含字段 'x' 和 'y',每个字段为浮点数
    point = np.dtype([('x', float), ('y', float)])
    # 定义一个结构化数据类型 'dt',包含字段 'code' 和 'points','points' 是一个包含两个 'point' 结构的数组
    dt = np.dtype([('code', int), ('points', point, (2,))])
    # 创建一个包含数据的字符串流对象
    data = StringIO("100,1,2,3,4\n200,5,6,7,8\n")
    # 期望的 NumPy 数组,包含两个元素,每个元素包含一个整数和两个点的数组
    expected = np.array(
        [
            (100, [(1., 2.), (3., 4.)]),
            (200, [(5., 6.), (7., 8.)]),
        ],
        dtype=dt
    )
    # 断言使用 np.loadtxt 函数加载数据,并与期望的数组进行比较,指定分隔符为逗号
    assert_array_equal(np.loadtxt(data, dtype=dt, delimiter=","), expected)


def test_structured_dtype_offsets():
    # 一个对齐的结构化数据类型会有额外的填充
    # 定义一个结构化数据类型 'dt',包含多个整数字段,对齐方式为 True
    dt = np.dtype("i1, i4, i1, i4, i1, i4", align=True)
    # 创建一个包含数据的字符串流对象
    data = StringIO("1,2,3,4,5,6\n7,8,9,10,11,12\n")
    # 期望的 NumPy 数组,包含两个元素,每个元素是一个包含整数的元组
    expected = np.array([(1, 2, 3, 4, 5, 6), (7, 8, 9, 10, 11, 12)], dtype=dt)
    # 断言使用 np.loadtxt 函数加载数据,并与期望的数组进行比较,指定分隔符为逗号
    assert_array_equal(np.loadtxt(data, delimiter=",", dtype=dt), expected)


@pytest.mark.parametrize("param", ("skiprows", "max_rows"))
def test_exception_negative_row_limits(param):
    """skiprows 和 max_rows 应当对负参数抛出异常。"""
    # 使用 pytest.raises 检查 np.loadtxt 函数在读取文件时,对于负的参数值抛出 ValueError 异常
    with pytest.raises(ValueError, match="argument must be nonnegative"):
        np.loadtxt("foo.bar", **{param: -3})


@pytest.mark.parametrize("param", ("skiprows", "max_rows"))
def test_exception_noninteger_row_limits(param):
    # 测试 np.loadtxt 函数对于非整数参数值抛出 TypeError 异常
    with pytest.raises(TypeError, match="argument must be an integer"):
        np.loadtxt("foo.bar", **{param: 1.0})


@pytest.mark.parametrize(
    "data, shape",
    [
        ("1 2 3 4 5\n", (1, 5)),  # 单行数据
        ("1\n2\n3\n4\n5\n", (5, 1)),  # 单列数据
    ]
)
def test_ndmin_single_row_or_col(data, shape):
    # 创建一个包含数据的字符串流对象
    arr = np.array([1, 2, 3, 4, 5])
    # 将一维数组 arr 重塑成 shape 指定的形状的二维数组 arr2d
    arr2d = arr.reshape(shape)

    # 断言使用 np.loadtxt 函数加载数据,并与一维数组 arr 进行比较
    assert_array_equal(np.loadtxt(StringIO(data), dtype=int), arr)
    # 断言使用 np.loadtxt 函数加载数据,并与一维数组 arr 进行比较,设置 ndmin=0
    assert_array_equal(np.loadtxt(StringIO(data), dtype=int, ndmin=0), arr)
    # 断言使用 np.loadtxt 函数加载数据,并与一维数组 arr 进行比较,设置 ndmin=1
    assert_array_equal(np.loadtxt(StringIO(data), dtype=int, ndmin=1), arr)
    # 断言使用 np.loadtxt 函数加载数据,并与二维数组 arr2d 进行比较,设置 ndmin=2
    assert_array_equal(np.loadtxt(StringIO(data), dtype=int, ndmin=2), arr2d)


@pytest.mark.parametrize("badval", [-1, 3, None, "plate of shrimp"])
def test_bad_ndmin(badval):
    # 测试 np.loadtxt 函数对于非法的 ndmin 参数值抛出 ValueError 异常
    with pytest.raises(ValueError, match="Illegal value of ndmin keyword"):
        np.loadtxt("foo.bar", ndmin=badval)


@pytest.mark.parametrize(
    "ws",
    (
            " ",  # 空格
            "\t",  # 制表符
            "\u2003",  # 空白字符
            "\u00A0",  # 不间断空格
            "\u3000",  # 表意空格
    )
)
def test_blank_lines_spaces_delimit(ws):
    txt = StringIO(
        f"1 2{ws}30\n\n{ws}\n"
        f"4 5 60{ws}\n  {ws}  \n"
        f"7 8 {ws} 90\n  # comment\n"
        f"3 2 1"
    )
    # 注意:`  # comment` 应当成功。除非 delimiter=None,应当使用任意空白字符(也许
    # 应当更接近 Python 实现
    # 创建一个预期的 NumPy 数组,包含指定的整数值
    expected = np.array([[1, 2, 30], [4, 5, 60], [7, 8, 90], [3, 2, 1]])
    # 使用 NumPy 的 assert_equal 函数比较两个数组是否相等
    assert_equal(
        # 使用 np.loadtxt 从文本文件中加载数据,指定数据类型为整数,分隔符为任意空白,忽略以 '#' 开始的注释
        np.loadtxt(txt, dtype=int, delimiter=None, comments="#"),
        # 将加载的数据与预期的数组进行比较
        expected
    )
# 定义一个测试函数,用于测试带有空行和注释的文本的解析
def test_blank_lines_normal_delimiter():
    # 创建一个包含特定内容的内存文本流对象
    txt = StringIO('1,2,30\n\n4,5,60\n\n7,8,90\n# comment\n3,2,1')
    # 预期的结果是一个包含特定数值的二维 NumPy 数组
    expected = np.array([[1, 2, 30], [4, 5, 60], [7, 8, 90], [3, 2, 1]])
    # 断言加载文本内容后的结果与预期结果相等
    assert_equal(
        np.loadtxt(txt, dtype=int, delimiter=',', comments="#"), expected
    )


# 使用参数化测试来测试不同数据类型的加载行数限制
@pytest.mark.parametrize("dtype", (float, object))
def test_maxrows_no_blank_lines(dtype):
    # 创建一个包含特定内容的内存文本流对象
    txt = StringIO("1.5,2.5\n3.0,4.0\n5.5,6.0")
    # 加载并限制最大行数为 2,数据类型由参数 dtype 决定
    res = np.loadtxt(txt, dtype=dtype, delimiter=",", max_rows=2)
    # 断言加载结果的数据类型与预期参数 dtype 相等
    assert_equal(res.dtype, dtype)
    # 断言加载的结果与预期的 NumPy 数组相等
    assert_equal(res, np.array([["1.5", "2.5"], ["3.0", "4.0"]], dtype=dtype))


# 使用参数化测试来测试异常情况下的错误消息处理
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
@pytest.mark.parametrize("dtype", (np.dtype("f8"), np.dtype("i2")))
def test_exception_message_bad_values(dtype):
    # 创建一个包含特定内容的内存文本流对象
    txt = StringIO("1,2\n3,XXX\n5,6")
    # 准备预期的错误消息
    msg = f"could not convert string 'XXX' to {dtype} at row 1, column 2"
    # 使用 pytest 断言捕获指定的 ValueError 异常,并匹配预期的错误消息
    with pytest.raises(ValueError, match=msg):
        np.loadtxt(txt, dtype=dtype, delimiter=",")


# 测试使用自定义转换器处理数据的加载
def test_converters_negative_indices():
    # 创建一个包含特定内容的内存文本流对象
    txt = StringIO('1.5,2.5\n3.0,XXX\n5.5,6.0')
    # 定义一个转换器,根据特定规则转换数据,例如将 'XXX' 转换为 NaN
    conv = {-1: lambda s: np.nan if s == 'XXX' else float(s)}
    # 预期的结果是一个包含特定数值的二维 NumPy 数组
    expected = np.array([[1.5, 2.5], [3.0, np.nan], [5.5, 6.0]])
    # 使用转换器加载数据,并断言加载结果与预期结果相等
    res = np.loadtxt(txt, dtype=np.float64, delimiter=",", converters=conv)
    assert_equal(res, expected)


# 测试在使用 usecols 限定列数的情况下,加载数据并处理负索引的转换
def test_converters_negative_indices_with_usecols():
    # 创建一个包含特定内容的内存文本流对象
    txt = StringIO('1.5,2.5,3.5\n3.0,4.0,XXX\n5.5,6.0,7.5\n')
    # 定义一个转换器,根据特定规则转换数据,例如将 'XXX' 转换为 NaN
    conv = {-1: lambda s: np.nan if s == 'XXX' else float(s)}
    # 预期的结果是一个包含特定数值的二维 NumPy 数组
    expected = np.array([[1.5, 3.5], [3.0, np.nan], [5.5, 7.5]])
    # 使用 usecols 参数限定要加载的列,并使用转换器处理数据加载
    res = np.loadtxt(
        txt,
        dtype=np.float64,
        delimiter=",",
        converters=conv,
        usecols=[0, -1],
    )
    # 断言加载结果与预期结果相等
    assert_equal(res, expected)

    # 第二个测试用例,用于测试变量行数的情况
    res = np.loadtxt(StringIO('''0,1,2\n0,1,2,3,4'''), delimiter=",",
                     usecols=[0, -1], converters={-1: (lambda x: -1)})
    # 断言加载结果与预期结果相等
    assert_array_equal(res, [[0, -1], [0, -1]])


# 测试在不同行数列数不一致情况下是否能正确抛出 ValueError 异常
def test_ragged_error():
    # 准备包含不同行数的数据列表
    rows = ["1,2,3", "1,2,3", "4,3,2,1"]
    # 使用 pytest 断言捕获指定的 ValueError 异常,并匹配预期的错误消息
    with pytest.raises(ValueError,
                       match="the number of columns changed from 3 to 4 at row 3"):
        np.loadtxt(rows, delimiter=",")


# 测试在不同行数列数不一致情况下是否能正确处理 usecols 参数
def test_ragged_usecols():
    # 测试即使在列数不一致的情况下,usecols 和负索引也能正确处理
    txt = StringIO("0,0,XXX\n0,XXX,0,XXX\n0,XXX,XXX,0,XXX\n")
    # 预期的结果是一个包含特定数值的二维 NumPy 数组
    expected = np.array([[0, 0], [0, 0], [0, 0]])
    # 使用 usecols 参数限定要加载的列,并使用负索引转换器处理数据加载
    res = np.loadtxt(txt, dtype=float, delimiter=",", usecols=[0, -2])
    # 断言加载结果与预期结果相等
    assert_equal(res, expected)

    # 准备另一个测试用例,包含不同行数和错误的 usecols 参数
    txt = StringIO("0,0,XXX\n0\n0,XXX,XXX,0,XXX\n")
    # 使用 pytest 断言捕获指定的 ValueError 异常,并匹配预期的错误消息
    with pytest.raises(ValueError,
                       match="invalid column index -2 at row 2 with 1 columns"):
        # 加载数据时,将会抛出错误,因为第二行不存在负索引为 -2 的列
        np.loadtxt(txt, dtype=float, delimiter=",", usecols=[0, -2])


# 测试空 usecols 参数的情况
def test_empty_usecols():
    txt = StringIO("0,0,XXX\n0,XXX,0,XXX\n0,XXX,XXX,0,XXX\n")
    # 使用 NumPy 加载文本文件 `txt`,返回一个 NumPy 数组 `res`
    res = np.loadtxt(txt, dtype=np.dtype([]), delimiter=",", usecols=[])
    # 断言:确保数组 `res` 的形状为 (3,)
    assert res.shape == (3,)
    # 断言:确保数组 `res` 的数据类型为一个空的结构化 NumPy 数据类型
    assert res.dtype == np.dtype([])
@pytest.mark.parametrize("c1", ["a", "の", "🫕"])
@pytest.mark.parametrize("c2", ["a", "の", "🫕"])
def test_large_unicode_characters(c1, c2):
    # 创建包含大量 Unicode 字符的测试用例,c1 和 c2 覆盖 ASCII、16 位和 32 位字符范围。
    txt = StringIO(f"a,{c1},c,1.0\ne,{c2},2.0,g")
    # 将文本数据封装为 StringIO 对象
    res = np.loadtxt(txt, dtype=np.dtype('U12'), delimiter=",")
    # 使用 NumPy 加载文本数据到数组 res 中,使用 Unicode 类型,每个元素最多12个字符,使用逗号分隔
    expected = np.array(
        [f"a,{c1},c,1.0".split(","), f"e,{c2},2.0,g".split(",")],
        dtype=np.dtype('U12')
    )
    # 创建预期结果数组,每个元素也是最多12个字符的 Unicode 类型
    assert_equal(res, expected)
    # 断言实际结果与预期结果相等


def test_unicode_with_converter():
    # 测试带有转换器的 Unicode 处理
    txt = StringIO("cat,dog\nαβγ,δεζ\nabc,def\n")
    # 将文本数据封装为 StringIO 对象
    conv = {0: lambda s: s.upper()}
    # 定义转换器,将第一列字符转换为大写
    res = np.loadtxt(
        txt,
        dtype=np.dtype("U12"),
        converters=conv,
        delimiter=",",
        encoding=None
    )
    # 使用 NumPy 加载文本数据到数组 res 中,使用 Unicode 类型,应用转换器,逗号分隔
    expected = np.array([['CAT', 'dog'], ['ΑΒΓ', 'δεζ'], ['ABC', 'def']])
    # 创建预期结果数组,每个元素最多12个字符的 Unicode 类型
    assert_equal(res, expected)
    # 断言实际结果与预期结果相等


def test_converter_with_structured_dtype():
    # 测试结构化数据类型和转换器的使用
    txt = StringIO('1.5,2.5,Abc\n3.0,4.0,dEf\n5.5,6.0,ghI\n')
    # 将文本数据封装为 StringIO 对象
    dt = np.dtype([('m', np.int32), ('r', np.float32), ('code', 'U8')])
    # 定义结构化数据类型,包括整数、浮点数和 Unicode 字符串
    conv = {0: lambda s: int(10*float(s)), -1: lambda s: s.upper()}
    # 定义转换器,将第一列乘以10转换为整数,将最后一列转换为大写
    res = np.loadtxt(txt, dtype=dt, delimiter=",", converters=conv)
    # 使用 NumPy 加载文本数据到结构化数组 res 中,应用转换器,逗号分隔
    expected = np.array(
        [(15, 2.5, 'ABC'), (30, 4.0, 'DEF'), (55, 6.0, 'GHI')], dtype=dt
    )
    # 创建预期结果结构化数组
    assert_equal(res, expected)
    # 断言实际结果与预期结果相等


def test_converter_with_unicode_dtype():
    """
    当使用 'bytes' 编码时,标记 tokens 之前编码。这意味着转换器的输出可能是字节而不是 `read_rows` 预期的 Unicode。
    此测试检查以上场景的输出是否在由 `read_rows` 解析之前被正确解码。
    """
    txt = StringIO('abc,def\nrst,xyz')
    # 将文本数据封装为 StringIO 对象
    conv = bytes.upper
    # 定义转换器,将输入的字节转换为大写
    res = np.loadtxt(
            txt, dtype=np.dtype("U3"), converters=conv, delimiter=",",
            encoding="bytes")
    # 使用 NumPy 加载文本数据到数组 res 中,使用最多3个字符的 Unicode 类型,应用转换器,逗号分隔,使用字节编码
    expected = np.array([['ABC', 'DEF'], ['RST', 'XYZ']])
    # 创建预期结果数组
    assert_equal(res, expected)
    # 断言实际结果与预期结果相等


def test_read_huge_row():
    # 测试读取超大行数据
    row = "1.5, 2.5," * 50000
    # 创建一个超大的行字符串
    row = row[:-1] + "\n"
    # 将字符串结尾替换为换行符
    txt = StringIO(row * 2)
    # 将文本数据封装为 StringIO 对象
    res = np.loadtxt(txt, delimiter=",", dtype=float)
    # 使用 NumPy 加载文本数据到数组 res 中,逗号分隔,数据类型为浮点数
    assert_equal(res, np.tile([1.5, 2.5], (2, 50000)))
    # 断言实际结果与预期结果相等


@pytest.mark.parametrize("dtype", "edfgFDG")
def test_huge_float(dtype):
    # 测试处理大浮点数的情况,覆盖一个不经常发生的非优化路径
    field = "0" * 1000 + ".123456789"
    # 创建一个大数值字段
    dtype = np.dtype(dtype)
    # 定义数据类型
    value = np.loadtxt([field], dtype=dtype)[()]
    # 使用 NumPy 加载文本数据到数组 value 中,使用指定的数据类型
    assert value == dtype.type("0.123456789")
    # 断言实际结果与预期结果相等


@pytest.mark.parametrize(
    ("given_dtype", "expected_dtype"),
    [
        ("S", np.dtype("S5")),
        ("U", np.dtype("U5")),
    ],
)
def test_string_no_length_given(given_dtype, expected_dtype):
    """
    给定的数据类型只有 'S' 或 'U' 而没有长度。在这些情况下,结果的长度由文件中找到的最长字符串决定。
    """
    txt = StringIO("AAA,5-1\nBBBBB,0-3\nC,4-9\n")
    # 将文本数据封装为 StringIO 对象
    res = np.loadtxt(txt, dtype=given_dtype, delimiter=",")
    # 使用 NumPy 加载文本数据到数组 res 中,使用给定的数据类型,逗号分隔
    # 创建一个预期的 NumPy 数组,包含指定的数据和数据类型
    expected = np.array(
        [['AAA', '5-1'], ['BBBBB', '0-3'], ['C', '4-9']], dtype=expected_dtype
    )
    # 使用 assert_equal 函数比较两个对象 res 和 expected 是否相等
    assert_equal(res, expected)
    # 使用 assert_equal 函数比较对象 res 的数据类型是否与预期的数据类型 expected_dtype 相等
    assert_equal(res.dtype, expected_dtype)
# 测试浮点数转换的准确性,验证转换为 float64 是否与 Python 内置的 float 函数一致。
def test_float_conversion():
    """
    Some tests that the conversion to float64 works as accurately as the
    Python built-in `float` function. In a naive version of the float parser,
    these strings resulted in values that were off by an ULP or two.
    """
    # 定义待转换的字符串列表
    strings = [
        '0.9999999999999999',
        '9876543210.123456',
        '5.43215432154321e+300',
        '0.901',
        '0.333',
    ]
    # 将字符串列表写入内存中的文本流
    txt = StringIO('\n'.join(strings))
    # 使用 numpy 的 loadtxt 函数加载数据
    res = np.loadtxt(txt)
    # 构建预期结果的 numpy 数组,通过 float 函数转换每个字符串为 float 类型
    expected = np.array([float(s) for s in strings])
    # 使用 assert_equal 断言 res 和 expected 数组相等
    assert_equal(res, expected)


# 测试布尔值转换
def test_bool():
    # 通过整数测试布尔值的简单情况
    txt = StringIO("1, 0\n10, -1")
    # 使用 numpy 的 loadtxt 函数加载数据,指定数据类型为 bool,分隔符为逗号
    res = np.loadtxt(txt, dtype=bool, delimiter=",")
    # 断言结果数组的数据类型为 bool
    assert res.dtype == bool
    # 断言数组内容与预期数组相等
    assert_array_equal(res, [[True, False], [True, True]])
    # 确保在字节级别上只使用 1 和 0
    assert_array_equal(res.view(np.uint8), [[1, 0], [1, 1]])


# 测试整数符号的处理
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
@pytest.mark.parametrize("dtype", np.typecodes["AllInteger"])
@pytest.mark.filterwarnings("error:.*integer via a float.*:DeprecationWarning")
def test_integer_signs(dtype):
    # 将 dtype 转换为 numpy 的数据类型
    dtype = np.dtype(dtype)
    # 断言加载包含 "+2" 的数据返回值为 2
    assert np.loadtxt(["+2"], dtype=dtype) == 2
    # 如果数据类型为无符号整数,断言加载包含 "-1\n" 的数据会引发 ValueError 异常
    if dtype.kind == "u":
        with pytest.raises(ValueError):
            np.loadtxt(["-1\n"], dtype=dtype)
    else:
        # 断言加载包含 "-2\n" 的数据返回值为 -2
        assert np.loadtxt(["-2\n"], dtype=dtype) == -2

    # 对于不合法的符号组合,如 "++", "+-", "--", "-+",断言加载时会引发 ValueError 异常
    for sign in ["++", "+-", "--", "-+"]:
        with pytest.raises(ValueError):
            np.loadtxt([f"{sign}2\n"], dtype=dtype)


# 测试隐式将浮点数转换为整数时的错误处理
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
@pytest.mark.parametrize("dtype", np.typecodes["AllInteger"])
@pytest.mark.filterwarnings("error:.*integer via a float.*:DeprecationWarning")
def test_implicit_cast_float_to_int_fails(dtype):
    # 定义包含浮点数和整数的文本流
    txt = StringIO("1.0, 2.1, 3.7\n4, 5, 6")
    # 断言加载时会引发 ValueError 异常
    with pytest.raises(ValueError):
        np.loadtxt(txt, dtype=dtype, delimiter=",")


# 测试复数的解析
@pytest.mark.parametrize("dtype", (np.complex64, np.complex128))
@pytest.mark.parametrize("with_parens", (False, True))
def test_complex_parsing(dtype, with_parens):
    # 定义包含复数字符串的文本流
    s = "(1.0-2.5j),3.75,(7+-5.0j)\n(4),(-19e2j),(0)"
    if not with_parens:
        s = s.replace("(", "").replace(")", "")

    # 使用 numpy 的 loadtxt 函数加载数据,指定数据类型为复数类型,分隔符为逗号
    res = np.loadtxt(StringIO(s), dtype=dtype, delimiter=",")
    # 构建预期结果的 numpy 数组
    expected = np.array(
        [[1.0-2.5j, 3.75, 7-5j], [4.0, -1900j, 0]], dtype=dtype
    )
    # 使用 assert_equal 断言 res 和 expected 数组相等
    assert_equal(res, expected)


# 测试从生成器中读取数据
def test_read_from_generator():
    # 定义生成器函数
    def gen():
        for i in range(4):
            yield f"{i},{2*i},{i**2}"

    # 使用 numpy 的 loadtxt 函数加载生成器生成的数据,指定数据类型为整数,分隔符为逗号
    res = np.loadtxt(gen(), dtype=int, delimiter=",")
    # 构建预期结果的 numpy 数组
    expected = np.array([[0, 0, 0], [1, 2, 1], [2, 4, 4], [3, 6, 9]])
    # 使用 assert_equal 断言 res 和 expected 数组相等
    assert_equal(res, expected)


# 测试从生成器中读取多种类型的数据
def test_read_from_generator_multitype():
    # 定义生成器函数
    def gen():
        for i in range(3):
            yield f"{i} {i / 4}"

    # 使用 numpy 的 loadtxt 函数加载生成器生成的数据,指定数据类型为 "i, d",分隔符为空格
    res = np.loadtxt(gen(), dtype="i, d", delimiter=" ")
    # 定义预期的 NumPy 数组,包含两列,第一列为整数类型,第二列为双精度浮点数类型
    expected = np.array([(0, 0.0), (1, 0.25), (2, 0.5)], dtype="i, d")
    # 使用 assert_equal 函数比较 res 和 expected,确保它们相等
    assert_equal(res, expected)
def test_read_from_bad_generator():
    # 定义一个生成器函数 `gen()`,生成器会依次产生字符串、字节串和整数
    def gen():
        yield from ["1,2", b"3, 5", 12738]

    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 TypeError 异常,并验证异常消息
    with pytest.raises(
            TypeError, match=r"non-string returned while reading data"):
        np.loadtxt(gen(), dtype="i, i", delimiter=",")


@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_object_cleanup_on_read_error():
    # 创建一个对象 sentinel 作为测试目的
    sentinel = object()
    # 初始化一个计数器 already_read,记录已经读取的次数
    already_read = 0

    # 定义一个转换函数 conv(x),用于处理每一行的数据并返回 sentinel
    def conv(x):
        nonlocal already_read
        # 如果 already_read 大于 4999,抛出 ValueError 异常
        if already_read > 4999:
            raise ValueError("failed half-way through!")
        already_read += 1
        return sentinel

    # 创建一个包含大量数据的 StringIO 对象 txt
    txt = StringIO("x\n" * 10000)

    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 ValueError 异常,并验证异常消息
    with pytest.raises(ValueError, match="at row 5000, column 1"):
        np.loadtxt(txt, dtype=object, converters={0: conv})

    # 检查 sentinel 的引用计数是否为 2
    assert sys.getrefcount(sentinel) == 2


@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
def test_character_not_bytes_compatible():
    """Test exception when a character cannot be encoded as 'S'."""
    # 创建一个包含特殊字符 '–'(Unicode码点 \u2013)的 StringIO 对象 data
    data = StringIO("–")
    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 ValueError 异常
    with pytest.raises(ValueError):
        np.loadtxt(data, dtype="S5")


@pytest.mark.parametrize("conv", (0, [float], ""))
def test_invalid_converter(conv):
    # 定义期望的错误消息
    msg = (
        "converters must be a dictionary mapping columns to converter "
        "functions or a single callable."
    )
    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 TypeError 异常,并验证异常消息
    with pytest.raises(TypeError, match=msg):
        np.loadtxt(StringIO("1 2\n3 4"), converters=conv)


@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
def test_converters_dict_raises_non_integer_key():
    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 TypeError 异常,并验证异常消息
    with pytest.raises(TypeError, match="keys of the converters dict"):
        np.loadtxt(StringIO("1 2\n3 4"), converters={"a": int})
    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 TypeError 异常,并验证异常消息
    with pytest.raises(TypeError, match="keys of the converters dict"):
        np.loadtxt(StringIO("1 2\n3 4"), converters={"a": int}, usecols=0)


@pytest.mark.parametrize("bad_col_ind", (3, -3))
def test_converters_dict_raises_non_col_key(bad_col_ind):
    # 创建一个包含数据的 StringIO 对象 data
    data = StringIO("1 2\n3 4")
    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 ValueError 异常,并验证异常消息
    with pytest.raises(ValueError, match="converter specified for column"):
        np.loadtxt(data, converters={bad_col_ind: int})


def test_converters_dict_raises_val_not_callable():
    # 使用 pytest 检查调用 `np.loadtxt()` 时抛出的 TypeError 异常,并验证异常消息
    with pytest.raises(TypeError,
                match="values of the converters dictionary must be callable"):
        np.loadtxt(StringIO("1 2\n3 4"), converters={0: 1})


@pytest.mark.parametrize("q", ('"', "'", "`"))
def test_quoted_field(q):
    # 创建一个包含带引号字段的数据的 StringIO 对象 txt
    txt = StringIO(
        f"{q}alpha, x{q}, 2.5\n{q}beta, y{q}, 4.5\n{q}gamma, z{q}, 5.0\n"
    )
    # 定义期望的数据类型 dtype
    dtype = np.dtype([('f0', 'U8'), ('f1', np.float64)])
    # 定义期望的结果数组 expected
    expected = np.array(
        [("alpha, x", 2.5), ("beta, y", 4.5), ("gamma, z", 5.0)], dtype=dtype
    )

    # 调用 `np.loadtxt()` 加载数据,并将结果存储在 res 中
    res = np.loadtxt(txt, dtype=dtype, delimiter=",", quotechar=q)
    # 使用 assert_array_equal 检查 res 是否与期望的结果数组 expected 相等
    assert_array_equal(res, expected)


@pytest.mark.parametrize("q", ('"', "'", "`"))
def test_quoted_field_with_whitepace_delimiter(q):
    # 此测试未提供完整的代码示例,因此无法添加注释
    pass
    # 创建一个包含指定文本的内存中的文本流对象
    txt = StringIO(
        f"{q}alpha, x{q}     2.5\n{q}beta, y{q} 4.5\n{q}gamma, z{q}   5.0\n"
    )
    # 定义一个 NumPy 数据类型,包含两个字段:一个是最大长度为 8 的 Unicode 字符串,另一个是 64 位浮点数
    dtype = np.dtype([('f0', 'U8'), ('f1', np.float64)])
    # 创建一个 NumPy 数组,用于存储预期的数据,每个元素是一个元组,元组包含一个字符串和一个浮点数
    expected = np.array(
        [("alpha, x", 2.5), ("beta, y", 4.5), ("gamma, z", 5.0)], dtype=dtype
    )
    
    # 使用 np.loadtxt 从文本流中加载数据,并指定数据类型、分隔符和引用字符
    res = np.loadtxt(txt, dtype=dtype, delimiter=None, quotechar=q)
    # 使用 assert_array_equal 断言函数,检查加载的数据是否与预期数据一致
    assert_array_equal(res, expected)
def test_quoted_field_is_not_empty_nonstrict():
    # Same as test_quoted_field_is_not_empty but check that we are not strict
    # about missing closing quote (this is the `csv.reader` default also)
    # 创建包含数据的字符串文件对象
    txt = StringIO('1\n\n"4"\n"')
    # 期望的结果数组
    expected = np.array(["1", "4", ""], dtype="U1")
    # 使用 NumPy 的 loadtxt 函数从文本文件中加载数据
    res = np.loadtxt(txt, delimiter=",", dtype="U1", quotechar='"')
    # 断言,验证加载的数据 res 是否等于预期的数据 expected
    assert_equal(res, expected)
def test_consecutive_quotechar_escaped():
    # 创建一个字符串缓冲区,内容为包含连续引号的文本
    txt = StringIO('"Hello, my name is ""Monty""!"')
    # 创建预期的 NumPy 数组,包含解析后的字符串
    expected = np.array('Hello, my name is "Monty"!', dtype="U40")
    # 使用 np.loadtxt 从文本中加载数据到 res 变量中
    res = np.loadtxt(txt, dtype="U40", delimiter=",", quotechar='"')
    # 断言 res 和 expected 数组相等
    assert_equal(res, expected)


@pytest.mark.parametrize("data", ("", "\n\n\n", "# 1 2 3\n# 4 5 6\n"))
@pytest.mark.parametrize("ndmin", (0, 1, 2))
@pytest.mark.parametrize("usecols", [None, (1, 2, 3)])
def test_warn_on_no_data(data, ndmin, usecols):
    """检查当输入数据为空时是否发出 UserWarning。"""
    if usecols is not None:
        expected_shape = (0, 3)
    elif ndmin == 2:
        expected_shape = (0, 1)  # 猜测只有一列数据?!
    else:
        expected_shape = (0,)

    # 创建一个包含指定数据的字符串缓冲区
    txt = StringIO(data)
    # 使用 pytest 的 warn 环境,检查是否发出 UserWarning 并匹配指定消息
    with pytest.warns(UserWarning, match="input contained no data"):
        # 使用 np.loadtxt 从文本中加载数据到 res 变量中
        res = np.loadtxt(txt, ndmin=ndmin, usecols=usecols)
    # 断言加载后的数据形状与预期形状相同
    assert res.shape == expected_shape

    # 使用临时文件写入指定数据
    with NamedTemporaryFile(mode="w") as fh:
        fh.write(data)
        fh.seek(0)
        # 使用 pytest 的 warn 环境,检查是否发出 UserWarning 并匹配指定消息
        with pytest.warns(UserWarning, match="input contained no data"):
            # 使用 np.loadtxt 从文本中加载数据到 res 变量中
            res = np.loadtxt(txt, ndmin=ndmin, usecols=usecols)
        # 断言加载后的数据形状与预期形状相同
        assert res.shape == expected_shape


@pytest.mark.parametrize("skiprows", (2, 3))
def test_warn_on_skipped_data(skiprows):
    # 创建包含数据的字符串缓冲区
    data = "1 2 3\n4 5 6"
    txt = StringIO(data)
    # 使用 pytest 的 warn 环境,检查是否发出 UserWarning 并匹配指定消息
    with pytest.warns(UserWarning, match="input contained no data"):
        # 使用 np.loadtxt 从文本中加载数据,跳过指定行数
        np.loadtxt(txt, skiprows=skiprows)


@pytest.mark.parametrize(["dtype", "value"], [
        ("i2", 0x0001), ("u2", 0x0001),
        ("i4", 0x00010203), ("u4", 0x00010203),
        ("i8", 0x0001020304050607), ("u8", 0x0001020304050607),
        ("float16", 3.07e-05),
        ("float32", 9.2557e-41), ("complex64", 9.2557e-41+2.8622554e-29j),
        ("float64", -1.758571353180402e-24),
        ("complex128", repr(5.406409232372729e-29-1.758571353180402e-24j)),
        ("longdouble", 0x01020304050607),
        ("clongdouble", repr(0x01020304050607 + (0x00121314151617 * 1j))),
        ("U2", "\U00010203\U000a0b0c")])
@pytest.mark.parametrize("swap", [True, False])
def test_byteswapping_and_unaligned(dtype, value, swap):
    # 尝试创建具有 "有趣" 值的数据,确保在有效的 Unicode 范围内
    dtype = np.dtype(dtype)
    # 创建包含指定数据的列表
    data = [f"x,{value}\n"]
    # 如果 swap 为 True,则交换字节顺序
    if swap:
        dtype = dtype.newbyteorder()
    # 创建具有指定结构的 dtype
    full_dt = np.dtype([("a", "S1"), ("b", dtype)], align=False)
    # 确保 "b" 字段的对齐方式为非对齐
    assert full_dt.fields["b"][1] == 1
    # 使用 numpy 的 loadtxt 函数从数据中加载内容,指定数据类型为 full_dt,分隔符为逗号
    # 使用 max_rows 参数限制加载的行数,防止过度分配内存
    res = np.loadtxt(data, dtype=full_dt, delimiter=",", max_rows=1)

    # 使用断言确保 res 数组中字段 "b" 的值等于给定的 value 值
    assert res["b"] == dtype.type(value)
# 使用 pytest 的 parametrize 装饰器为单元测试函数提供多组参数化输入
@pytest.mark.parametrize("dtype",
        np.typecodes["AllInteger"] + "efdFD" + "?")
def test_unicode_whitespace_stripping(dtype):
    # 测试所有数字类型(包括布尔型)是否能正确去除空白字符
    # \u202F 是一个窄的不换行空格,`\n` 表示一个普通的换行符
    # 目前跳过 float128,因为它不总是支持此功能且没有“自定义”解析
    txt = StringIO(' 3 ,"\u202F2\n"')
    # 使用 np.loadtxt 函数从文本流中加载数据,并指定数据类型、分隔符和引号字符
    res = np.loadtxt(txt, dtype=dtype, delimiter=",", quotechar='"')
    # 断言加载的数据与预期的数组相等
    assert_array_equal(res, np.array([3, 2]).astype(dtype))


@pytest.mark.parametrize("dtype", "FD")
def test_unicode_whitespace_stripping_complex(dtype):
    # 复数有一些额外的情况,因为它有两个组件和括号
    line = " 1 , 2+3j , ( 4+5j ), ( 6+-7j )  , 8j , ( 9j ) \n"
    data = [line, line.replace(" ", "\u202F")]
    # 测试加载包含复数的数据时是否正确去除空白字符
    res = np.loadtxt(data, dtype=dtype, delimiter=',')
    # 断言加载的数据与预期的二维数组相等
    assert_array_equal(res, np.array([[1, 2+3j, 4+5j, 6-7j, 8j, 9j]] * 2))


@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
@pytest.mark.parametrize("dtype", "FD")
@pytest.mark.parametrize("field",
        ["1 +2j", "1+ 2j", "1+2 j", "1+-+3", "(1j", "(1", "(1+2j", "1+2j)"])
def test_bad_complex(dtype, field):
    # 使用 pytest.raises 检查是否会抛出 ValueError 异常
    with pytest.raises(ValueError):
        # 测试加载包含错误格式的复数字符串时是否会抛出异常
        np.loadtxt([field + "\n"], dtype=dtype, delimiter=",")


@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
@pytest.mark.parametrize("dtype",
            np.typecodes["AllInteger"] + "efgdFDG" + "?")
def test_nul_character_error(dtype):
    # 测试是否能正确识别 `\0` 字符,并抛出 ValueError 异常
    # 即使前面的内容是有效的(不是所有内容都能在内部解析)
    if dtype.lower() == "g":
        pytest.xfail("longdouble/clongdouble assignment may misbehave.")
    with pytest.raises(ValueError):
        np.loadtxt(["1\000"], dtype=dtype, delimiter=",", quotechar='"')


@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
@pytest.mark.parametrize("dtype",
        np.typecodes["AllInteger"] + "efgdFDG" + "?")
def test_no_thousands_support(dtype):
    # 主要用于文档说明行为,Python 支持像 1_1 这样的千分位表示
    # (e 和 G 可能会使用不同的转换和支持,这是一个 bug 但确实发生了...)
    if dtype == "e":
        pytest.skip("half assignment currently uses Python float converter")
    if dtype in "eG":
        pytest.xfail("clongdouble assignment is buggy (uses `complex`?).")

    assert int("1_1") == float("1_1") == complex("1_1") == 11
    with pytest.raises(ValueError):
        np.loadtxt(["1_1\n"], dtype=dtype)


@pytest.mark.parametrize("data", [
    ["1,2\n", "2\n,3\n"],
    ["1,2\n", "2\r,3\n"]])
def test_bad_newline_in_iterator(data):
    # 在 NumPy <=1.22 中这是被接受的,因为换行符是完全
    # 设置错误消息字符串,用于匹配 pytest 抛出的 ValueError 异常
    msg = "Found an unquoted embedded newline within a single line"
    # 使用 pytest 提供的上下文管理器 `pytest.raises` 来捕获 ValueError 异常,
    # 并检查其异常消息是否与预设的 `msg` 相匹配
    with pytest.raises(ValueError, match=msg):
        # 调用 numpy 的 loadtxt 函数来加载数据,指定分隔符为逗号 `,`
        np.loadtxt(data, delimiter=",")
@pytest.mark.parametrize("data", [
    ["1,2\n", "2,3\r\n"],  # 定义测试参数,包括包含不同换行符的数据
    ["1,2\n", "'2\n',3\n"],  # 含有引号的换行数据
    ["1,2\n", "'2\r',3\n"],  # 含有引号的回车数据
    ["1,2\n", "'2\r\n',3\n"],  # 含有引号的回车换行数据
])
def test_good_newline_in_iterator(data):
    # 在这里引号内的换行符不会被转换,但会被视为空白字符。
    res = np.loadtxt(data, delimiter=",", quotechar="'")  # 使用 numpy 的 loadtxt 函数加载数据
    assert_array_equal(res, [[1., 2.], [2., 3.]])


@pytest.mark.parametrize("newline", ["\n", "\r", "\r\n"])
def test_universal_newlines_quoted(newline):
    # 检查在引用字段中不应用通用换行符支持的情况下的情况
    # (注意,行必须以换行符结尾,否则引用字段将不包括换行符)
    data = ['1,"2\n"\n', '3,"4\n', '1"\n']
    data = [row.replace("\n", newline) for row in data]  # 替换每行的换行符为指定的换行符
    res = np.loadtxt(data, dtype=object, delimiter=",", quotechar='"')  # 使用 numpy 的 loadtxt 函数加载数据
    assert_array_equal(res, [['1', f'2{newline}'], ['3', f'4{newline}1']])


def test_null_character():
    # 检查 NUL 字符是否不具有特殊性的基本测试:
    res = np.loadtxt(["1\0002\0003\n", "4\0005\0006"], delimiter="\000")  # 使用 numpy 的 loadtxt 函数加载数据
    assert_array_equal(res, [[1, 2, 3], [4, 5, 6]])

    # 同样不作为字段的一部分(避免 Unicode/数组会将 \0 去掉)
    res = np.loadtxt(["1\000,2\000,3\n", "4\000,5\000,6"],
                     delimiter=",", dtype=object)  # 使用 numpy 的 loadtxt 函数加载数据
    assert res.tolist() == [["1\000", "2\000", "3"], ["4\000", "5\000", "6"]]


def test_iterator_fails_getting_next_line():
    class BadSequence:
        def __len__(self):
            return 100

        def __getitem__(self, item):
            if item == 50:
                raise RuntimeError("Bad things happened!")
            return f"{item}, {item+1}"

    with pytest.raises(RuntimeError, match="Bad things happened!"):
        np.loadtxt(BadSequence(), dtype=int, delimiter=",")  # 使用 numpy 的 loadtxt 函数加载数据


class TestCReaderUnitTests:
    # 这些是路径上不应该触发的内部测试,除非出现非常严重的问题。
    def test_not_an_filelike(self):
        with pytest.raises(AttributeError, match=".*read"):
            np._core._multiarray_umath._load_from_filelike(
                object(), dtype=np.dtype("i"), filelike=True)

    def test_filelike_read_fails(self):
        # 只有当 loadtxt 打开文件时才能到达,所以很难通过公共接口实现
        # (尽管在当前的 "DataClass" 支持下可能不是不可能的)。
        class BadFileLike:
            counter = 0

            def read(self, size):
                self.counter += 1
                if self.counter > 20:
                    raise RuntimeError("Bad bad bad!")
                return "1,2,3\n"

        with pytest.raises(RuntimeError, match="Bad bad bad!"):
            np._core._multiarray_umath._load_from_filelike(
                BadFileLike(), dtype=np.dtype("i"), filelike=True)
    # 定义一个测试用例,用于测试当 read 方法返回非字符串类型时的情况
    def test_filelike_bad_read(self):
        # 如果 loadtxt 打开文件,则可以到达此处,所以很难通过公共接口完成
        # 虽然在当前的“DataClass”支持下可能并非不可能。

        # 定义一个模拟的文件类 BadFileLike
        class BadFileLike:
            counter = 0

            # 重载 read 方法,返回一个整数而不是字符串
            def read(self, size):
                return 1234  # not a string!

        # 使用 pytest 检查是否会抛出 TypeError 异常,并匹配特定的错误信息
        with pytest.raises(TypeError,
                    match="non-string returned while reading data"):
            # 调用被测试的函数,传入 BadFileLike 实例作为文件对象
            np._core._multiarray_umath._load_from_filelike(
                BadFileLike(), dtype=np.dtype("i"), filelike=True)

    # 定义一个测试用例,用于测试当对象不是可迭代对象时的情况
    def test_not_an_iter(self):
        # 使用 pytest 检查是否会抛出 TypeError 异常,并匹配特定的错误信息
        with pytest.raises(TypeError,
                    match="error reading from object, expected an iterable"):
            # 调用被测试的函数,传入普通对象而不是可迭代对象
            np._core._multiarray_umath._load_from_filelike(
                object(), dtype=np.dtype("i"), filelike=False)

    # 定义一个测试用例,用于测试当 dtype 参数不正确时的情况
    def test_bad_type(self):
        # 使用 pytest 检查是否会抛出 TypeError 异常,并匹配特定的错误信息
        with pytest.raises(TypeError, match="internal error: dtype must"):
            # 调用被测试的函数,传入错误的 dtype 类型
            np._core._multiarray_umath._load_from_filelike(
                object(), dtype="i", filelike=False)

    # 定义一个测试用例,用于测试当 encoding 参数不正确时的情况
    def test_bad_encoding(self):
        # 使用 pytest 检查是否会抛出 TypeError 异常,并匹配特定的错误信息
        with pytest.raises(TypeError, match="encoding must be a unicode"):
            # 调用被测试的函数,传入错误的 encoding 类型
            np._core._multiarray_umath._load_from_filelike(
                object(), dtype=np.dtype("i"), filelike=False, encoding=123)

    # 使用 pytest 的参数化功能定义一个测试用例,测试不同的 newline 参数
    @pytest.mark.parametrize("newline", ["\r", "\n", "\r\n"])
    def test_manual_universal_newlines(self, newline):
        # 这部分当前对用户不可用,因为我们应该始终以启用了 universal newlines 的方式打开文件 `newlines=None`
        # (从迭代器读取数据使用了稍微不同的代码路径)。
        # 我们对 `newline="\r"` 或 `newline="\n"` 没有真正的支持,因为用户不能指定这些选项。

        # 创建一个 StringIO 对象,模拟包含特定 newline 的数据
        data = StringIO('0\n1\n"2\n"\n3\n4 #\n'.replace("\n", newline),
                        newline="")

        # 调用被测试的函数,传入 StringIO 对象以及其他参数
        res = np._core._multiarray_umath._load_from_filelike(
            data, dtype=np.dtype("U10"), filelike=True,
            quote='"', comment="#", skiplines=1)
        
        # 使用 assert_array_equal 断言函数验证结果的正确性
        assert_array_equal(res[:, 0], ["1", f"2{newline}", "3", "4 "])
# 当分隔符与注释字符冲突时,应该抛出TypeError异常,提示控制字符不兼容
def test_delimiter_comment_collision_raises():
    # 使用 pytest 模块验证加载文本时抛出TypeError异常,异常消息中包含“control characters”和“incompatible”
    with pytest.raises(TypeError, match=".*control characters.*incompatible"):
        # 使用 numpy 的 loadtxt 函数加载以逗号分隔的文本数据,指定分隔符为逗号,注释字符也为逗号
        np.loadtxt(StringIO("1, 2, 3"), delimiter=",", comments=",")


# 当分隔符与引用字符冲突时,应该抛出TypeError异常,提示控制字符不兼容
def test_delimiter_quotechar_collision_raises():
    # 使用 pytest 模块验证加载文本时抛出TypeError异常,异常消息中包含“control characters”和“incompatible”
    with pytest.raises(TypeError, match=".*control characters.*incompatible"):
        # 使用 numpy 的 loadtxt 函数加载以逗号分隔的文本数据,指定分隔符为逗号,引用字符也为逗号
        np.loadtxt(StringIO("1, 2, 3"), delimiter=",", quotechar=",")


# 当注释字符与引用字符冲突时,应该抛出TypeError异常,提示控制字符不兼容
def test_comment_quotechar_collision_raises():
    # 使用 pytest 模块验证加载文本时抛出TypeError异常,异常消息中包含“control characters”和“incompatible”
    with pytest.raises(TypeError, match=".*control characters.*incompatible"):
        # 使用 numpy 的 loadtxt 函数加载空格分隔的文本数据,指定注释字符为井号,引用字符也为井号
        np.loadtxt(StringIO("1 2 3"), comments="#", quotechar="#")


# 当分隔符与多个注释字符冲突时,应该抛出TypeError异常,提示注释字符不能包括分隔符
def test_delimiter_and_multiple_comments_collision_raises():
    # 使用 pytest 模块验证加载文本时抛出TypeError异常,异常消息中包含“Comment characters”和“cannot include the delimiter”
    with pytest.raises(
        TypeError, match="Comment characters.*cannot include the delimiter"
    ):
        # 使用 numpy 的 loadtxt 函数加载以逗号分隔的文本数据,指定分隔符为逗号,注释字符包括井号和逗号
        np.loadtxt(StringIO("1, 2, 3"), delimiter=",", comments=["#", ","])


# 使用 pytest.mark.parametrize 注册的参数化测试,测试空白字符与默认分隔符冲突时是否抛出TypeError异常
@pytest.mark.parametrize(
    "ws",
    (
        " ",  # 空格
        "\t",  # 制表符
        "\u2003",  # EM 空白
        "\u00A0",  # 不间断空白
        "\u3000",  # 表意字符空白
    )
)
def test_collision_with_default_delimiter_raises(ws):
    # 使用 pytest 模块验证加载文本时抛出TypeError异常,异常消息中包含“control characters”和“incompatible”
    with pytest.raises(TypeError, match=".*control characters.*incompatible"):
        # 使用 numpy 的 loadtxt 函数加载带有空白字符分隔的文本数据,指定注释字符为当前空白字符
        np.loadtxt(StringIO(f"1{ws}2{ws}3\n4{ws}5{ws}6\n"), comments=ws)
    with pytest.raises(TypeError, match=".*control characters.*incompatible"):
        # 使用 numpy 的 loadtxt 函数加载带有空白字符分隔的文本数据,指定引用字符为当前空白字符
        np.loadtxt(StringIO(f"1{ws}2{ws}3\n4{ws}5{ws}6\n"), quotechar=ws)


# 使用 pytest.mark.parametrize 注册的参数化测试,测试控制字符与换行符冲突时是否抛出TypeError异常
@pytest.mark.parametrize("nl", ("\n", "\r"))
def test_control_character_newline_raises(nl):
    # 准备包含换行符的文本数据
    txt = StringIO(f"1{nl}2{nl}3{nl}{nl}4{nl}5{nl}6{nl}{nl}")
    # 准备异常消息
    msg = "control character.*cannot be a newline"
    # 使用 pytest 模块验证加载文本时抛出TypeError异常,异常消息中包含“control character”和“cannot be a newline”
    with pytest.raises(TypeError, match=msg):
        # 使用 numpy 的 loadtxt 函数加载文本数据,指定分隔符为当前换行符
        np.loadtxt(txt, delimiter=nl)
    with pytest.raises(TypeError, match=msg):
        # 使用 numpy 的 loadtxt 函数加载文本数据,指定注释字符为当前换行符
        np.loadtxt(txt, comments=nl)
    with pytest.raises(TypeError, match=msg):
        # 使用 numpy 的 loadtxt 函数加载文本数据,指定引用字符为当前换行符
        np.loadtxt(txt, quotechar=nl)


# 使用 pytest.mark.parametrize 注册的参数化测试,测试用户指定的数据类型发现功能
@pytest.mark.parametrize(
    ("generic_data", "long_datum", "unitless_dtype", "expected_dtype"),
    [
        ("2012-03", "2013-01-15", "M8", "M8[D]"),  # 日期时间类型
        ("spam-a-lot", "tis_but_a_scratch", "U", "U17"),  # 字符串类型
    ],
)
@pytest.mark.parametrize("nrows", (10, 50000, 60000))  # 小于、等于、大于分块大小
def test_parametric_unit_discovery(
    generic_data, long_datum, unitless_dtype, expected_dtype, nrows
):
    """检查当用户指定无单位的日期时间时,从数据中正确识别单位(例如月、日、秒)。"""
    # 准备数据,包含重复数据和长日期时间数据
    data = [generic_data] * 50000 + [long_datum]
    expected = np.array(data, dtype=expected_dtype)

    # 准备文件对象路径
    txt = StringIO("\n".join(data))
    # 使用 numpy 的 loadtxt 函数加载文本数据,指定数据类型为无单位的日期时间类型
    a = np.loadtxt(txt, dtype=unitless_dtype)
    assert a.dtype == expected.dtype
    assert_equal(a, expected)

    # 准备文件路径
    fd, fname = mkstemp()
    os.close(fd)
    with open(fname, "w") as fh:
        fh.write("\n".join(data))
    # 使用 numpy 的 loadtxt 函数加载文件中的文本数据,指定数据类型为无单位的日期时间类型
    a = np.loadtxt(fname, dtype=unitless_dtype)
    os.remove(fname)
    assert a.dtype == expected.dtype
    assert_equal(a, expected)
def test_str_dtype_unit_discovery_with_converter():
    # 创建一个包含大量字符串的列表,其中包括一个特殊的字符串
    data = ["spam-a-lot"] * 60000 + ["XXXtis_but_a_scratch"]
    # 创建预期的 NumPy 数组,指定数据类型为 Unicode 字符串,长度为 17
    expected = np.array(["spam-a-lot"] * 60000 + ["tis_but_a_scratch"], dtype="U17")
    # 定义一个字符串转换器,去除字符串两端的 "XXX"
    conv = lambda s: s.strip("XXX")

    # 创建一个类似文件的路径,将数据作为文本流写入 StringIO 对象
    txt = StringIO("\n".join(data))
    # 使用 np.loadtxt 从文本流中加载数据,指定数据类型为 Unicode,应用字符串转换器
    a = np.loadtxt(txt, dtype="U", converters=conv)
    # 断言加载后的数组的数据类型与预期相符
    assert a.dtype == expected.dtype
    # 断言加载后的数组内容与预期相等
    assert_equal(a, expected)

    # 创建一个文件对象路径,写入数据并读取
    fd, fname = mkstemp()
    os.close(fd)
    with open(fname, "w") as fh:
        fh.write("\n".join(data))
    # 使用 np.loadtxt 从文件中加载数据,指定数据类型为 Unicode,应用字符串转换器
    a = np.loadtxt(fname, dtype="U", converters=conv)
    os.remove(fname)
    # 断言加载后的数组的数据类型与预期相符
    assert a.dtype == expected.dtype
    # 断言加载后的数组内容与预期相等
    assert_equal(a, expected)


@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
                    reason="PyPy bug in error formatting")
def test_control_character_empty():
    # 使用 pytest 检测加载数据时的异常情况,期望抛出 TypeError
    with pytest.raises(TypeError, match="Text reading control character must"):
        np.loadtxt(StringIO("1 2 3"), delimiter="")
    with pytest.raises(TypeError, match="Text reading control character must"):
        np.loadtxt(StringIO("1 2 3"), quotechar="")
    # 使用 pytest 检测加载数据时的异常情况,期望抛出 ValueError
    with pytest.raises(ValueError, match="comments cannot be an empty string"):
        np.loadtxt(StringIO("1 2 3"), comments="")
    with pytest.raises(ValueError, match="comments cannot be an empty string"):
        np.loadtxt(StringIO("1 2 3"), comments=["#", ""])


def test_control_characters_as_bytes():
    """Byte control characters (comments, delimiter) are supported."""
    # 使用字节形式的控制字符(注释符号和分隔符)加载数据
    a = np.loadtxt(StringIO("#header\n1,2,3"), comments=b"#", delimiter=b",")
    # 断言加载后的数组内容与预期相等
    assert_equal(a, [1, 2, 3])


@pytest.mark.filterwarnings('ignore::UserWarning')
def test_field_growing_cases():
    # 测试在每个字段仍然占据一个字符的情况下进行空字段的追加/增长
    res = np.loadtxt([""], delimiter=",", dtype=bytes)
    # 断言加载结果数组的长度为 0
    assert len(res) == 0

    # 循环测试不同长度的字段字符串,检查最终字段追加不会产生问题
    for i in range(1, 1024):
        res = np.loadtxt(["," * i], delimiter=",", dtype=bytes)
        # 断言加载结果数组的长度与预期相符
        assert len(res) == i+1

.\numpy\numpy\lib\tests\test_mixins.py

# 导入必要的模块
import numbers           # 导入 numbers 模块,用于处理数字相关操作
import operator          # 导入 operator 模块,提供了各种运算符的函数实现

import numpy as np       # 导入 NumPy 库,用于科学计算
from numpy.testing import assert_, assert_equal, assert_raises  # 导入 NumPy 测试模块中的断言函数

# NOTE: This class should be kept as an exact copy of the example from the
# docstring for NDArrayOperatorsMixin.

# 定义一个类 ArrayLike,继承自 NDArrayOperatorsMixin,用于模拟数组的行为
class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
    def __init__(self, value):
        self.value = np.asarray(value)  # 将输入的值转换为 NumPy 数组并存储在 self.value 中

    # One might also consider adding the built-in list type to this
    # list, to support operations like np.add(array_like, list)
    # 定义 _HANDLED_TYPES 元组,指定支持的数据类型,包括 ndarray 和 numbers.Number
    _HANDLED_TYPES = (np.ndarray, numbers.Number)

    # 实现 __array_ufunc__ 方法,处理 NumPy 的通用函数
    def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
        out = kwargs.get('out', ())  # 获取关键字参数中的 out 参数
        for x in inputs + out:
            # 只支持 _HANDLED_TYPES 中指定的数据类型或者 ArrayLike 类型的实例
            if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)):
                return NotImplemented

        # 对输入参数进行处理,如果是 ArrayLike 类型,则取其 value 属性作为操作对象
        inputs = tuple(x.value if isinstance(x, ArrayLike) else x
                       for x in inputs)
        if out:
            kwargs['out'] = tuple(
                x.value if isinstance(x, ArrayLike) else x
                for x in out)
        # 调用 ufunc 的具体方法进行计算
        result = getattr(ufunc, method)(*inputs, **kwargs)

        if type(result) is tuple:
            # 处理多返回值的情况,将每个返回值转换为 ArrayLike 类型
            return tuple(type(self)(x) for x in result)
        elif method == 'at':
            # 对于 'at' 方法,没有返回值
            return None
        else:
            # 处理单返回值的情况,将结果转换为 ArrayLike 类型
            return type(self)(result)

    # 重写 __repr__ 方法,返回对象的字符串表示形式
    def __repr__(self):
        return '%s(%r)' % (type(self).__name__, self.value)


# 定义一个函数 wrap_array_like,用于将返回值转换为 ArrayLike 类型
def wrap_array_like(result):
    if type(result) is tuple:
        return tuple(ArrayLike(r) for r in result)
    else:
        return ArrayLike(result)


# 定义一个私有函数 _assert_equal_type_and_value,用于断言两个对象的类型和值是否相等
def _assert_equal_type_and_value(result, expected, err_msg=None):
    assert_equal(type(result), type(expected), err_msg=err_msg)  # 断言结果类型与期望类型相等
    if isinstance(result, tuple):
        assert_equal(len(result), len(expected), err_msg=err_msg)  # 断言元组长度相等
        for result_item, expected_item in zip(result, expected):
            _assert_equal_type_and_value(result_item, expected_item, err_msg)  # 递归比较元组内的每个元素
    else:
        assert_equal(result.value, expected.value, err_msg=err_msg)  # 断言值属性相等
        assert_equal(getattr(result.value, 'dtype', None),
                     getattr(expected.value, 'dtype', None), err_msg=err_msg)  # 断言 dtype 相等


# 定义一个列表 _ALL_BINARY_OPERATORS,包含常用的二元运算符函数
_ALL_BINARY_OPERATORS = [
    operator.lt,         # 小于运算符 <
    operator.le,         # 小于等于运算符 <=
    operator.eq,         # 等于运算符 ==
    operator.ne,         # 不等于运算符 !=
    operator.gt,         # 大于运算符 >
    operator.ge,         # 大于等于运算符 >=
    operator.add,        # 加法运算符 +
    operator.sub,        # 减法运算符 -
    operator.mul,        # 乘法运算符 *
    operator.truediv,    # 真除法运算符 /
    operator.floordiv,   # 地板除法运算符 //
    operator.mod,        # 取模运算符 %
    divmod,              # 返回商和余数的元组
    pow,                 # 幂运算符 **
    operator.lshift,     # 左移运算符 <<
    operator.rshift,     # 右移运算符 >>
    operator.and_,       # 按位与运算符 &
    operator.xor,        # 按位异或运算符 ^
    operator.or_,        # 按位或运算符 |
]


# 定义一个测试类 TestNDArrayOperatorsMixin,用于测试 NDArrayOperatorsMixin 类的功能
class TestNDArrayOperatorsMixin:
    def test_array_like_add(self):
        # 定义内部函数check,用于验证结果是否符合预期
        def check(result):
            _assert_equal_type_and_value(result, ArrayLike(0))

        # 测试 ArrayLike 对象与数字的加法
        check(ArrayLike(0) + 0)
        check(0 + ArrayLike(0))

        # 测试 ArrayLike 对象与 NumPy 数组的加法
        check(ArrayLike(0) + np.array(0))
        check(np.array(0) + ArrayLike(0))

        # 测试 ArrayLike 对象封装的 NumPy 数组与数字的加法
        check(ArrayLike(np.array(0)) + 0)
        check(0 + ArrayLike(np.array(0)))

        # 测试 ArrayLike 对象封装的 NumPy 数组与 NumPy 数组的加法
        check(ArrayLike(np.array(0)) + np.array(0))
        check(np.array(0) + ArrayLike(np.array(0)))

    def test_inplace(self):
        # 创建一个封装了 NumPy 数组 [0] 的 ArrayLike 对象
        array_like = ArrayLike(np.array([0]))
        # 对 ArrayLike 对象执行就地操作,加上 1
        array_like += 1
        # 验证执行操作后的类型和值是否与预期一致
        _assert_equal_type_and_value(array_like, ArrayLike(np.array([1])))

        # 创建一个 NumPy 数组 [0]
        array = np.array([0])
        # 对 NumPy 数组执行就地操作,加上 ArrayLike 对象封装的值 1
        array += ArrayLike(1)
        # 验证执行操作后的类型和值是否与预期一致
        _assert_equal_type_and_value(array, ArrayLike(np.array([1])))

    def test_opt_out(self):
        # 定义一个类 OptOut,该类不支持 __array_ufunc__,并定义了加法运算
        class OptOut:
            """Object that opts out of __array_ufunc__."""
            __array_ufunc__ = None

            def __add__(self, other):
                return self

            def __radd__(self, other):
                return self

        # 创建一个封装了值 1 的 ArrayLike 对象
        array_like = ArrayLike(1)
        # 创建一个 OptOut 对象
        opt_out = OptOut()

        # 支持的操作:ArrayLike 对象与 OptOut 对象的加法返回 OptOut 对象
        assert_(array_like + opt_out is opt_out)
        assert_(opt_out + array_like is opt_out)

        # 不支持的操作:预期会抛出 TypeError 异常
        with assert_raises(TypeError):
            # 不要使用默认的 Python 操作,避免执行 array_like = array_like + opt_out
            array_like += opt_out
        with assert_raises(TypeError):
            array_like - opt_out
        with assert_raises(TypeError):
            opt_out - array_like

    def test_subclass(self):
        # 定义一个继承自 ArrayLike 的子类 SubArrayLike
        class SubArrayLike(ArrayLike):
            """Should take precedence over ArrayLike."""

        # 创建 ArrayLike 对象 x 和 SubArrayLike 对象 y
        x = ArrayLike(0)
        y = SubArrayLike(1)
        # 验证 ArrayLike 对象与 SubArrayLike 对象的加法结果与预期一致
        _assert_equal_type_and_value(x + y, y)
        _assert_equal_type_and_value(y + x, y)

    def test_object(self):
        # 创建一个封装了值 0 的 ArrayLike 对象
        x = ArrayLike(0)
        # 创建一个普通对象 obj
        obj = object()
        # 预期会抛出 TypeError 异常,因为 ArrayLike 对象与普通对象不能进行加法操作
        with assert_raises(TypeError):
            x + obj
        with assert_raises(TypeError):
            obj + x
        with assert_raises(TypeError):
            x += obj

    def test_unary_methods(self):
        # 创建一个 NumPy 数组 [-1, 0, 1, 2]
        array = np.array([-1, 0, 1, 2])
        # 创建一个封装了该数组的 ArrayLike 对象
        array_like = ArrayLike(array)
        # 遍历一组一元操作函数,验证其对 ArrayLike 对象的操作结果是否与预期一致
        for op in [operator.neg,
                   operator.pos,
                   abs,
                   operator.invert]:
            _assert_equal_type_and_value(op(array_like), ArrayLike(op(array)))

    def test_forward_binary_methods(self):
        # 创建一个 NumPy 数组 [-1, 0, 1, 2]
        array = np.array([-1, 0, 1, 2])
        # 创建一个封装了该数组的 ArrayLike 对象
        array_like = ArrayLike(array)
        # 遍历一组二元操作函数,验证其对 ArrayLike 对象与标量的操作结果是否与预期一致
        for op in _ALL_BINARY_OPERATORS:
            expected = wrap_array_like(op(array, 1))
            actual = op(array_like, 1)
            err_msg = 'failed for operator {}'.format(op)
            _assert_equal_type_and_value(expected, actual, err_msg=err_msg)
    # 对每个二元运算符进行测试
    def test_reflected_binary_methods(self):
        for op in _ALL_BINARY_OPERATORS:
            # 使用操作符对(2, ArrayLike(1))执行运算,并包装结果为类似数组
            expected = wrap_array_like(op(2, 1))
            # 使用操作符对ArrayLike(1)和2执行运算
            actual = op(2, ArrayLike(1))
            err_msg = 'failed for operator {}'.format(op)
            # 断言期望值和实际值类型与数值都相等
            _assert_equal_type_and_value(expected, actual, err_msg=err_msg)

    # 测试矩阵乘法运算
    def test_matmul(self):
        # 创建浮点64位数组
        array = np.array([1, 2], dtype=np.float64)
        # 将数组转为类似数组对象
        array_like = ArrayLike(array)
        # 预期结果是包装后的浮点64位数
        expected = ArrayLike(np.float64(5))
        # 断言类似数组对象与原数组的矩阵乘法结果类型和数值相等
        _assert_equal_type_and_value(expected, np.matmul(array_like, array))
        # 断言类似数组对象与原数组的矩阵乘法运算结果类型和数值相等
        _assert_equal_type_and_value(
            expected, operator.matmul(array_like, array))
        # 断言原数组与类似数组对象的矩阵乘法运算结果类型和数值相等
        _assert_equal_type_and_value(
            expected, operator.matmul(array, array_like))

    # 测试 ufunc 的 at 方法
    def test_ufunc_at(self):
        # 创建类似数组对象
        array = ArrayLike(np.array([1, 2, 3, 4]))
        # 断言负数 ufunc 在指定索引位置的运算结果为 None
        assert_(np.negative.at(array, np.array([0, 1])) is None)
        # 断言类似数组对象的值为 [-1, -2, 3, 4]
        _assert_equal_type_and_value(array, ArrayLike([-1, -2, 3, 4]))

    # 测试 ufunc 返回两个输出的情况
    def test_ufunc_two_outputs(self):
        # 计算 2 的 -3 次方的尾数和指数
        mantissa, exponent = np.frexp(2 ** -3)
        # 预期结果是封装后的尾数和指数
        expected = (ArrayLike(mantissa), ArrayLike(exponent))
        # 断言类似数组对象计算 2 的 -3 次方的 frexp 结果与预期结果类型和数值相等
        _assert_equal_type_and_value(
            np.frexp(ArrayLike(2 ** -3)), expected)
        # 断言类似数组对象计算 np.array(2 的 -3 次方) 的 frexp 结果与预期结果类型和数值相等
        _assert_equal_type_and_value(
            np.frexp(ArrayLike(np.array(2 ** -3))), expected)

.\numpy\numpy\lib\tests\test_nanfunctions.py

# 引入警告模块,用于管理和控制警告信息的显示
import warnings
# 引入 pytest 模块,用于编写和运行测试用例
import pytest
# 引入 inspect 模块,提供了对 Python 对象内部结构的访问
import inspect
# 从 functools 模块中引入 partial 函数,用于部分应用函数
from functools import partial

# 引入 numpy 库,用于科学计算
import numpy as np
# 引入 normalize_axis_tuple 函数,用于规范化轴元组
from numpy._core.numeric import normalize_axis_tuple
# 引入 AxisError 和 ComplexWarning 异常类,用于处理轴错误和复数警告
from numpy.exceptions import AxisError, ComplexWarning
# 引入 _nan_mask 和 _replace_nan 函数,用于处理 NaN 值的掩码和替换
from numpy.lib._nanfunctions_impl import _nan_mask, _replace_nan
# 引入 numpy.testing 模块中的测试函数,用于进行各种断言和测试
from numpy.testing import (
    assert_, assert_equal, assert_almost_equal, assert_raises,
    assert_raises_regex, assert_array_equal, suppress_warnings
    )

# 测试数据,包含 NaN 值
_ndat = np.array([[0.6244, np.nan, 0.2692, 0.0116, np.nan, 0.1170],
                  [0.5351, -0.9403, np.nan, 0.2100, 0.4759, 0.2833],
                  [np.nan, np.nan, np.nan, 0.1042, np.nan, -0.5954],
                  [0.1610, np.nan, np.nan, 0.1859, 0.3146, np.nan]])

# 移除 NaN 值后的数据行
_rdat = [np.array([0.6244, 0.2692, 0.0116, 0.1170]),
         np.array([0.5351, -0.9403, 0.2100, 0.4759, 0.2833]),
         np.array([0.1042, -0.5954]),
         np.array([0.1610, 0.1859, 0.3146])]

# 将 NaN 值替换为 1.0 后的数据
_ndat_ones = np.array([[0.6244, 1.0, 0.2692, 0.0116, 1.0, 0.1170],
                       [0.5351, -0.9403, 1.0, 0.2100, 0.4759, 0.2833],
                       [1.0, 1.0, 1.0, 0.1042, 1.0, -0.5954],
                       [0.1610, 1.0, 1.0, 0.1859, 0.3146, 1.0]])

# 将 NaN 值替换为 0.0 后的数据
_ndat_zeros = np.array([[0.6244, 0.0, 0.2692, 0.0116, 0.0, 0.1170],
                        [0.5351, -0.9403, 0.0, 0.2100, 0.4759, 0.2833],
                        [0.0, 0.0, 0.0, 0.1042, 0.0, -0.5954],
                        [0.1610, 0.0, 0.0, 0.1859, 0.3146, 0.0]])


class TestSignatureMatch:
    # 定义一个字典,将 numpy 中处理 NaN 的函数映射到其对应的非 NaN 版本
    NANFUNCS = {
        np.nanmin: np.amin,
        np.nanmax: np.amax,
        np.nanargmin: np.argmin,
        np.nanargmax: np.argmax,
        np.nansum: np.sum,
        np.nanprod: np.prod,
        np.nancumsum: np.cumsum,
        np.nancumprod: np.cumprod,
        np.nanmean: np.mean,
        np.nanmedian: np.median,
        np.nanpercentile: np.percentile,
        np.nanquantile: np.quantile,
        np.nanvar: np.var,
        np.nanstd: np.std,
    }
    # 使用函数名作为参数化测试的标识
    IDS = [k.__name__ for k in NANFUNCS]

    @staticmethod
    def get_signature(func, default="..."):
        """构造函数签名并替换所有默认参数值。"""
        # 初始化参数列表
        prm_list = []
        # 获取函数的签名信息
        signature = inspect.signature(func)
        # 遍历签名中的每个参数
        for prm in signature.parameters.values():
            # 如果参数没有默认值,则直接添加到参数列表中
            if prm.default is inspect.Parameter.empty:
                prm_list.append(prm)
            else:
                # 否则,用指定的默认值替换参数的默认值
                prm_list.append(prm.replace(default=default))
        # 返回替换后的函数签名对象
        return inspect.Signature(prm_list)

    # 使用 pytest 的参数化装饰器,传入处理 NaN 的函数和对应的非 NaN 函数
    @pytest.mark.parametrize("nan_func,func", NANFUNCS.items(), ids=IDS)
    # 测试函数签名是否匹配的方法
    def test_signature_match(self, nan_func, func):
        # 忽略默认参数值,因为它们有时可能不同
        # 一个函数可能为 `False`,而另一个可能为 `np._NoValue`
        signature = self.get_signature(func)
        nan_signature = self.get_signature(nan_func)
        # 使用 NumPy 的测试工具检查两个函数的签名是否相等
        np.testing.assert_equal(signature, nan_signature)

    # 测试方法,验证所有的 NaN 函数是否都被测试到
    def test_exhaustiveness(self):
        """Validate that all nan functions are actually tested."""
        # 使用 NumPy 的测试工具,比较已测试的函数集合和 NumPy 内部所有 NaN 函数的集合
        np.testing.assert_equal(
            set(self.IDS), set(np.lib._nanfunctions_impl.__all__)
        )
# 定义一个测试类 TestNanFunctions_MinMax,用于测试处理 NaN 值的函数 np.nanmin 和 np.nanmax
class TestNanFunctions_MinMax:

    # 初始化类变量,包含处理 NaN 的函数列表和标准函数列表
    nanfuncs = [np.nanmin, np.nanmax]
    stdfuncs = [np.min, np.max]

    # 测试数组是否被修改的方法
    def test_mutation(self):
        # 复制原始数组 _ndat 到 ndat,确保不修改原始数据
        ndat = _ndat.copy()
        # 对 nanfuncs 中的每个函数 f,应用于 ndat
        for f in self.nanfuncs:
            f(ndat)
            # 断言 ndat 未被修改
            assert_equal(ndat, _ndat)

    # 测试 keepdims 参数的方法
    def test_keepdims(self):
        # 创建一个3x3的单位矩阵 mat
        mat = np.eye(3)
        # 对 nanfuncs 和 stdfuncs 中的每一对函数 nf 和 rf 进行迭代
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            # 对于每个可能的轴 axis:None, 0, 1
            for axis in [None, 0, 1]:
                # 计算使用 rf 函数在 mat 上的结果 tgt,并保持维度不变
                tgt = rf(mat, axis=axis, keepdims=True)
                # 计算使用 nf 函数在 mat 上的结果 res,并保持维度不变
                res = nf(mat, axis=axis, keepdims=True)
                # 断言 res 的维度与 tgt 的维度相同
                assert_(res.ndim == tgt.ndim)

    # 测试 out 参数的方法
    def test_out(self):
        # 创建一个3x3的单位矩阵 mat
        mat = np.eye(3)
        # 对 nanfuncs 和 stdfuncs 中的每一对函数 nf 和 rf 进行迭代
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            # 创建一个用于存储输出结果的 resout 数组
            resout = np.zeros(3)
            # 使用 rf 函数计算 mat 的结果 tgt,仅在 axis=1 时
            tgt = rf(mat, axis=1)
            # 使用 nf 函数计算 mat 的结果 res,将结果存储在 resout 中
            res = nf(mat, axis=1, out=resout)
            # 断言 res 与 resout 的值接近
            assert_almost_equal(res, resout)
            # 断言 res 与 tgt 的值接近
            assert_almost_equal(res, tgt)

    # 测试根据输入的 dtype 类型来确定输出的 dtype 的方法
    def test_dtype_from_input(self):
        # 定义一组 dtype 代码
        codes = 'efdgFDG'
        # 对 nanfuncs 和 stdfuncs 中的每一对函数 nf 和 rf 进行迭代
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            # 对于每个 dtype 代码 c
            for c in codes:
                # 创建一个 dtype 为 c 的3x3单位矩阵 mat
                mat = np.eye(3, dtype=c)
                # 使用 rf 函数计算 mat 的结果 tgt,仅在 axis=1 时,并确定其 dtype 类型
                tgt = rf(mat, axis=1).dtype.type
                # 使用 nf 函数计算 mat 的结果 res,并确定其 dtype 类型
                res = nf(mat, axis=1).dtype.type
                # 断言 res 的 dtype 类型与 tgt 的 dtype 类型相同
                assert_(res is tgt)
                # 在标量情况下进行断言
                tgt = rf(mat, axis=None).dtype.type
                res = nf(mat, axis=None).dtype.type
                assert_(res is tgt)

    # 测试函数返回值的方法
    def test_result_values(self):
        # 对 nanfuncs 和 stdfuncs 中的每一对函数 nf 和 rf 进行迭代
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            # 计算标准函数 rf 在 _rdat 中每个数组 d 上的结果列表 tgt
            tgt = [rf(d) for d in _rdat]
            # 使用 nf 函数计算 _ndat 在 axis=1 上的结果 res
            res = nf(_ndat, axis=1)
            # 断言 res 与 tgt 的值接近
            assert_almost_equal(res, tgt)

    # 使用 pytest.mark.parametrize 标记测试用例的方法,测试处理全为 NaN 的数组情况
    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    @pytest.mark.parametrize("array", [
        np.array(np.nan),  # 0维数组情况
        np.full((3, 3), np.nan),  # 2维数组情况
    ], ids=["0d", "2d"])
    def test_allnans(self, axis, dtype, array):
        # 如果 axis 不为 None 且 array 的维度为 0,则跳过该测试用例
        if axis is not None and array.ndim == 0:
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        # 将 array 转换为指定的 dtype 类型
        array = array.astype(dtype)
        # 匹配字符串 "All-NaN slice encountered"
        match = "All-NaN slice encountered"
        # 对 nanfuncs 中的每个函数 func 进行迭代
        for func in self.nanfuncs:
            # 使用 pytest.warns 检查是否会发出 RuntimeWarning 警告,匹配警告信息为 match
            with pytest.warns(RuntimeWarning, match=match):
                # 使用 func 函数计算 array 在指定 axis 上的结果 out
                out = func(array, axis=axis)
            # 断言 out 中所有的值都为 NaN
            assert np.isnan(out).all()
            # 断言 out 的 dtype 类型与 array 的 dtype 类型相同
            assert out.dtype == array.dtype

    # 测试处理带掩码的数组的方法
    def test_masked(self):
        # 创建一个包含无效值修正的 _ndat 的掩码数组 mat
        mat = np.ma.fix_invalid(_ndat)
        # 复制 mat 的掩码到 msk
        msk = mat._mask.copy()
        # 对于函数 f 中的每个函数 f,仅使用 np.nanmin
        for f in [np.nanmin]:
            # 使用 f 函数计算 mat 在 axis=1 上的结果 res
            res = f(mat, axis=1)
            # 使用 f 函数计算 _ndat 在 axis=1 上的结果 tgt
            tgt = f(_ndat, axis=1)
            # 断言 res 等于 tgt
            assert_equal(res, tgt)
            # 断言 mat 的掩码与 msk 相同
            assert_equal(mat._mask, msk)
            # 断言 mat 中不包含任何无穷值
            assert_(not np.isinf(mat).any())

    # 测试处理标量输入的方法
    def test_scalar(self):
        # 对 nanfuncs 中的每个函数 f 进行迭代
        for f in self.nanfuncs:
            # 断言 f 函数在输入为标量 0.0 时的结果为 0.0
            assert_(f(0.) == 0.)
    def test_subclass(self):
        # 定义一个自定义的 ndarray 子类 MyNDArray
        class MyNDArray(np.ndarray):
            pass

        # 创建一个 3x3 的单位矩阵,并将其视图转换为 MyNDArray 类型
        mine = np.eye(3).view(MyNDArray)

        # 对每个在 self.nanfuncs 中的函数进行测试
        for f in self.nanfuncs:
            # 测试沿 axis=0 方向的函数调用结果
            res = f(mine, axis=0)
            assert_(isinstance(res, MyNDArray))
            assert_(res.shape == (3,))

            # 测试沿 axis=1 方向的函数调用结果
            res = f(mine, axis=1)
            assert_(isinstance(res, MyNDArray))
            assert_(res.shape == (3,))

            # 测试没有指定 axis 的函数调用结果
            res = f(mine)
            assert_(res.shape == ())

        # 对包含 NaN 的行进行处理的测试 (#4628)
        mine[1] = np.nan
        for f in self.nanfuncs:
            # 捕获可能的警告信息
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')

                # 测试沿 axis=0 方向处理 NaN 行的结果
                res = f(mine, axis=0)
                assert_(isinstance(res, MyNDArray))
                assert_(not np.any(np.isnan(res)))
                assert_(len(w) == 0)

            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')

                # 测试沿 axis=1 方向处理 NaN 行的结果
                res = f(mine, axis=1)
                assert_(isinstance(res, MyNDArray))
                assert_(np.isnan(res[1]) and not np.isnan(res[0])
                        and not np.isnan(res[2]))
                assert_(len(w) == 1, 'no warning raised')
                assert_(issubclass(w[0].category, RuntimeWarning))

            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')

                # 测试没有指定 axis 处理 NaN 的结果
                res = f(mine)
                assert_(res.shape == ())
                assert_(res != np.nan)
                assert_(len(w) == 0)

    def test_object_array(self):
        # 创建一个包含 NaN 的对象数组
        arr = np.array([[1.0, 2.0], [np.nan, 4.0], [np.nan, np.nan]], dtype=object)

        # 测试 np.nanmin 在对象数组上的表现
        assert_equal(np.nanmin(arr), 1.0)
        assert_equal(np.nanmin(arr, axis=0), [1.0, 2.0])

        # 测试对对象数组使用 np.nanmin 时的警告情况
        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter('always')

            # 对比 np.nanmin 在 axis=1 上的结果
            # 注意:assert_equal 在处理对象数组的 NaN 时不适用
            assert_equal(list(np.nanmin(arr, axis=1)), [1.0, 4.0, np.nan])
            assert_(len(w) == 1, 'no warning raised')
            assert_(issubclass(w[0].category, RuntimeWarning))

    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    def test_initial(self, dtype):
        # 定义一个自定义的 ndarray 子类 MyNDArray
        class MyNDArray(np.ndarray):
            pass

        # 创建一个浮点类型的数组,并将前五个元素设为 NaN
        ar = np.arange(9).astype(dtype)
        ar[:5] = np.nan

        for f in self.nanfuncs:
            initial = 100 if f is np.nanmax else 0

            # 测试带有 initial 参数的函数调用结果
            ret1 = f(ar, initial=initial)
            assert ret1.dtype == dtype
            assert ret1 == initial

            # 测试对 MyNDArray 类型的视图进行函数调用的结果
            ret2 = f(ar.view(MyNDArray), initial=initial)
            assert ret2.dtype == dtype
            assert ret2 == initial
    # 定义一个测试方法,用于测试特定数据类型的函数
    def test_where(self, dtype):
        # 定义一个继承自 numpy.ndarray 的子类 MyNDArray
        class MyNDArray(np.ndarray):
            pass

        # 创建一个3x3的数组,元素为0到8,转换为指定的数据类型并设置第一行为 NaN
        ar = np.arange(9).reshape(3, 3).astype(dtype)
        ar[0, :] = np.nan

        # 创建一个与 ar 形状相同的全为 True 的布尔数组 where,并将第一列设为 False
        where = np.ones_like(ar, dtype=np.bool)
        where[:, 0] = False

        # 遍历 nanfuncs 中的每一个函数 f
        for f in self.nanfuncs:
            # 如果 f 是 np.nanmin,则 reference 为 4;否则为 8
            reference = 4 if f is np.nanmin else 8

            # 使用函数 f 计算 ar 数组中符合条件的最小值或最大值,初始值为 5
            ret1 = f(ar, where=where, initial=5)
            # 断言返回值的数据类型与指定的 dtype 相同
            assert ret1.dtype == dtype
            # 断言返回值等于预期的 reference 值
            assert ret1 == reference

            # 使用函数 f 计算 ar 数组(视图)中符合条件的最小值或最大值,初始值为 5
            ret2 = f(ar.view(MyNDArray), where=where, initial=5)
            # 断言返回值的数据类型与指定的 dtype 相同
            assert ret2.dtype == dtype
            # 断言返回值等于预期的 reference 值
            assert ret2 == reference
class TestNanFunctions_ArgminArgmax:
    # 定义一个测试类,用于测试 np.nanargmin 和 np.nanargmax 函数
    nanfuncs = [np.nanargmin, np.nanargmax]

    def test_mutation(self):
        # 检查传入的数组不会被修改
        ndat = _ndat.copy()
        for f in self.nanfuncs:
            f(ndat)
            assert_equal(ndat, _ndat)

    def test_result_values(self):
        for f, fcmp in zip(self.nanfuncs, [np.greater, np.less]):
            for row in _ndat:
                with suppress_warnings() as sup:
                    sup.filter(RuntimeWarning, "invalid value encountered in")
                    ind = f(row)
                    val = row[ind]
                    # 比较 NaN 可能有些棘手,因为结果总是 False,除了 NaN != NaN
                    assert_(not np.isnan(val))
                    assert_(not fcmp(val, row).any())
                    assert_(not np.equal(val, row[:ind]).any())

    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    @pytest.mark.parametrize("array", [
        np.array(np.nan),
        np.full((3, 3), np.nan),
    ], ids=["0d", "2d"])
    def test_allnans(self, axis, dtype, array):
        if axis is not None and array.ndim == 0:
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        array = array.astype(dtype)
        for func in self.nanfuncs:
            with pytest.raises(ValueError, match="All-NaN slice encountered"):
                func(array, axis=axis)

    def test_empty(self):
        mat = np.zeros((0, 3))
        for f in self.nanfuncs:
            for axis in [0, None]:
                assert_raises_regex(
                        ValueError,
                        "attempt to get argm.. of an empty sequence",
                        f, mat, axis=axis)
            for axis in [1]:
                res = f(mat, axis=axis)
                assert_equal(res, np.zeros(0))

    def test_scalar(self):
        for f in self.nanfuncs:
            assert_(f(0.) == 0.)

    def test_subclass(self):
        class MyNDArray(np.ndarray):
            pass

        # 检查函数能正常工作,并且类型和形状得到保留
        mine = np.eye(3).view(MyNDArray)
        for f in self.nanfuncs:
            res = f(mine, axis=0)
            assert_(isinstance(res, MyNDArray))
            assert_(res.shape == (3,))
            res = f(mine, axis=1)
            assert_(isinstance(res, MyNDArray))
            assert_(res.shape == (3,))
            res = f(mine)
            assert_(res.shape == ())

    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    def test_keepdims(self, dtype):
        ar = np.arange(9).astype(dtype)
        ar[:5] = np.nan

        for f in self.nanfuncs:
            reference = 5 if f is np.nanargmin else 8
            ret = f(ar, keepdims=True)
            assert ret.ndim == ar.ndim
            assert ret == reference

    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    # 定义一个测试方法,用于测试NaN处理函数的行为
    def test_out(self, dtype):
        # 创建一个包含0到8的数组,并将其转换为指定数据类型(dtype)
        ar = np.arange(9).astype(dtype)
        # 将数组的前5个元素设置为NaN
        ar[:5] = np.nan
    
        # 遍历NaN处理函数列表
        for f in self.nanfuncs:
            # 创建一个dtype为np.intp的零维数组out,用于接收函数的输出
            out = np.zeros((), dtype=np.intp)
            # 根据函数类型设置参考值,如果是np.nanargmin,则参考值为5,否则为8
            reference = 5 if f is np.nanargmin else 8
            # 调用NaN处理函数f,将ar作为输入,将结果存入out
            ret = f(ar, out=out)
            # 断言返回值ret与out相同
            assert ret is out
            # 断言返回值ret与参考值reference相同
            assert ret == reference
# 定义测试用例中使用的示例数组集合
_TEST_ARRAYS = {
    "0d": np.array(5),                             # 创建一个0维的NumPy数组,包含单个整数5
    "1d": np.array([127, 39, 93, 87, 46])           # 创建一个1维的NumPy数组,包含多个整数
}

# 设置所有数组为不可写以确保测试不会修改它们
for _v in _TEST_ARRAYS.values():
    _v.setflags(write=False)

# 使用pytest的参数化标记定义多个测试参数
@pytest.mark.parametrize(
    "dtype",                                         # 参数名为dtype,用于测试不同的数据类型
    np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "O",  # 测试所有整数、浮点数和对象类型
)
@pytest.mark.parametrize(
    "mat", _TEST_ARRAYS.values(),                    # 参数名为mat,用于测试_TEST_ARRAYS中的所有数组
    ids=_TEST_ARRAYS.keys()                          # 用_TEST_ARRAYS中的键作为数组的标识符
)
class TestNanFunctions_NumberTypes:
    # 定义NaN相关函数与其对应的标准函数的映射关系
    nanfuncs = {
        np.nanmin: np.min,                           # NaN最小值与最小值函数的映射关系
        np.nanmax: np.max,                           # NaN最大值与最大值函数的映射关系
        np.nanargmin: np.argmin,                     # NaN最小值位置与最小值位置函数的映射关系
        np.nanargmax: np.argmax,                     # NaN最大值位置与最大值位置函数的映射关系
        np.nansum: np.sum,                           # NaN求和与求和函数的映射关系
        np.nanprod: np.prod,                         # NaN累积乘积与累积乘积函数的映射关系
        np.nancumsum: np.cumsum,                     # NaN累积求和与累积求和函数的映射关系
        np.nancumprod: np.cumprod,                   # NaN累积乘积与累积乘积函数的映射关系
        np.nanmean: np.mean,                         # NaN均值与均值函数的映射关系
        np.nanmedian: np.median,                     # NaN中位数与中位数函数的映射关系
        np.nanvar: np.var,                           # NaN方差与方差函数的映射关系
        np.nanstd: np.std                            # NaN标准差与标准差函数的映射关系
    }
    nanfunc_ids = [i.__name__ for i in nanfuncs]      # 提取函数名用于测试标识符的参数化

    # 使用参数化标记定义测试函数,测试NaN函数与其对应的标准函数
    @pytest.mark.parametrize("nanfunc,func", nanfuncs.items(), ids=nanfunc_ids)
    @np.errstate(over="ignore")
    def test_nanfunc(self, mat, dtype, nanfunc, func):
        mat = mat.astype(dtype)                       # 将mat数组转换为指定的数据类型
        tgt = func(mat)                               # 计算标准函数的结果
        out = nanfunc(mat)                            # 计算NaN函数的结果

        assert_almost_equal(out, tgt)                 # 断言NaN函数的结果与标准函数的结果几乎相等
        if dtype == "O":
            assert type(out) is type(tgt)             # 如果数据类型为对象类型,断言NaN函数与标准函数的结果类型相同
        else:
            assert out.dtype == tgt.dtype             # 否则,断言NaN函数的结果与标准函数的结果的数据类型相同

    # 使用参数化标记定义测试函数,测试NaN分位数和百分位数函数
    @pytest.mark.parametrize(
        "nanfunc,func",
        [(np.nanquantile, np.quantile), (np.nanpercentile, np.percentile)],
        ids=["nanquantile", "nanpercentile"],
    )
    def test_nanfunc_q(self, mat, dtype, nanfunc, func):
        mat = mat.astype(dtype)                       # 将mat数组转换为指定的数据类型
        if mat.dtype.kind == "c":
            assert_raises(TypeError, func, mat, q=1)  # 复数数组不支持分位数和百分位数计算,断言引发TypeError
            assert_raises(TypeError, nanfunc, mat, q=1)

        else:
            tgt = func(mat, q=1)                      # 计算标准分位数或百分位数的结果
            out = nanfunc(mat, q=1)                   # 计算NaN分位数或百分位数的结果

            assert_almost_equal(out, tgt)             # 断言NaN函数的结果与标准函数的结果几乎相等

            if dtype == "O":
                assert type(out) is type(tgt)         # 如果数据类型为对象类型,断言NaN函数与标准函数的结果类型相同
            else:
                assert out.dtype == tgt.dtype         # 否则,断言NaN函数的结果与标准函数的结果的数据类型相同

    # 使用参数化标记定义测试函数,测试NaN方差和标准差函数的ddof参数
    @pytest.mark.parametrize(
        "nanfunc,func",
        [(np.nanvar, np.var), (np.nanstd, np.std)],
        ids=["nanvar", "nanstd"],
    )
    def test_nanfunc_ddof(self, mat, dtype, nanfunc, func):
        mat = mat.astype(dtype)                       # 将mat数组转换为指定的数据类型
        tgt = func(mat, ddof=0.5)                     # 计算标准函数的结果,使用ddof参数为0.5
        out = nanfunc(mat, ddof=0.5)                  # 计算NaN函数的结果,使用ddof参数为0.5

        assert_almost_equal(out, tgt)                 # 断言NaN函数的结果与标准函数的结果几乎相等
        if dtype == "O":
            assert type(out) is type(tgt)             # 如果数据类型为对象类型,断言NaN函数与标准函数的结果类型相同
        else:
            assert out.dtype == tgt.dtype             # 否则,断言NaN函数的结果与标准函数的结果的数据类型相同

    # 使用参数化标记定义测试函数,测试NaN方差和标准差函数的correction参数
    @pytest.mark.parametrize(
        "nanfunc", [np.nanvar, np.nanstd]
    )
    def test_nanfunc_correction(self, mat, dtype, nanfunc):
        mat = mat.astype(dtype)                       # 将mat数组转换为指定的数据类型
        assert_almost_equal(
            nanfunc(mat, correction=0.5),             # 断言使用correction参数0.5计算的NaN函数结果与使用ddof参数0.5计算的结果几乎相等
            nanfunc(mat, ddof=0.5)
        )

        err_msg = "ddof and correction can't be provided simultaneously."
        with assert_raises_regex(ValueError, err_msg):
            nanfunc(mat, ddof=0.5, correction=0.5)    # 断言当同时提供ddof和correction参数时,会引发ValueError异常

        with assert_raises_regex(ValueError, err_msg):
            nanfunc(mat, ddof=1, correction=0)        # 断言当提供ddof参数为1和correction参数时,会引发ValueError异常
    def test_mutation(self):
        # 检查传入的数组未被修改
        ndat = _ndat.copy()  # 复制 _ndat 数组以防止修改原始数据
        for f in self.nanfuncs:
            f(ndat)  # 调用函数 f 对 ndat 进行操作
            assert_equal(ndat, _ndat)  # 断言 ndat 与 _ndat 相等,验证未修改原始数据

    def test_keepdims(self):
        mat = np.eye(3)  # 创建一个 3x3 的单位矩阵
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            for axis in [None, 0, 1]:
                tgt = rf(mat, axis=axis, keepdims=True)  # 调用 rf 函数计算结果并保持维度
                res = nf(mat, axis=axis, keepdims=True)  # 调用 nf 函数计算结果并保持维度
                assert_(res.ndim == tgt.ndim)  # 断言 res 和 tgt 的维度相同

    def test_out(self):
        mat = np.eye(3)  # 创建一个 3x3 的单位矩阵
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            resout = np.zeros(3)  # 创建一个长度为 3 的零向量
            tgt = rf(mat, axis=1)  # 调用 rf 函数计算结果
            res = nf(mat, axis=1, out=resout)  # 调用 nf 函数计算结果,并将结果存储到 resout 中
            assert_almost_equal(res, resout)  # 断言 nf 计算结果与 resout 几乎相等
            assert_almost_equal(res, tgt)  # 断言 nf 计算结果与 rf 计算结果几乎相等

    def test_dtype_from_dtype(self):
        mat = np.eye(3)  # 创建一个 3x3 的单位矩阵
        codes = 'efdgFDG'
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            for c in codes:
                with suppress_warnings() as sup:
                    if nf in {np.nanstd, np.nanvar} and c in 'FDG':
                        sup.filter(ComplexWarning)  # 过滤掉复杂类型警告
                    tgt = rf(mat, dtype=np.dtype(c), axis=1).dtype.type  # 指定数据类型进行计算并获取类型
                    res = nf(mat, dtype=np.dtype(c), axis=1).dtype.type  # 使用 nf 函数计算相同数据类型的结果类型
                    assert_(res is tgt)  # 断言 nf 和 rf 的结果类型相同
                    # scalar case
                    tgt = rf(mat, dtype=np.dtype(c), axis=None).dtype.type  # 沿单个轴进行计算并获取类型
                    res = nf(mat, dtype=np.dtype(c), axis=None).dtype.type  # 使用 nf 函数计算相同数据类型的结果类型
                    assert_(res is tgt)  # 断言 nf 和 rf 的结果类型相同

    def test_dtype_from_char(self):
        mat = np.eye(3)  # 创建一个 3x3 的单位矩阵
        codes = 'efdgFDG'
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            for c in codes:
                with suppress_warnings() as sup:
                    if nf in {np.nanstd, np.nanvar} and c in 'FDG':
                        sup.filter(ComplexWarning)  # 过滤掉复杂类型警告
                    tgt = rf(mat, dtype=c, axis=1).dtype.type  # 使用字符指定数据类型进行计算并获取类型
                    res = nf(mat, dtype=c, axis=1).dtype.type  # 使用 nf 函数计算相同数据类型的结果类型
                    assert_(res is tgt)  # 断言 nf 和 rf 的结果类型相同
                    # scalar case
                    tgt = rf(mat, dtype=c, axis=None).dtype.type  # 沿单个轴进行计算并获取类型
                    res = nf(mat, dtype=c, axis=None).dtype.type  # 使用 nf 函数计算相同数据类型的结果类型
                    assert_(res is tgt)  # 断言 nf 和 rf 的结果类型相同

    def test_dtype_from_input(self):
        codes = 'efdgFDG'
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            for c in codes:
                mat = np.eye(3, dtype=c)  # 创建指定数据类型的 3x3 单位矩阵
                tgt = rf(mat, axis=1).dtype.type  # 指定轴进行计算并获取结果类型
                res = nf(mat, axis=1).dtype.type  # 使用 nf 函数计算相同轴的结果类型
                assert_(res is tgt, "res %s, tgt %s" % (res, tgt))  # 断言 nf 和 rf 的结果类型相同,否则输出详细信息
                # scalar case
                tgt = rf(mat, axis=None).dtype.type  # 沿单个轴进行计算并获取类型
                res = nf(mat, axis=None).dtype.type  # 使用 nf 函数计算相同轴的结果类型
                assert_(res is tgt)  # 断言 nf 和 rf 的结果类型相同
    # 定义测试方法以验证结果值
    def test_result_values(self):
        # 对于每对自定义函数和标准函数,分别进行测试
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            # 创建目标结果列表,其中包含对标准函数在_rdat数据上的应用结果
            tgt = [rf(d) for d in _rdat]
            # 对自定义函数在_ndat上执行轴向为1的操作,得到结果
            res = nf(_ndat, axis=1)
            # 断言结果近似等于目标结果列表
            assert_almost_equal(res, tgt)

    # 定义测试标量情况的方法
    def test_scalar(self):
        # 对于每个自定义函数,测试其对标量0.的返回值是否为0.
        for f in self.nanfuncs:
            assert_(f(0.) == 0.)

    # 定义测试子类情况的方法
    def test_subclass(self):
        # 定义一个继承自np.ndarray的子类MyNDArray
        class MyNDArray(np.ndarray):
            pass

        # 创建一个3x3的单位矩阵
        array = np.eye(3)
        # 将array视图转换为MyNDArray类型的对象mine
        mine = array.view(MyNDArray)

        # 对于每个自定义函数,验证其在不同轴上操作后返回的类型和形状与预期一致
        for f in self.nanfuncs:
            # 预期轴为0时的形状
            expected_shape = f(array, axis=0).shape
            # 对mine在轴为0上执行函数f,得到结果res
            res = f(mine, axis=0)
            # 断言res的类型为MyNDArray
            assert_(isinstance(res, MyNDArray))
            # 断言res的形状与预期一致
            assert_(res.shape == expected_shape)

            # 预期轴为1时的形状
            expected_shape = f(array, axis=1).shape
            # 对mine在轴为1上执行函数f,得到结果res
            res = f(mine, axis=1)
            # 断言res的类型为MyNDArray
            assert_(isinstance(res, MyNDArray))
            # 断言res的形状与预期一致
            assert_(res.shape == expected_shape)

            # 对于不指定轴的情况,验证返回结果的形状
            expected_shape = f(array).shape
            # 对mine执行函数f,得到结果res
            res = f(mine)
            # 断言res的类型为MyNDArray
            assert_(isinstance(res, MyNDArray))
            # 断言res的形状与预期一致
            assert_(res.shape == expected_shape)
# 定义一个测试类 TestNanFunctions_SumProd,继承自 SharedNanFunctionsTestsMixin
class TestNanFunctions_SumProd(SharedNanFunctionsTestsMixin):

    # nanfuncs 列表包含 np.nansum 和 np.nanprod 函数
    nanfuncs = [np.nansum, np.nanprod]
    # stdfuncs 列表包含 np.sum 和 np.prod 函数
    stdfuncs = [np.sum, np.prod]

    # 使用 pytest.mark.parametrize 标记的参数化测试方法
    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    @pytest.mark.parametrize("array", [
        np.array(np.nan),  # 创建一个包含单个 NaN 值的数组
        np.full((3, 3), np.nan),  # 创建一个全部元素为 NaN 的 3x3 数组
    ], ids=["0d", "2d"])  # 分别用 "0d" 和 "2d" 标识两个测试用例
    def test_allnans(self, axis, dtype, array):
        # 如果 axis 不为 None 且 array 的维度为 0,则跳过测试,并显示相应的提示信息
        if axis is not None and array.ndim == 0:
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        # 将 array 转换为指定的 dtype 类型
        array = array.astype(dtype)
        # 对于 nanfuncs 列表中的每个函数 func 和对应的 identity 值
        for func, identity in zip(self.nanfuncs, [0, 1]):
            # 调用 func 函数计算 array 在指定 axis 上的操作结果
            out = func(array, axis=axis)
            # 断言结果 out 中所有的元素等于预期的 identity 值
            assert np.all(out == identity)
            # 断言结果 out 的数据类型与 array 的数据类型相同
            assert out.dtype == array.dtype

    # 定义测试空数组情况的方法
    def test_empty(self):
        # 对于 nanfuncs 列表中的每个函数 f 和其对应的目标值 tgt_value
        for f, tgt_value in zip([np.nansum, np.nanprod], [0, 1]):
            # 创建一个形状为 (0, 3) 的全零数组 mat
            mat = np.zeros((0, 3))
            # 设置目标值 tgt 为长度为 3 的列表,其元素均为 tgt_value
            tgt = [tgt_value]*3
            # 调用函数 f 计算 mat 在 axis=0 上的结果 res
            res = f(mat, axis=0)
            # 断言 res 与目标值 tgt 相等
            assert_equal(res, tgt)
            # 设置目标值 tgt 为空列表
            tgt = []
            # 调用函数 f 计算 mat 在 axis=1 上的结果 res
            res = f(mat, axis=1)
            # 断言 res 与目标值 tgt 相等
            assert_equal(res, tgt)
            # 设置目标值 tgt 为单一的 tgt_value
            tgt = tgt_value
            # 调用函数 f 计算 mat 在 axis=None 上的结果 res
            res = f(mat, axis=None)
            # 断言 res 等于目标值 tgt
            assert_equal(res, tgt)

    # 使用 pytest.mark.parametrize 标记的参数化测试方法
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    def test_initial(self, dtype):
        # 创建一个长度为 9 的数组 ar,并转换为指定的 dtype 类型
        ar = np.arange(9).astype(dtype)
        # 将数组 ar 中的前 5 个元素设为 NaN
        ar[:5] = np.nan

        # 对于 nanfuncs 列表中的每个函数 f
        for f in self.nanfuncs:
            # 设置参考值 reference,根据 f 是 np.nansum 还是 np.nanprod 不同而不同
            reference = 28 if f is np.nansum else 3360
            # 调用函数 f 计算 ar 的结果 ret,并指定 initial 参数为 2
            ret = f(ar, initial=2)
            # 断言 ret 的数据类型与 dtype 相同
            assert ret.dtype == dtype
            # 断言 ret 等于参考值 reference
            assert ret == reference

    # 使用 pytest.mark.parametrize 标记的参数化测试方法
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    def test_where(self, dtype):
        # 创建一个形状为 (3, 3) 的数组 ar,并转换为指定的 dtype 类型
        ar = np.arange(9).reshape(3, 3).astype(dtype)
        # 将数组 ar 的第一行所有元素设为 NaN
        ar[0, :] = np.nan
        # 创建一个与 ar 相同形状的布尔数组 where,并将所有元素初始化为 True
        where = np.ones_like(ar, dtype=np.bool)
        # 将 where 的第一列所有元素设为 False
        where[:, 0] = False

        # 对于 nanfuncs 列表中的每个函数 f
        for f in self.nanfuncs:
            # 设置参考值 reference,根据 f 是 np.nansum 还是 np.nanprod 不同而不同
            reference = 26 if f is np.nansum else 2240
            # 调用函数 f 计算 ar 在给定 where 和 initial=2 的条件下的结果 ret
            ret = f(ar, where=where, initial=2)
            # 断言 ret 的数据类型与 dtype 相同
            assert ret.dtype == dtype
            # 断言 ret 等于参考值 reference
            assert ret == reference


# 定义一个测试类 TestNanFunctions_CumSumProd,继承自 SharedNanFunctionsTestsMixin
class TestNanFunctions_CumSumProd(SharedNanFunctionsTestsMixin):

    # nanfuncs 列表包含 np.nancumsum 和 np.nancumprod 函数
    nanfuncs = [np.nancumsum, np.nancumprod]
    # stdfuncs 列表包含 np.cumsum 和 np.cumprod 函数
    stdfuncs = [np.cumsum, np.cumprod]

    # 使用 pytest.mark.parametrize 标记的参数化测试方法
    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    @pytest.mark.parametrize("array", [
        np.array(np.nan),  # 创建一个包含单个 NaN 值的数组
        np.full((3, 3), np.nan)  # 创建一个全部元素为 NaN 的 3x3 数组
    ], ids=["0d", "2d"])  # 分别用 "0d" 和 "2d" 标识两个测试用例
    def test_allnans(self, axis, dtype, array):
        # 如果 axis 不为 None 且 array 的维度为 0,则跳过测试,并显示相应的提示信息
        if axis is not None and array.ndim == 0:
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        # 将 array 转换为指定的 dtype 类型
        array = array.astype(dtype)
        # 对于 nanfuncs 列表中的每个函数 func 和对应的 identity 值
        for func, identity in zip(self.nanfuncs, [0, 1]):
            # 调用 func 函数计算 array 上的累积操作结果 out
            out = func(array)
            # 断言结果 out 中所有的元素等于预期的 identity 值
            assert np.all(out == identity)
            # 断言结果 out 的数据类型与 array 的数据类型相同
            assert out.dtype == array.dtype
    # 测试空矩阵情况下各函数的行为
    def test_empty(self):
        # 遍历 nanfuncs 和对应的目标值列表
        for f, tgt_value in zip(self.nanfuncs, [0, 1]):
            # 创建一个空的 0x3 的矩阵
            mat = np.zeros((0, 3))
            # 创建一个与 mat 相同形状的矩阵,填充为目标值的倍数
            tgt = tgt_value * np.ones((0, 3))
            # 使用函数 f 计算 mat 沿 axis=0 的结果
            res = f(mat, axis=0)
            # 断言 res 与目标值 tgt 相等
            assert_equal(res, tgt)
            # 将目标值设为 mat 自身
            tgt = mat
            # 使用函数 f 计算 mat 沿 axis=1 的结果
            res = f(mat, axis=1)
            # 断言 res 与目标值 tgt 相等
            assert_equal(res, tgt)
            # 创建一个空的 0 维数组
            tgt = np.zeros((0))
            # 使用函数 f 计算 mat 沿 axis=None 的结果
            res = f(mat, axis=None)
            # 断言 res 与目标值 tgt 相等
            assert_equal(res, tgt)

    # 测试 keepdims 参数对函数行为的影响
    def test_keepdims(self):
        # 遍历 nanfuncs 和 stdfuncs
        for f, g in zip(self.nanfuncs, self.stdfuncs):
            # 创建一个 3x3 的单位矩阵
            mat = np.eye(3)
            # 遍历 axis 参数的可能取值
            for axis in [None, 0, 1]:
                # 使用函数 f 计算 mat 的结果,并且不指定输出
                tgt = f(mat, axis=axis, out=None)
                # 使用函数 g 计算 mat 的结果,并且不指定输出
                res = g(mat, axis=axis, out=None)
                # 断言 res 和 tgt 的维度相等
                assert_(res.ndim == tgt.ndim)

        # 再次遍历 nanfuncs
        for f in self.nanfuncs:
            # 创建一个形状为 (3, 5, 7, 11) 的全为 1 的数组
            d = np.ones((3, 5, 7, 11))
            # 随机将一些元素设为 NaN
            rs = np.random.RandomState(0)
            d[rs.rand(*d.shape) < 0.5] = np.nan
            # 使用函数 f 计算数组 d 沿 axis=None 的结果
            res = f(d, axis=None)
            # 断言 res 的形状为 (1155,)
            assert_equal(res.shape, (1155,))
            # 遍历 axis 的所有可能取值
            for axis in np.arange(4):
                # 使用函数 f 计算数组 d 沿指定 axis 的结果
                res = f(d, axis=axis)
                # 断言 res 的形状与数组 d 相同
                assert_equal(res.shape, (3, 5, 7, 11))

    # 测试结果值是否正确的断言
    def test_result_values(self):
        # 遍历 axis 的多个可能取值
        for axis in (-2, -1, 0, 1, None):
            # 计算 _ndat_ones 沿指定 axis 的累积乘积
            tgt = np.cumprod(_ndat_ones, axis=axis)
            # 计算 _ndat 沿指定 axis 的累积乘积,跳过 NaN 值
            res = np.nancumprod(_ndat, axis=axis)
            # 断言 res 和 tgt 的近似相等
            assert_almost_equal(res, tgt)
            # 计算 _ndat_zeros 沿指定 axis 的累积和
            tgt = np.cumsum(_ndat_zeros, axis=axis)
            # 计算 _ndat 沿指定 axis 的累积和,跳过 NaN 值
            res = np.nancumsum(_ndat, axis=axis)
            # 断言 res 和 tgt 的近似相等
            assert_almost_equal(res, tgt)

    # 测试输出参数 out 对函数行为的影响
    def test_out(self):
        # 创建一个 3x3 的单位矩阵
        mat = np.eye(3)
        # 遍历 nanfuncs 和 stdfuncs
        for nf, rf in zip(self.nanfuncs, self.stdfuncs):
            # 创建一个与 mat 相同形状的单位矩阵作为输出容器
            resout = np.eye(3)
            # 遍历 axis 的多个可能取值
            for axis in (-2, -1, 0, 1):
                # 使用函数 rf 计算 mat 沿指定 axis 的结果
                tgt = rf(mat, axis=axis)
                # 使用函数 nf 计算 mat 沿指定 axis 的结果,并将结果写入 resout
                res = nf(mat, axis=axis, out=resout)
                # 断言 res 与 resout 的近似相等
                assert_almost_equal(res, resout)
                # 断言 res 与 tgt 的近似相等
                assert_almost_equal(res, tgt)
class TestNanFunctions_MeanVarStd(SharedNanFunctionsTestsMixin):
    # 继承自 SharedNanFunctionsTestsMixin 的测试类,用于测试 NaN 相关函数的行为

    nanfuncs = [np.nanmean, np.nanvar, np.nanstd]
    # 包含 NaN 函数的列表:nanmean, nanvar, nanstd

    stdfuncs = [np.mean, np.var, np.std]
    # 标准函数的列表:mean, var, std

    def test_dtype_error(self):
        # 测试数据类型错误的情况
        for f in self.nanfuncs:
            for dtype in [np.bool, np.int_, np.object_]:
                # 对于每个 NaN 函数和指定的数据类型
                assert_raises(TypeError, f, _ndat, axis=1, dtype=dtype)

    def test_out_dtype_error(self):
        # 测试输出数据类型错误的情况
        for f in self.nanfuncs:
            for dtype in [np.bool, np.int_, np.object_]:
                # 对于每个 NaN 函数和指定的数据类型
                out = np.empty(_ndat.shape[0], dtype=dtype)
                assert_raises(TypeError, f, _ndat, axis=1, out=out)

    def test_ddof(self):
        # 测试自由度参数 ddof 的影响
        nanfuncs = [np.nanvar, np.nanstd]
        stdfuncs = [np.var, np.std]
        for nf, rf in zip(nanfuncs, stdfuncs):
            for ddof in [0, 1]:
                # 对于每个 NaN 方差和标准差函数,以及不同的 ddof 值
                tgt = [rf(d, ddof=ddof) for d in _rdat]
                res = nf(_ndat, axis=1, ddof=ddof)
                assert_almost_equal(res, tgt)

    def test_ddof_too_big(self):
        # 测试 ddof 参数过大的情况
        nanfuncs = [np.nanvar, np.nanstd]
        stdfuncs = [np.var, np.std]
        dsize = [len(d) for d in _rdat]
        for nf, rf in zip(nanfuncs, stdfuncs):
            for ddof in range(5):
                # 对于每个 NaN 方差和标准差函数,以及不同的 ddof 值
                with suppress_warnings() as sup:
                    sup.record(RuntimeWarning)
                    sup.filter(ComplexWarning)
                    tgt = [ddof >= d for d in dsize]
                    res = nf(_ndat, axis=1, ddof=ddof)
                    assert_equal(np.isnan(res), tgt)
                    if any(tgt):
                        assert_(len(sup.log) == 1)
                    else:
                        assert_(len(sup.log) == 0)

    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["AllFloat"])
    @pytest.mark.parametrize("array", [
        np.array(np.nan),
        np.full((3, 3), np.nan),
    ], ids=["0d", "2d"])
    def test_allnans(self, axis, dtype, array):
        # 测试所有元素为 NaN 的情况
        if axis is not None and array.ndim == 0:
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        array = array.astype(dtype)
        match = "(Degrees of freedom <= 0 for slice.)|(Mean of empty slice)"
        for func in self.nanfuncs:
            with pytest.warns(RuntimeWarning, match=match):
                out = func(array, axis=axis)
            assert np.isnan(out).all()

            # `nanvar` and `nanstd` convert complex inputs to their
            # corresponding floating dtype
            if func is np.nanmean:
                assert out.dtype == array.dtype
            else:
                assert out.dtype == np.abs(array).dtype
    # 定义一个测试方法,测试处理空数组的情况
    def test_empty(self):
        # 创建一个形状为 (0, 3) 的全零数组
        mat = np.zeros((0, 3))
        # 遍历 nanfuncs 列表中的函数
        for f in self.nanfuncs:
            # 对于 axis 参数为 [0, None] 的情况
            for axis in [0, None]:
                # 使用 warnings.catch_warnings 捕获警告信息
                with warnings.catch_warnings(record=True) as w:
                    # 设置警告过滤器,捕获所有警告
                    warnings.simplefilter('always')
                    # 断言调用 f 函数处理 mat 时所有结果都是 NaN
                    assert_(np.isnan(f(mat, axis=axis)).all())
                    # 断言警告列表 w 的长度为 1
                    assert_(len(w) == 1)
                    # 断言第一个警告是 RuntimeWarning 的子类
                    assert_(issubclass(w[0].category, RuntimeWarning))
            # 对于 axis 参数为 1 的情况
            for axis in [1]:
                # 使用 warnings.catch_warnings 捕获警告信息
                with warnings.catch_warnings(record=True) as w:
                    # 设置警告过滤器,捕获所有警告
                    warnings.simplefilter('always')
                    # 断言调用 f 函数处理 mat 时结果与形状为 [] 的全零数组相等
                    assert_equal(f(mat, axis=axis), np.zeros([]))
                    # 断言警告列表 w 的长度为 0
                    assert_(len(w) == 0)

    # 使用 pytest.mark.parametrize 装饰器,参数为 np.typecodes["AllFloat"] 中的数据类型
    def test_where(self, dtype):
        # 创建一个形状为 (3, 3) 的数组 ar,转换为指定数据类型 dtype
        ar = np.arange(9).reshape(3, 3).astype(dtype)
        # 将第一行设置为 NaN
        ar[0, :] = np.nan
        # 创建一个与 ar 相同形状的布尔数组 where,并设置第一列为 False
        where = np.ones_like(ar, dtype=np.bool)
        where[:, 0] = False

        # 遍历 nanfuncs 和 stdfuncs 列表中的函数
        for f, f_std in zip(self.nanfuncs, self.stdfuncs):
            # 使用 where 数组的条件,对 ar 的第三行及之后的数据应用 f_std 函数计算参考值
            reference = f_std(ar[where][2:])
            # 如果 f 是 np.nanmean,则使用指定数据类型 dtype 作为参考值的数据类型
            dtype_reference = dtype if f is np.nanmean else ar.real.dtype

            # 调用 f 函数处理 ar 和 where 数组
            ret = f(ar, where=where)
            # 断言返回结果的数据类型与 dtype_reference 相同
            assert ret.dtype == dtype_reference
            # 使用 np.testing.assert_allclose 断言返回结果与 reference 的接近程度

    # 定义一个测试方法,测试带有 mean 关键字参数的 np.nanstd 函数
    def test_nanstd_with_mean_keyword(self):
        # 设置随机种子以保证测试的可复现性
        rng = np.random.RandomState(1234)
        # 创建一个形状为 (10, 20, 5) 的随机数组 A,并添加 NaN 值
        A = rng.randn(10, 20, 5) + 0.5
        A[:, 5, :] = np.nan

        # 创建形状为 (10, 1, 5) 的全零数组 mean_out 和 std_out
        mean_out = np.zeros((10, 1, 5))
        std_out = np.zeros((10, 1, 5))

        # 使用 np.nanmean 计算 A 的均值,输出到 mean_out,沿 axis=1,保持维度为 True
        mean = np.nanmean(A,
                          out=mean_out,
                          axis=1,
                          keepdims=True)

        # 断言 mean_out 与 mean 是同一个对象
        assert mean_out is mean

        # 使用 np.nanstd 计算 A 的标准差,输出到 std_out,沿 axis=1,保持维度为 True,指定 mean 参数为 mean
        std = np.nanstd(A,
                        out=std_out,
                        axis=1,
                        keepdims=True,
                        mean=mean)

        # 断言 std_out 与 std 是同一个对象
        assert std_out is std

        # 断言 mean 和 std 的形状相同,应为 (10, 1, 5)
        assert std.shape == mean.shape
        assert std.shape == (10, 1, 5)

        # 使用 np.nanstd 计算 A 的标准差,沿 axis=1,保持维度为 True,作为旧方法的参考值
        std_old = np.nanstd(A, axis=1, keepdims=True)

        # 断言 std_old 与 mean 的形状相同
        assert std_old.shape == mean.shape
        # 使用 assert_almost_equal 断言 std 与 std_old 的接近程度
_TIME_UNITS = (
    "Y", "M", "W", "D", "h", "m", "s", "ms", "us", "ns", "ps", "fs", "as"
)
# 定义了时间单位的元组

# All `inexact` + `timdelta64` type codes
_TYPE_CODES = list(np.typecodes["AllFloat"])
_TYPE_CODES += [f"m8[{unit}]" for unit in _TIME_UNITS]
# 将所有浮点数类型代码添加到 _TYPE_CODES 列表中,同时添加了时间单位对应的 m8[unit] 类型代码

class TestNanFunctions_Median:

    def test_mutation(self):
        # 检查传递的数组未被修改
        ndat = _ndat.copy()
        np.nanmedian(ndat)
        assert_equal(ndat, _ndat)

    def test_keepdims(self):
        mat = np.eye(3)
        for axis in [None, 0, 1]:
            tgt = np.median(mat, axis=axis, out=None, overwrite_input=False)
            res = np.nanmedian(mat, axis=axis, out=None, overwrite_input=False)
            assert_(res.ndim == tgt.ndim)

        d = np.ones((3, 5, 7, 11))
        # 随机将一些元素设为 NaN:
        w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
        w = w.astype(np.intp)
        d[tuple(w)] = np.nan
        with suppress_warnings() as sup:
            sup.filter(RuntimeWarning)
            res = np.nanmedian(d, axis=None, keepdims=True)
            assert_equal(res.shape, (1, 1, 1, 1))
            res = np.nanmedian(d, axis=(0, 1), keepdims=True)
            assert_equal(res.shape, (1, 1, 7, 11))
            res = np.nanmedian(d, axis=(0, 3), keepdims=True)
            assert_equal(res.shape, (1, 5, 7, 1))
            res = np.nanmedian(d, axis=(1,), keepdims=True)
            assert_equal(res.shape, (3, 1, 7, 11))
            res = np.nanmedian(d, axis=(0, 1, 2, 3), keepdims=True)
            assert_equal(res.shape, (1, 1, 1, 1))
            res = np.nanmedian(d, axis=(0, 1, 3), keepdims=True)
            assert_equal(res.shape, (1, 1, 7, 1))

    @pytest.mark.parametrize(
        argnames='axis',
        argvalues=[
            None,
            1,
            (1, ),
            (0, 1),
            (-3, -1),
        ]
    )
    @pytest.mark.filterwarnings("ignore:All-NaN slice:RuntimeWarning")
    def test_keepdims_out(self, axis):
        d = np.ones((3, 5, 7, 11))
        # 随机将一些元素设为 NaN:
        w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
        w = w.astype(np.intp)
        d[tuple(w)] = np.nan
        if axis is None:
            shape_out = (1,) * d.ndim
        else:
            axis_norm = normalize_axis_tuple(axis, d.ndim)
            shape_out = tuple(
                1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
        out = np.empty(shape_out)
        result = np.nanmedian(d, axis=axis, keepdims=True, out=out)
        assert result is out
        assert_equal(result.shape, shape_out)
    def test_out(self):
        # 创建一个 3x3 的随机数矩阵
        mat = np.random.rand(3, 3)
        # 在矩阵中插入 NaN 值,每行插入两个 NaN
        nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
        # 创建一个全零的长度为 3 的数组
        resout = np.zeros(3)
        # 计算原始矩阵每行的中位数
        tgt = np.median(mat, axis=1)
        # 计算插入 NaN 值后的矩阵每行的中位数,结果存入 resout 中
        res = np.nanmedian(nan_mat, axis=1, out=resout)
        # 检查计算结果与预期是否几乎相等
        assert_almost_equal(res, resout)
        assert_almost_equal(res, tgt)
        
        # 对于零维输出:
        resout = np.zeros(())
        # 计算原始矩阵所有元素的中位数
        tgt = np.median(mat, axis=None)
        # 计算插入 NaN 值后的矩阵所有元素的中位数,结果存入 resout 中
        res = np.nanmedian(nan_mat, axis=None, out=resout)
        # 检查计算结果与预期是否几乎相等
        assert_almost_equal(res, resout)
        assert_almost_equal(res, tgt)
        
        # 计算插入 NaN 值后的矩阵在指定轴(0 和 1)上的中位数,结果存入 resout 中
        res = np.nanmedian(nan_mat, axis=(0, 1), out=resout)
        # 检查计算结果与预期是否几乎相等
        assert_almost_equal(res, resout)
        assert_almost_equal(res, tgt)

    def test_small_large(self):
        # 测试小型和大型代码路径,当前截断为 400 个元素
        for s in [5, 20, 51, 200, 1000]:
            # 创建一个大小为 4x(s+1) 的随机数矩阵
            d = np.random.randn(4, s)
            # 随机将部分元素设为 NaN
            w = np.random.randint(0, d.size, size=d.size // 5)
            d.ravel()[w] = np.nan
            d[:, 0] = 1.  # 确保至少有一个有效值
            # 使用没有 NaN 的普通中位数进行比较
            tgt = []
            for x in d:
                nonan = np.compress(~np.isnan(x), x)
                tgt.append(np.median(nonan, overwrite_input=True))
            
            # 检查 np.nanmedian 函数计算结果与预期是否相等
            assert_array_equal(np.nanmedian(d, axis=-1), tgt)

    def test_result_values(self):
        # 计算 _ndat 沿第二个轴的每行的中位数
        tgt = [np.median(d) for d in _rdat]
        res = np.nanmedian(_ndat, axis=1)
        # 检查计算结果与预期是否几乎相等
        assert_almost_equal(res, tgt)

    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", _TYPE_CODES)
    def test_allnans(self, dtype, axis):
        # 创建一个全为 NaN 的 3x3 数组,并转换为指定的 dtype
        mat = np.full((3, 3), np.nan).astype(dtype)
        with suppress_warnings() as sup:
            sup.record(RuntimeWarning)

            # 计算 mat 在指定轴上的中位数
            output = np.nanmedian(mat, axis=axis)
            # 检查输出的数据类型与 mat 的数据类型是否相等
            assert output.dtype == mat.dtype
            # 检查输出是否全为 NaN
            assert np.isnan(output).all()

            if axis is None:
                # 如果 axis 为 None,检查警告记录数是否为 1
                assert_(len(sup.log) == 1)
            else:
                # 如果 axis 不为 None,检查警告记录数是否为 3
                assert_(len(sup.log) == 3)

            # 检查标量情况下的中位数计算
            scalar = np.array(np.nan).astype(dtype)[()]
            output_scalar = np.nanmedian(scalar)
            # 检查输出的数据类型与标量的数据类型是否相等
            assert output_scalar.dtype == scalar.dtype
            # 检查输出是否为 NaN
            assert np.isnan(output_scalar)

            if axis is None:
                # 如果 axis 为 None,检查警告记录数是否为 2
                assert_(len(sup.log) == 2)
            else:
                # 如果 axis 不为 None,检查警告记录数是否为 4
                assert_(len(sup.log) == 4)
    # 定义测试函数,测试处理空数组时的行为
    def test_empty(self):
        # 创建一个空的 0x3 的 NumPy 数组
        mat = np.zeros((0, 3))
        # 针对不同的轴进行循环测试
        for axis in [0, None]:
            # 捕获警告信息
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')
                # 使用 np.nanmedian 计算空数组的中位数,并检查是否全为 NaN
                assert_(np.isnan(np.nanmedian(mat, axis=axis)).all())
                # 断言捕获到一条警告
                assert_(len(w) == 1)
                # 断言该警告是 RuntimeWarning 的子类
                assert_(issubclass(w[0].category, RuntimeWarning))
        # 针对 axis=1 的情况进行测试
        for axis in [1]:
            # 捕获警告信息
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')
                # 使用 np.nanmedian 计算空数组的中位数,并与空数组 [] 进行比较
                assert_equal(np.nanmedian(mat, axis=axis), np.zeros([]))
                # 断言没有捕获到任何警告
                assert_(len(w) == 0)

    # 定义测试函数,测试处理标量时的行为
    def test_scalar(self):
        # 断言 np.nanmedian(0.) 的结果等于 0.
        assert_(np.nanmedian(0.) == 0.)

    # 定义测试函数,测试处理超出范围的轴参数时的行为
    def test_extended_axis_invalid(self):
        # 创建一个形状为 (3, 5, 7, 11) 的全为 1 的 NumPy 数组
        d = np.ones((3, 5, 7, 11))
        # 断言处理超出负轴索引的 AxisError 异常
        assert_raises(AxisError, np.nanmedian, d, axis=-5)
        # 断言处理包含负轴索引的 AxisError 异常
        assert_raises(AxisError, np.nanmedian, d, axis=(0, -5))
        # 断言处理超出正轴索引的 AxisError 异常
        assert_raises(AxisError, np.nanmedian, d, axis=4)
        # 断言处理包含超出正轴索引的 AxisError 异常
        assert_raises(AxisError, np.nanmedian, d, axis=(0, 4))
        # 断言处理重复轴的 ValueError 异常
        assert_raises(ValueError, np.nanmedian, d, axis=(1, 1))
    # 定义测试函数,用于测试处理特殊浮点数情况的函数
    def test_float_special(self):
        # 使用 suppress_warnings 上下文管理器,过滤掉 RuntimeWarning 警告
        with suppress_warnings() as sup:
            sup.filter(RuntimeWarning)
            
            # 对于正无穷和负无穷两种情况进行迭代测试
            for inf in [np.inf, -np.inf]:
                # 创建包含特殊值的二维数组 a
                a = np.array([[inf,  np.nan], [np.nan, np.nan]])
                # 检查按列计算忽略 NaN 后的中位数是否符合预期
                assert_equal(np.nanmedian(a, axis=0), [inf,  np.nan])
                # 检查按行计算忽略 NaN 后的中位数是否符合预期
                assert_equal(np.nanmedian(a, axis=1), [inf,  np.nan])
                # 检查忽略 NaN 后的整体中位数是否符合预期
                assert_equal(np.nanmedian(a), inf)
                
                # 最小填充值检查
                a = np.array([[np.nan, np.nan, inf],
                             [np.nan, np.nan, inf]])
                # 检查忽略 NaN 后的整体中位数是否符合预期
                assert_equal(np.nanmedian(a), inf)
                # 检查按列计算忽略 NaN 后的中位数是否符合预期
                assert_equal(np.nanmedian(a, axis=0), [np.nan, np.nan, inf])
                # 检查按行计算忽略 NaN 后的中位数是否符合预期
                assert_equal(np.nanmedian(a, axis=1), inf)
                
                # 无遮罩路径
                a = np.array([[inf, inf], [inf, inf]])
                # 检查按行计算忽略 NaN 后的中位数是否符合预期
                assert_equal(np.nanmedian(a, axis=1), inf)
                
                # 创建包含特殊浮点数的二维数组 a
                a = np.array([[inf, 7, -inf, -9],
                              [-10, np.nan, np.nan, 5],
                              [4, np.nan, np.nan, inf]],
                              dtype=np.float32)
                # 根据正无穷的值进行条件判断,检查按列计算忽略 NaN 后的中位数是否符合预期
                if inf > 0:
                    assert_equal(np.nanmedian(a, axis=0), [4., 7., -inf, 5.])
                    assert_equal(np.nanmedian(a), 4.5)
                else:
                    assert_equal(np.nanmedian(a, axis=0), [-10., 7., -inf, -9.])
                    assert_equal(np.nanmedian(a), -2.5)
                # 检查按行计算忽略 NaN 后的中位数是否符合预期
                assert_equal(np.nanmedian(a, axis=-1), [-1., -2.5, inf])
                
                # 针对不同长度的 i 和 j 进行迭代测试
                for i in range(0, 10):
                    for j in range(1, 10):
                        # 创建特殊值数组 a
                        a = np.array([([np.nan] * i) + ([inf] * j)] * 2)
                        # 检查忽略 NaN 后的整体中位数是否符合预期
                        assert_equal(np.nanmedian(a), inf)
                        # 检查按行计算忽略 NaN 后的中位数是否符合预期
                        assert_equal(np.nanmedian(a, axis=1), inf)
                        # 检查按列计算忽略 NaN 后的中位数是否符合预期
                        assert_equal(np.nanmedian(a, axis=0),
                                     ([np.nan] * i) + [inf] * j)
                        
                        # 创建特殊值数组 a
                        a = np.array([([np.nan] * i) + ([-inf] * j)] * 2)
                        # 检查忽略 NaN 后的整体中位数是否符合预期
                        assert_equal(np.nanmedian(a), -inf)
                        # 检查按行计算忽略 NaN 后的中位数是否符合预期
                        assert_equal(np.nanmedian(a, axis=1), -inf)
                        # 检查按列计算忽略 NaN 后的中位数是否符合预期
                        assert_equal(np.nanmedian(a, axis=0),
                                     ([np.nan] * i) + [-inf] * j)
class TestNanFunctions_Percentile:

    def test_mutation(self):
        # 检查传入的数组是否被修改
        ndat = _ndat.copy()  # 复制_ndat数组的副本,确保不改变原始数据
        np.nanpercentile(ndat, 30)  # 计算ndat数组的30th百分位数,忽略NaN值
        assert_equal(ndat, _ndat)  # 断言复制后的数组与原始数组相等,验证原始数据未被修改

    def test_keepdims(self):
        mat = np.eye(3)  # 创建一个3x3的单位矩阵
        for axis in [None, 0, 1]:
            tgt = np.percentile(mat, 70, axis=axis, out=None,
                                overwrite_input=False)
            res = np.nanpercentile(mat, 70, axis=axis, out=None,
                                   overwrite_input=False)
            assert_(res.ndim == tgt.ndim)  # 断言计算结果的维度与目标维度相等

        d = np.ones((3, 5, 7, 11))  # 创建一个全为1的4维数组
        # 随机将一些元素设为NaN:
        w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
        w = w.astype(np.intp)
        d[tuple(w)] = np.nan  # 在d数组中随机设置一些元素为NaN
        with suppress_warnings() as sup:
            sup.filter(RuntimeWarning)
            # 测试不同轴上的百分位数计算,保持维度为True
            res = np.nanpercentile(d, 90, axis=None, keepdims=True)
            assert_equal(res.shape, (1, 1, 1, 1))
            res = np.nanpercentile(d, 90, axis=(0, 1), keepdims=True)
            assert_equal(res.shape, (1, 1, 7, 11))
            res = np.nanpercentile(d, 90, axis=(0, 3), keepdims=True)
            assert_equal(res.shape, (1, 5, 7, 1))
            res = np.nanpercentile(d, 90, axis=(1,), keepdims=True)
            assert_equal(res.shape, (3, 1, 7, 11))
            res = np.nanpercentile(d, 90, axis=(0, 1, 2, 3), keepdims=True)
            assert_equal(res.shape, (1, 1, 1, 1))
            res = np.nanpercentile(d, 90, axis=(0, 1, 3), keepdims=True)
            assert_equal(res.shape, (1, 1, 7, 1))

    @pytest.mark.parametrize('q', [7, [1, 7]])
    @pytest.mark.parametrize(
        argnames='axis',
        argvalues=[
            None,
            1,
            (1,),
            (0, 1),
            (-3, -1),
        ]
    )
    @pytest.mark.filterwarnings("ignore:All-NaN slice:RuntimeWarning")
    def test_keepdims_out(self, q, axis):
        d = np.ones((3, 5, 7, 11))  # 创建一个全为1的4维数组
        # 随机将一些元素设为NaN:
        w = np.random.random((4, 200)) * np.array(d.shape)[:, None]
        w = w.astype(np.intp)
        d[tuple(w)] = np.nan  # 在d数组中随机设置一些元素为NaN
        if axis is None:
            shape_out = (1,) * d.ndim  # 如果axis为None,输出形状为全1
        else:
            axis_norm = normalize_axis_tuple(axis, d.ndim)
            # 根据指定的轴计算输出形状
            shape_out = tuple(
                1 if i in axis_norm else d.shape[i] for i in range(d.ndim))
        shape_out = np.shape(q) + shape_out  # 在q的形状前加上计算得到的输出形状

        out = np.empty(shape_out)  # 创建一个空数组作为输出
        result = np.nanpercentile(d, q, axis=axis, keepdims=True, out=out)
        assert result is out  # 断言返回的结果与指定的输出数组相同
        assert_equal(result.shape, shape_out)  # 断言返回的结果的形状与预期的形状相同

    @pytest.mark.parametrize("weighted", [False, True])
    # 定义一个测试方法,用于测试特定条件下的函数行为
    def test_out(self, weighted):
        # 创建一个 3x3 的随机数矩阵
        mat = np.random.rand(3, 3)
        # 在矩阵中插入 NaN 值,构成一个包含 NaN 的矩阵
        nan_mat = np.insert(mat, [0, 2], np.nan, axis=1)
        # 创建一个长度为 3 的零向量,用于存储结果
        resout = np.zeros(3)
        # 根据权重条件选择参数
        if weighted:
            # 如果使用权重,定义带有权重和方法的参数字典
            w_args = {"weights": np.ones_like(mat), "method": "inverted_cdf"}
            nan_w_args = {
                "weights": np.ones_like(nan_mat), "method": "inverted_cdf"
            }
        else:
            # 否则,参数字典为空
            w_args = dict()
            nan_w_args = dict()
        # 计算 mat 矩阵的百分位数,返回目标数组
        tgt = np.percentile(mat, 42, axis=1, **w_args)
        # 计算 nan_mat 矩阵的带有 NaN 的百分位数,将结果存储到 resout 中
        res = np.nanpercentile(nan_mat, 42, axis=1, out=resout, **nan_w_args)
        # 断言结果近似相等
        assert_almost_equal(res, resout)
        # 断言结果近似相等
        assert_almost_equal(res, tgt)
        # 处理 0 维输出的情况:
        resout = np.zeros(())
        # 计算 mat 矩阵的全局百分位数,返回目标值
        tgt = np.percentile(mat, 42, axis=None, **w_args)
        # 计算 nan_mat 矩阵的带有 NaN 的全局百分位数,将结果存储到 resout 中
        res = np.nanpercentile(
            nan_mat, 42, axis=None, out=resout, **nan_w_args
        )
        # 断言结果近似相等
        assert_almost_equal(res, resout)
        # 断言结果近似相等
        assert_almost_equal(res, tgt)
        # 计算 nan_mat 矩阵在多轴 (0, 1) 上的带有 NaN 的百分位数,将结果存储到 resout 中
        res = np.nanpercentile(
            nan_mat, 42, axis=(0, 1), out=resout, **nan_w_args
        )
        # 断言结果近似相等
        assert_almost_equal(res, resout)
        # 断言结果近似相等
        assert_almost_equal(res, tgt)

    # 定义一个测试复杂情况的方法
    def test_complex(self):
        # 创建一个复数数组,测试在复数数组上调用 nanpercentile 会引发 TypeError
        arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='G')
        assert_raises(TypeError, np.nanpercentile, arr_c, 0.5)
        arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='D')
        assert_raises(TypeError, np.nanpercentile, arr_c, 0.5)
        arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='F')
        assert_raises(TypeError, np.nanpercentile, arr_c, 0.5)

    # 使用 pytest 的参数化装饰器,定义测试不同参数组合下的函数行为
    @pytest.mark.parametrize("weighted", [False, True])
    @pytest.mark.parametrize("use_out", [False, True])
    def test_result_values(self, weighted, use_out):
        # 根据 weighted 参数选择相应的百分位数函数和生成权重的函数
        if weighted:
            percentile = partial(np.percentile, method="inverted_cdf")
            nanpercentile = partial(np.nanpercentile, method="inverted_cdf")

            def gen_weights(d):
                return np.ones_like(d)

        else:
            percentile = np.percentile
            nanpercentile = np.nanpercentile

            def gen_weights(d):
                return None

        # 对给定数据集 _rdat 计算目标百分位数,并存储到 tgt 中
        tgt = [percentile(d, 28, weights=gen_weights(d)) for d in _rdat]
        # 根据 use_out 参数决定是否使用 out 参数
        out = np.empty_like(tgt) if use_out else None
        # 计算 _ndat 数据集的带有 NaN 的百分位数,存储到 res 中
        res = nanpercentile(_ndat, 28, axis=1,
                            weights=gen_weights(_ndat), out=out)
        # 断言结果近似相等
        assert_almost_equal(res, tgt)
        # 将结果数组转置以符合 numpy.percentile 的输出约定
        tgt = np.transpose([percentile(d, (28, 98), weights=gen_weights(d))
                            for d in _rdat])
        # 根据 use_out 参数决定是否使用 out 参数
        out = np.empty_like(tgt) if use_out else None
        # 计算 _ndat 数据集的带有 NaN 的多轴 (1, 2) 百分位数,存储到 res 中
        res = nanpercentile(_ndat, (28, 98), axis=1,
                            weights=gen_weights(_ndat), out=out)
        # 断言结果近似相等
        assert_almost_equal(res, tgt)

    # 使用 pytest 的参数化装饰器,测试不同的轴和浮点数类型
    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["Float"])
    # 使用 pytest 的参数化装饰器,为测试方法 test_allnans 提供不同的输入数组
    @pytest.mark.parametrize("array", [
        # 创建一个包含单个 NaN 值的 NumPy 数组
        np.array(np.nan),
        # 创建一个 3x3 的 NumPy 数组,每个元素都是 NaN
        np.full((3, 3), np.nan),
    ], ids=["0d", "2d"])
    # 测试所有元素为 NaN 的情况
    def test_allnans(self, axis, dtype, array):
        # 如果指定了 axis 且数组维度为 0,则跳过测试,并给出相应提示
        if axis is not None and array.ndim == 0:
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        # 将数组转换为指定的数据类型
        array = array.astype(dtype)
        # 在执行计算时,捕获所有的 RuntimeWarning,其中包含 All-NaN slice encountered 的警告
        with pytest.warns(RuntimeWarning, match="All-NaN slice encountered"):
            # 计算数组的第 60 百分位数,可以指定计算的轴
            out = np.nanpercentile(array, 60, axis=axis)
        # 断言计算结果中所有元素都是 NaN
        assert np.isnan(out).all()
        # 断言计算结果的数据类型与原始数组的数据类型相同
        assert out.dtype == array.dtype

    # 测试空数组的情况
    def test_empty(self):
        # 创建一个空的 0x3 的 NumPy 数组
        mat = np.zeros((0, 3))
        # 分别测试 axis 为 0 和 None 的情况
        for axis in [0, None]:
            # 在测试期间捕获所有警告
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')
                # 断言对空数组计算百分位数时所有元素都是 NaN
                assert_(np.isnan(np.nanpercentile(mat, 40, axis=axis)).all())
                # 断言捕获到一条警告
                assert_(len(w) == 1)
                # 断言该警告是 RuntimeWarning 类型的
                assert_(issubclass(w[0].category, RuntimeWarning))
        # 对于 axis 为 1 的情况,不应有警告产生
        for axis in [1]:
            with warnings.catch_warnings(record=True) as w:
                warnings.simplefilter('always')
                # 断言对空数组在 axis=1 上计算百分位数的结果是一个空的数组
                assert_equal(np.nanpercentile(mat, 40, axis=axis), np.zeros([]))
                # 断言没有捕获到任何警告
                assert_(len(w) == 0)

    # 测试标量输入的情况
    def test_scalar(self):
        # 断言对标量 0 计算百分位数得到的结果是 0
        assert_equal(np.nanpercentile(0., 100), 0.)
        # 创建一个 0 到 5 的数组
        a = np.arange(6)
        # 计算数组的第 50 百分位数
        r = np.nanpercentile(a, 50, axis=0)
        # 断言计算结果为 2.5
        assert_equal(r, 2.5)
        # 断言计算结果是一个标量
        assert_(np.isscalar(r))

    # 测试扩展轴参数无效的情况
    def test_extended_axis_invalid(self):
        # 创建一个形状为 (3, 5, 7, 11) 的全为 1 的数组
        d = np.ones((3, 5, 7, 11))
        # 断言对于超出范围的轴索引会抛出 AxisError
        assert_raises(AxisError, np.nanpercentile, d, q=5, axis=-5)
        # 断言对于同时指定有效和无效轴索引的情况会抛出 AxisError
        assert_raises(AxisError, np.nanpercentile, d, q=5, axis=(0, -5))
        # 断言对于超出范围的轴索引会抛出 AxisError
        assert_raises(AxisError, np.nanpercentile, d, q=5, axis=4)
        # 断言对于同时指定有效和超出范围的轴索引的情况会抛出 AxisError
        assert_raises(AxisError, np.nanpercentile, d, q=5, axis=(0, 4))
        # 断言当指定轴索引为重复时会抛出 ValueError
        assert_raises(ValueError, np.nanpercentile, d, q=5, axis=(1, 1))
    # 定义一个测试方法,用于测试多个百分位数的计算
    def test_multiple_percentiles(self):
        # 设定百分位数的列表
        perc = [50, 100]
        # 创建一个4x3的全1矩阵
        mat = np.ones((4, 3))
        # 创建一个与mat相同大小的NaN矩阵
        nan_mat = np.nan * mat
        # 在更高维度情况下检查一致性
        large_mat = np.ones((3, 4, 5))
        # 将large_mat的第1维和第3维的特定切片置为0
        large_mat[:, 0:2:4, :] = 0
        # 将large_mat的第3维后的所有元素乘以2
        large_mat[:, :, 3:] *= 2

        # 遍历不同的轴和保持维度的选项
        for axis in [None, 0, 1]:
            for keepdim in [False, True]:
                # 使用suppress_warnings上下文管理器以过滤特定的运行时警告
                with suppress_warnings() as sup:
                    # 过滤掉特定的运行时警告信息
                    sup.filter(RuntimeWarning, "All-NaN slice encountered")
                    # 计算mat的百分位数,返回值val
                    val = np.percentile(mat, perc, axis=axis, keepdims=keepdim)
                    # 计算nan_mat的百分位数,返回值nan_val
                    nan_val = np.nanpercentile(nan_mat, perc, axis=axis,
                                               keepdims=keepdim)
                    # 断言nan_val的形状与val的形状相同
                    assert_equal(nan_val.shape, val.shape)

                    # 计算large_mat的百分位数,返回值val
                    val = np.percentile(large_mat, perc, axis=axis,
                                        keepdims=keepdim)
                    # 计算large_mat中NaN值排除后的百分位数,返回值nan_val
                    nan_val = np.nanpercentile(large_mat, perc, axis=axis,
                                               keepdims=keepdim)
                    # 断言nan_val等于val
                    assert_equal(nan_val, val)

        # 创建一个更大的矩阵megamat,形状为3x4x5x6
        megamat = np.ones((3, 4, 5, 6))
        # 断言计算megamat在指定轴(1, 2)上的NaN值排除后的百分位数的形状
        assert_equal(
            np.nanpercentile(megamat, perc, axis=(1, 2)).shape, (2, 3, 6)
        )

    # 使用pytest的参数化标记定义一个测试方法,用于测试带有权重的NaN值处理
    @pytest.mark.parametrize("nan_weight", [0, 1, 2, 3, 1e200])
    def test_nan_value_with_weight(self, nan_weight):
        # 创建一个包含NaN的列表x
        x = [1, np.nan, 2, 3]
        # 预期的非NaN位置上的结果
        result = np.float64(2.0)
        # 计算未加权情况下的百分位数,返回值q_unweighted
        q_unweighted = np.nanpercentile(x, 50, method="inverted_cdf")
        # 断言q_unweighted等于预期结果result
        assert_equal(q_unweighted, result)

        # 创建一个权重列表w,在NaN位置处的权重值为nan_weight
        w = [1.0, nan_weight, 1.0, 1.0]
        # 计算带权重情况下的百分位数,返回值q_weighted
        q_weighted = np.nanpercentile(x, 50, weights=w, method="inverted_cdf")
        # 断言q_weighted等于预期结果result
        assert_equal(q_weighted, result)
    # 定义一个测试方法,用于测试带有权重和多维数组的 NaN 值处理
    def test_nan_value_with_weight_ndim(self, axis):
        # 创建一个多维数组进行测试
        np.random.seed(1)
        x_no_nan = np.random.random(size=(100, 99, 2))
        
        # 将部分位置设置为 NaN(不是特别聪明的做法),以确保始终存在非 NaN 值
        x = x_no_nan.copy()
        x[np.arange(99), np.arange(99), 0] = np.nan

        # 设置权重为全 1 数组,但在下面的 NaN 位置用 0 或 1e200 替换
        weights = np.ones_like(x)

        # 对比使用带有 NaN 权重的加权正常百分位,其中 NaN 位置的权重为 0(没有 NaN)
        weights[np.isnan(x)] = 0
        p_expected = np.percentile(
            x_no_nan, p, axis=axis, weights=weights, method="inverted_cdf")

        # 使用 np.nanpercentile 计算未加权的百分位
        p_unweighted = np.nanpercentile(
            x, p, axis=axis, method="inverted_cdf")
        
        # 正常版本和未加权版本应该是相同的:
        assert_equal(p_unweighted, p_expected)

        # 将 NaN 位置的权重设置为 1e200(一个很大的值,不应影响结果)
        weights[np.isnan(x)] = 1e200
        p_weighted = np.nanpercentile(
            x, p, axis=axis, weights=weights, method="inverted_cdf")
        
        # 断言加权版本的结果与预期结果相等
        assert_equal(p_weighted, p_expected)

        # 还可以传递输出数组进行检查:
        out = np.empty_like(p_weighted)
        res = np.nanpercentile(
            x, p, axis=axis, weights=weights, out=out, method="inverted_cdf")
        
        # 断言结果数组是传递的输出数组,并且其内容与预期结果相等
        assert res is out
        assert_equal(out, p_expected)
class TestNanFunctions_Quantile:
    # most of this is already tested by TestPercentile

    @pytest.mark.parametrize("weighted", [False, True])
    def test_regression(self, weighted):
        # 创建一个3维的浮点数数组,形状为(2, 3, 4),数值为0到23
        ar = np.arange(24).reshape(2, 3, 4).astype(float)
        # 将第一个子数组的第二个子数组全部设为NaN
        ar[0][1] = np.nan
        # 根据weighted参数设置权重参数w_args
        if weighted:
            w_args = {"weights": np.ones_like(ar), "method": "inverted_cdf"}
        else:
            w_args = dict()

        # 断言np.nanquantile和np.nanpercentile的结果相等
        assert_equal(np.nanquantile(ar, q=0.5, **w_args),
                     np.nanpercentile(ar, q=50, **w_args))
        assert_equal(np.nanquantile(ar, q=0.5, axis=0, **w_args),
                     np.nanpercentile(ar, q=50, axis=0, **w_args))
        assert_equal(np.nanquantile(ar, q=0.5, axis=1, **w_args),
                     np.nanpercentile(ar, q=50, axis=1, **w_args))
        assert_equal(np.nanquantile(ar, q=[0.5], axis=1, **w_args),
                     np.nanpercentile(ar, q=[50], axis=1, **w_args))
        assert_equal(np.nanquantile(ar, q=[0.25, 0.5, 0.75], axis=1, **w_args),
                     np.nanpercentile(ar, q=[25, 50, 75], axis=1, **w_args))

    def test_basic(self):
        # 创建一个包含8个元素的浮点数数组
        x = np.arange(8) * 0.5
        # 断言np.nanquantile的结果与预期相等
        assert_equal(np.nanquantile(x, 0), 0.)
        assert_equal(np.nanquantile(x, 1), 3.5)
        assert_equal(np.nanquantile(x, 0.5), 1.75)

    def test_complex(self):
        # 创建一个复数数组,包含三个复数元素
        arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='G')
        # 断言对于复数数组,调用np.nanquantile会引发TypeError异常
        assert_raises(TypeError, np.nanquantile, arr_c, 0.5)
        arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='D')
        assert_raises(TypeError, np.nanquantile, arr_c, 0.5)
        arr_c = np.array([0.5+3.0j, 2.1+0.5j, 1.6+2.3j], dtype='F')
        assert_raises(TypeError, np.nanquantile, arr_c, 0.5)

    def test_no_p_overwrite(self):
        # 这个测试值得重新测试,因为quantile函数不会创建副本
        p0 = np.array([0, 0.75, 0.25, 0.5, 1.0])
        p = p0.copy()
        # 调用np.nanquantile,验证参数p是否被修改
        np.nanquantile(np.arange(100.), p, method="midpoint")
        assert_array_equal(p, p0)

        p0 = p0.tolist()
        p = p.tolist()
        # 调用np.nanquantile,验证参数p是否被修改
        np.nanquantile(np.arange(100.), p, method="midpoint")
        assert_array_equal(p, p0)

    @pytest.mark.parametrize("axis", [None, 0, 1])
    @pytest.mark.parametrize("dtype", np.typecodes["Float"])
    @pytest.mark.parametrize("array", [
        np.array(np.nan),
        np.full((3, 3), np.nan),
    ], ids=["0d", "2d"])
    def test_allnans(self, axis, dtype, array):
        if axis is not None and array.ndim == 0:
            # 对于0维数组,不支持axis参数不为None的情况,跳过测试
            pytest.skip(f"`axis != None` not supported for 0d arrays")

        # 将array转换为指定dtype
        array = array.astype(dtype)
        # 断言调用np.nanquantile后,返回的结果全为NaN,并且dtype与array一致
        with pytest.warns(RuntimeWarning, match="All-NaN slice encountered"):
            out = np.nanquantile(array, 1, axis=axis)
        assert np.isnan(out).all()
        assert out.dtype == array.dtype
    (np.array([1, 5, 7, 9], dtype=np.int64),
     True),
    # 创建一个包含整数的一维数组,数据类型为64位整数,不包含 NaN
    (np.array([False, True, False, True]),
     True),
    # 创建一个包含布尔值的一维数组,所有值都是布尔类型,不包含 NaN
    (np.array([[np.nan, 5.0],
               [np.nan, np.inf]], dtype=np.complex64),
     np.array([[False, True],
               [False, True]])),
    # 创建一个包含复数的二维数组,数据类型为64位复数,包含 NaN 和无穷大(inf)
    # 同时创建一个布尔类型的二维数组,标识对应位置是否包含 NaN 或无穷大
    ])
# 测试函数,验证 _nan_mask 函数的行为是否符合预期
def test__nan_mask(arr, expected):
    # 针对两种输出情况进行循环测试
    for out in [None, np.empty(arr.shape, dtype=np.bool)]:
        # 调用 _nan_mask 函数计算实际结果
        actual = _nan_mask(arr, out=out)
        # 断言实际结果与期望结果相等
        assert_equal(actual, expected)
        # 如果期望结果不是 np.ndarray 类型,则需进一步验证 actual 是否为 True
        # 用于无法包含 NaN 的数据类型,确保 actual 是 True 而非 True 数组
        if type(expected) is not np.ndarray:
            assert actual is True


# 测试函数,验证 _replace_nan 函数的不同数据类型情况下的行为
def test__replace_nan():
    """ Test that _replace_nan returns the original array if there are no
    NaNs, not a copy.
    """
    # 针对不同的数据类型进行测试
    for dtype in [np.bool, np.int32, np.int64]:
        # 创建指定类型的数组 arr
        arr = np.array([0, 1], dtype=dtype)
        # 调用 _replace_nan 函数,替换 NaN,并获取结果及 mask
        result, mask = _replace_nan(arr, 0)
        # 断言 mask 为 None,表明没有 NaN 存在时不进行复制操作
        assert mask is None
        # 断言 result 与 arr 是同一个对象,即不进行复制
        assert result is arr

    # 针对浮点类型进行测试
    for dtype in [np.float32, np.float64]:
        # 创建指定类型的数组 arr
        arr = np.array([0, 1], dtype=dtype)
        # 调用 _replace_nan 函数,替换 NaN,并获取结果及 mask
        result, mask = _replace_nan(arr, 2)
        # 断言 mask 全为 False,表明没有 NaN 存在时不进行复制操作
        assert (mask == False).all()
        # 断言 result 不是 arr,表明需要进行复制操作
        assert result is not arr
        # 断言 result 与 arr 的内容相等
        assert_equal(result, arr)

        # 创建包含 NaN 的数组 arr_nan
        arr_nan = np.array([0, 1, np.nan], dtype=dtype)
        # 调用 _replace_nan 函数,替换 NaN,并获取结果及 mask
        result_nan, mask_nan = _replace_nan(arr_nan, 2)
        # 断言 mask_nan 的值与预期一致
        assert_equal(mask_nan, np.array([False, False, True]))
        # 断言 result_nan 不是 arr_nan,表明需要进行复制操作
        assert result_nan is not arr_nan
        # 断言 result_nan 的内容与预期一致
        assert_equal(result_nan, np.array([0, 1, 2]))
        # 断言 arr_nan 最后一个元素仍然为 NaN
        assert np.isnan(arr_nan[-1])

.\numpy\numpy\lib\tests\test_packbits.py

# 导入必要的库
import numpy as np  # 导入NumPy库
from numpy.testing import assert_array_equal, assert_equal, assert_raises  # 导入NumPy测试相关的函数和类
import pytest  # 导入pytest库
from itertools import chain  # 导入itertools库中的chain函数

# 定义测试函数:测试np.packbits函数
def test_packbits():
    # 定义测试数据a,这里使用了多维列表表示布尔值数组
    a = [[[1, 0, 1], [0, 1, 0]],
         [[1, 1, 0], [0, 0, 1]]]
    
    # 对于不同的数据类型dt进行测试
    for dt in '?bBhHiIlLqQ':
        # 将a转换为NumPy数组,指定数据类型为dt
        arr = np.array(a, dtype=dt)
        # 使用np.packbits函数进行位压缩,压缩轴为最后一个轴
        b = np.packbits(arr, axis=-1)
        # 断言压缩结果的数据类型为np.uint8
        assert_equal(b.dtype, np.uint8)
        # 断言压缩后的结果数组与预期的数组相等
        assert_array_equal(b, np.array([[[160], [64]], [[192], [32]]]))

    # 测试处理异常情况:输入数据类型为float时抛出TypeError异常
    assert_raises(TypeError, np.packbits, np.array(a, dtype=float))


# 定义测试函数:测试处理空数组的np.packbits函数
def test_packbits_empty():
    # 定义不同的空数组形状
    shapes = [
        (0,), (10, 20, 0), (10, 0, 20), (0, 10, 20), (20, 0, 0), (0, 20, 0),
        (0, 0, 20), (0, 0, 0),
    ]
    # 对于不同的数据类型dt和形状shape进行测试
    for dt in '?bBhHiIlLqQ':
        for shape in shapes:
            # 创建指定形状的空数组a,指定数据类型为dt
            a = np.empty(shape, dtype=dt)
            # 使用np.packbits函数对空数组a进行位压缩
            b = np.packbits(a)
            # 断言压缩结果的数据类型为np.uint8
            assert_equal(b.dtype, np.uint8)
            # 断言压缩后的结果数组形状为(0,)
            assert_equal(b.shape, (0,))


# 定义测试函数:测试带有轴参数的np.packbits函数处理空数组
def test_packbits_empty_with_axis():
    # 定义原始形状和不同轴的压缩后形状列表
    shapes = [
        ((0,), [(0,)]),
        ((10, 20, 0), [(2, 20, 0), (10, 3, 0), (10, 20, 0)]),
        ((10, 0, 20), [(2, 0, 20), (10, 0, 20), (10, 0, 3)]),
        ((0, 10, 20), [(0, 10, 20), (0, 2, 20), (0, 10, 3)]),
        ((20, 0, 0), [(3, 0, 0), (20, 0, 0), (20, 0, 0)]),
        ((0, 20, 0), [(0, 20, 0), (0, 3, 0), (0, 20, 0)]),
        ((0, 0, 20), [(0, 0, 20), (0, 0, 20), (0, 0, 3)]),
        ((0, 0, 0), [(0, 0, 0), (0, 0, 0), (0, 0, 0)]),
    ]
    # 对于不同的数据类型dt、输入形状in_shape和输出形状out_shape进行测试
    for dt in '?bBhHiIlLqQ':
        for in_shape, out_shapes in shapes:
            for ax, out_shape in enumerate(out_shapes):
                # 创建指定形状的空数组a,指定数据类型为dt
                a = np.empty(in_shape, dtype=dt)
                # 使用np.packbits函数对空数组a进行位压缩,指定压缩轴为ax
                b = np.packbits(a, axis=ax)
                # 断言压缩结果的数据类型为np.uint8
                assert_equal(b.dtype, np.uint8)
                # 断言压缩后的结果数组形状与预期的out_shape相等
                assert_equal(b.shape, out_shape)


@pytest.mark.parametrize('bitorder', ('little', 'big'))
# 定义测试函数:测试大数据量情况下的np.packbits函数
def test_packbits_large(bitorder):
    # test data large enough for 16 byte vectorization
    # 创建一个包含二进制整数数组的 NumPy 数组
    a = np.array([1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,
                  0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1,
                  1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0,
                  1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1,
                  1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1,
                  1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1,
                  1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1,
                  0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1,
                  1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
                  1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1,
                  1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
                  0, 1, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1,
                  1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0,
                  1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0,
                  1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0])
    # 将数组元素重复三次
    a = a.repeat(3)

    # 使用不同的数据类型测试以下代码的结果是否相同
    for dtype in 'bBhHiIlLqQ':
        # 将数组 a 转换为指定的数据类型
        arr = np.array(a, dtype=dtype)
        # 生成与 arr 相同大小的随机整数数组,范围在指定数据类型的最小值和最大值之间
        rnd = np.random.randint(low=np.iinfo(dtype).min,
                                high=np.iinfo(dtype).max, size=arr.size,
                                dtype=dtype)
        # 将 rnd 数组中的零值替换为1
        rnd[rnd == 0] = 1
        # 将 arr 数组与 rnd 数组中对应元素相乘,结果保持 dtype 数据类型
        arr *= rnd.astype(dtype)
        # 将 arr 数组打包成二进制位数组
        b = np.packbits(arr, axis=-1)
        # 验证解包后的结果是否与原始数组 a(去除末尾四个元素)相等
        assert_array_equal(np.unpackbits(b)[:-4], a)

    # 测试将浮点数数组作为输入时是否会引发 TypeError 异常
    assert_raises(TypeError, np.packbits, np.array(a, dtype=float))
def test_packbits_very_large():
    # 对 np.packbits 函数进行大数组测试,用于解决 gh-8637 中的问题
    # 大数组可能更容易触发潜在的 bug
    for s in range(950, 1050):
        # 遍历不同的数据类型
        for dt in '?bBhHiIlLqQ':
            # 创建一个形状为 (200, s) 的全为 True 的数组
            x = np.ones((200, s), dtype=bool)
            # 对数组 x 进行按位压缩,沿着 axis=1 的方向
            np.packbits(x, axis=1)


def test_unpackbits():
    # 从文档字符串中复制的示例
    # 创建一个包含 [[2], [7], [23]] 的 numpy 数组,数据类型为 uint8
    a = np.array([[2], [7], [23]], dtype=np.uint8)
    # 对数组 a 进行按位解压缩,沿着 axis=1 的方向
    b = np.unpackbits(a, axis=1)
    # 断言解压缩后的数组 b 的数据类型为 uint8
    assert_equal(b.dtype, np.uint8)
    # 断言解压缩后的数组 b 与给定的数组相等
    assert_array_equal(b, np.array([[0, 0, 0, 0, 0, 0, 1, 0],
                                    [0, 0, 0, 0, 0, 1, 1, 1],
                                    [0, 0, 0, 1, 0, 1, 1, 1]]))


def test_pack_unpack_order():
    # 创建一个包含 [[2], [7], [23]] 的 numpy 数组,数据类型为 uint8
    a = np.array([[2], [7], [23]], dtype=np.uint8)
    # 对数组 a 进行按位解压缩,沿着 axis=1 的方向
    b = np.unpackbits(a, axis=1)
    # 断言解压缩后的数组 b 的数据类型为 uint8
    assert_equal(b.dtype, np.uint8)
    # 使用 'big' 字节顺序对数组 a 进行解压缩,得到数组 b_big
    b_big = np.unpackbits(a, axis=1, bitorder='big')
    # 使用 'little' 字节顺序对数组 a 进行解压缩,得到数组 b_little
    b_little = np.unpackbits(a, axis=1, bitorder='little')
    # 断言数组 b 与数组 b_big 相等
    assert_array_equal(b, b_big)
    # 断言数组 a 与以 'little' 字节顺序对数组 b_little 进行按位压缩后的结果相等
    assert_array_equal(a, np.packbits(b_little, axis=1, bitorder='little'))
    # 断言数组 b 的逆序列与数组 b_little 相等
    assert_array_equal(b[:,::-1], b_little)
    # 断言数组 a 与以 'big' 字节顺序对数组 b_big 进行按位压缩后的结果相等
    assert_array_equal(a, np.packbits(b_big, axis=1, bitorder='big'))
    # 断言当 'bitorder' 参数为 'r' 时,解压缩函数会引发 ValueError 异常
    assert_raises(ValueError, np.unpackbits, a, bitorder='r')
    # 断言当 'bitorder' 参数为整数 10 时,解压缩函数会引发 TypeError 异常
    assert_raises(TypeError, np.unpackbits, a, bitorder=10)


def test_unpackbits_empty():
    # 创建一个空的 numpy 数组,数据类型为 uint8
    a = np.empty((0,), dtype=np.uint8)
    # 对数组 a 进行按位解压缩
    b = np.unpackbits(a)
    # 断言解压缩后的数组 b 的数据类型为 uint8
    assert_equal(b.dtype, np.uint8)
    # 断言解压缩后的数组 b 为空数组
    assert_array_equal(b, np.empty((0,)))


def test_unpackbits_empty_with_axis():
    # 不同轴上的打包形状列表和解包形状列表
    shapes = [
        ([(0,)], (0,)),
        ([(2, 24, 0), (16, 3, 0), (16, 24, 0)], (16, 24, 0)),
        ([(2, 0, 24), (16, 0, 24), (16, 0, 3)], (16, 0, 24)),
        ([(0, 16, 24), (0, 2, 24), (0, 16, 3)], (0, 16, 24)),
        ([(3, 0, 0), (24, 0, 0), (24, 0, 0)], (24, 0, 0)),
        ([(0, 24, 0), (0, 3, 0), (0, 24, 0)], (0, 24, 0)),
        ([(0, 0, 24), (0, 0, 24), (0, 0, 3)], (0, 0, 24)),
        ([(0, 0, 0), (0, 0, 0), (0, 0, 0)], (0, 0, 0)),
    ]
    # 遍历不同的形状对,并测试对应的解包操作
    for in_shapes, out_shape in shapes:
        for ax, in_shape in enumerate(in_shapes):
            # 创建一个空的 numpy 数组,形状为 in_shape,数据类型为 uint8
            a = np.empty(in_shape, dtype=np.uint8)
            # 对数组 a 进行按位解压缩,沿着指定的轴 ax
            b = np.unpackbits(a, axis=ax)
            # 断言解压缩后的数组 b 的数据类型为 uint8
            assert_equal(b.dtype, np.uint8)
            # 断言解压缩后的数组 b 的形状与预期的 out_shape 相等
            assert_equal(b.shape, out_shape)


def test_unpackbits_large():
    # 对所有可能的数字进行测试,通过与已经测试过的 packbits 进行比较
    d = np.arange(277, dtype=np.uint8)
    # 断言解压缩后再压缩的结果与原数组 d 相等
    assert_array_equal(np.packbits(np.unpackbits(d)), d)
    # 断言解压缩后再压缩的结果与 d[::2] 相等
    assert_array_equal(np.packbits(np.unpackbits(d[::2])), d[::2])
    # 将数组 d 在行方向上重复三次
    d = np.tile(d, (3, 1))
    # 断言解压缩后再压缩的结果与原始数组 d 相等,沿着 axis=1 的方向
    assert_array_equal(np.packbits(np.unpackbits(d, axis=1), axis=1), d)
    # 将数组 d 进行转置,并创建其副本
    d = d.T.copy()
    # 断言解压缩后再压缩的结果与原始数组 d 相等,沿着 axis=0 的方向
    assert_array_equal(np.packbits(np.unpackbits(d, axis=0), axis=0), d)


class TestCount():
    # 创建一个 7x7 的二维数组,元素为 0 或 1,数据类型为 uint8
    x = np.array([
        [1, 0, 1, 0, 0, 1, 0],
        [0, 1, 1, 1, 0, 0, 0],
        [0, 0, 1, 0, 0, 1, 1],
        [1, 1, 0, 0, 0, 1, 1],
        [1, 0, 1, 0, 1, 0, 1],
        [0, 0, 1, 1, 1, 0, 0],
        [0, 1, 0, 1, 0, 1, 0],
    ], dtype=np.uint8)
    # 创建一个长度为 57 的一维数组,元素为 0,数据类型为 uint8
    padded1 = np.zeros(57, dtype=np.uint8)
    # 将二维数组 x 按行展开并填充到 padded1 的前 49 个位置
    padded1[:49] = x.ravel()
    # 创建一个长度为 57 的一维数组,元素为 0,数据类型为 uint8
    padded1b = np.zeros(57, dtype=np.uint8)
    # 将二维数组 x 沿水平翻转后再按行展开并填充到 padded1b 的前 49 个位置
    padded1b[:49] = x[::-1].copy().ravel()
    # 创建一个 9x9 的二维数组,元素为 0,数据类型为 uint8
    padded2 = np.zeros((9, 9), dtype=np.uint8)
    # 将二维数组 x 按行、列展开并填充到 padded2 的前 7x7 个位置
    padded2[:7, :7] = x

    # 使用 pytest 的参数化装饰器,指定测试函数的参数化条件
    @pytest.mark.parametrize('bitorder', ('little', 'big'))
    @pytest.mark.parametrize('count', chain(range(58), range(-1, -57, -1)))
    # 定义测试函数 test_roundtrip,参数包括 bitorder 和 count
    def test_roundtrip(self, bitorder, count):
        # 如果 count 小于 0,则设定 cutoff 为 count - 1;否则设定为 count
        if count < 0:
            # 添加额外的零填充
            cutoff = count - 1
        else:
            cutoff = count
        # 对二维数组 x 进行位压缩(packbits),根据指定的 bitorder
        packed = np.packbits(self.x, bitorder=bitorder)
        # 对压缩后的数据进行位解压缩(unpackbits),根据指定的 count 和 bitorder
        unpacked = np.unpackbits(packed, count=count, bitorder=bitorder)
        # 断言解压缩后数据类型为 uint8
        assert_equal(unpacked.dtype, np.uint8)
        # 断言解压缩后的数组与 padded1 的前 cutoff 个元素相等
        assert_array_equal(unpacked, self.padded1[:cutoff])

    # 使用 pytest 的参数化装饰器,指定测试函数的参数化条件
    @pytest.mark.parametrize('kwargs', [
                    {}, {'count': None},
                    ])
    # 定义测试函数 test_count,参数为 kwargs
    def test_count(self, kwargs):
        # 对二维数组 x 进行位压缩
        packed = np.packbits(self.x)
        # 对压缩后的数据进行位解压缩,根据 kwargs 中的参数
        unpacked = np.unpackbits(packed, **kwargs)
        # 断言解压缩后数据类型为 uint8
        assert_equal(unpacked.dtype, np.uint8)
        # 断言解压缩后的数组与 padded1 的除最后一个元素外的所有元素相等
        assert_array_equal(unpacked, self.padded1[:-1])

    # 使用 pytest 的参数化装饰器,指定测试函数的参数化条件
    @pytest.mark.parametrize('bitorder', ('little', 'big'))
    # delta==-1 when count<0 because one extra zero of padding
    @pytest.mark.parametrize('count', chain(range(8), range(-1, -9, -1)))
    # 定义测试函数 test_roundtrip_axis,参数包括 bitorder 和 count
    def test_roundtrip_axis(self, bitorder, count):
        # 如果 count 小于 0,则设定 cutoff 为 count - 1;否则设定为 count
        if count < 0:
            # 添加额外的零填充
            cutoff = count - 1
        else:
            cutoff = count
        # 对二维数组 x 按指定轴进行位压缩,根据指定的 bitorder
        packed0 = np.packbits(self.x, axis=0, bitorder=bitorder)
        # 对压缩后的数据进行位解压缩,根据指定的轴、count 和 bitorder
        unpacked0 = np.unpackbits(packed0, axis=0, count=count,
                                  bitorder=bitorder)
        # 断言解压缩后数据类型为 uint8
        assert_equal(unpacked0.dtype, np.uint8)
        # 断言解压缩后的数组与 padded2 的前 cutoff 行和 x 的列数相等
        assert_array_equal(unpacked0, self.padded2[:cutoff, :self.x.shape[1]])

        # 对二维数组 x 按指定轴进行位压缩,根据指定的 bitorder
        packed1 = np.packbits(self.x, axis=1, bitorder=bitorder)
        # 对压缩后的数据进行位解压缩,根据指定的轴、count 和 bitorder
        unpacked1 = np.unpackbits(packed1, axis=1, count=count,
                                  bitorder=bitorder)
        # 断言解压缩后数据类型为 uint8
        assert_equal(unpacked1.dtype, np.uint8)
        # 断言解压缩后的数组与 padded2 的前 x 的行数和 cutoff 列相等
        assert_array_equal(unpacked1, self.padded2[:self.x.shape[0], :cutoff])

    # 使用 pytest 的参数化装饰器,指定测试函数的参数化条件
    @pytest.mark.parametrize('kwargs', [
                    {}, {'count': None},
                    {'bitorder' : 'little'},
                    {'bitorder': 'little', 'count': None},
                    {'bitorder' : 'big'},
                    {'bitorder': 'big', 'count': None},
                    ])
    # 测试函数,用于验证 np.packbits 和 np.unpackbits 的行为
    def test_axis_count(self, kwargs):
        # 在 axis=0 上对数组 self.x 进行位打包
        packed0 = np.packbits(self.x, axis=0)
        # 在 axis=0 上对 packed0 进行位解包,并根据 kwargs 参数进行额外配置
        unpacked0 = np.unpackbits(packed0, axis=0, **kwargs)
        # 断言解包后的数据类型为 np.uint8
        assert_equal(unpacked0.dtype, np.uint8)
        # 根据 bitorder 参数的配置,进行不同的数组比较和断言
        if kwargs.get('bitorder', 'big') == 'big':
            assert_array_equal(unpacked0, self.padded2[:-1, :self.x.shape[1]])
        else:
            assert_array_equal(unpacked0[::-1, :], self.padded2[:-1, :self.x.shape[1]])

        # 在 axis=1 上对数组 self.x 进行位打包
        packed1 = np.packbits(self.x, axis=1)
        # 在 axis=1 上对 packed1 进行位解包,并根据 kwargs 参数进行额外配置
        unpacked1 = np.unpackbits(packed1, axis=1, **kwargs)
        # 断言解包后的数据类型为 np.uint8
        assert_equal(unpacked1.dtype, np.uint8)
        # 根据 bitorder 参数的配置,进行不同的数组比较和断言
        if kwargs.get('bitorder', 'big') == 'big':
            assert_array_equal(unpacked1, self.padded2[:self.x.shape[0], :-1])
        else:
            assert_array_equal(unpacked1[:, ::-1], self.padded2[:self.x.shape[0], :-1])

    # 测试异常情况下的 np.unpackbits 函数
    def test_bad_count(self):
        # 在 axis=0 上对数组 self.x 进行位打包
        packed0 = np.packbits(self.x, axis=0)
        # 断言在指定 count=-9 的情况下,调用 np.unpackbits 会抛出 ValueError 异常
        assert_raises(ValueError, np.unpackbits, packed0, axis=0, count=-9)

        # 在 axis=1 上对数组 self.x 进行位打包
        packed1 = np.packbits(self.x, axis=1)
        # 断言在指定 count=-9 的情况下,调用 np.unpackbits 会抛出 ValueError 异常
        assert_raises(ValueError, np.unpackbits, packed1, axis=1, count=-9)

        # 对整个数组 self.x 进行位打包
        packed = np.packbits(self.x)
        # 断言在指定 count=-57 的情况下,调用 np.unpackbits 会抛出 ValueError 异常
        assert_raises(ValueError, np.unpackbits, packed, count=-57)

.\numpy\numpy\lib\tests\test_polynomial.py

import numpy as np
from numpy.testing import (
    assert_, assert_equal, assert_array_equal, assert_almost_equal,
    assert_array_almost_equal, assert_raises, assert_allclose
    )

import pytest

# 定义一个字符串,包含所有整数和浮点数类型码,但不包括布尔和时间类型
TYPE_CODES = np.typecodes["AllInteger"] + np.typecodes["AllFloat"] + "O"

# 定义测试类 TestPolynomial
class TestPolynomial:
    
    # 测试方法:测试 poly1d 对象的字符串表示和表达式
    def test_poly1d_str_and_repr(self):
        p = np.poly1d([1., 2, 3])
        # 断言 poly1d 对象的 repr 结果
        assert_equal(repr(p), 'poly1d([1., 2., 3.])')
        # 断言 poly1d 对象的 str 结果
        assert_equal(str(p),
                     '   2\n'
                     '1 x + 2 x + 3')

        q = np.poly1d([3., 2, 1])
        assert_equal(repr(q), 'poly1d([3., 2., 1.])')
        assert_equal(str(q),
                     '   2\n'
                     '3 x + 2 x + 1')

        r = np.poly1d([1.89999 + 2j, -3j, -5.12345678, 2 + 1j])
        assert_equal(str(r),
                     '            3      2\n'
                     '(1.9 + 2j) x - 3j x - 5.123 x + (2 + 1j)')

        assert_equal(str(np.poly1d([-3, -2, -1])),
                     '    2\n'
                     '-3 x - 2 x - 1')

    # 测试方法:测试 poly1d 对象的计算能力
    def test_poly1d_resolution(self):
        p = np.poly1d([1., 2, 3])
        q = np.poly1d([3., 2, 1])
        assert_equal(p(0), 3.0)
        assert_equal(p(5), 38.0)
        assert_equal(q(0), 1.0)
        assert_equal(q(5), 86.0)

    # 测试方法:测试 poly1d 对象的数学运算
    def test_poly1d_math(self):
        # 使用简单系数进行测试,以便计算更加简单
        p = np.poly1d([1., 2, 4])
        q = np.poly1d([4., 2, 1])
        assert_equal(p/q, (np.poly1d([0.25]), np.poly1d([1.5, 3.75])))
        assert_equal(p.integ(), np.poly1d([1/3, 1., 4., 0.]))
        assert_equal(p.integ(1), np.poly1d([1/3, 1., 4., 0.]))

        p = np.poly1d([1., 2, 3])
        q = np.poly1d([3., 2, 1])
        assert_equal(p * q, np.poly1d([3., 8., 14., 8., 3.]))
        assert_equal(p + q, np.poly1d([4., 4., 4.]))
        assert_equal(p - q, np.poly1d([-2., 0., 2.]))
        assert_equal(p ** 4, np.poly1d([1., 8., 36., 104., 214., 312., 324., 216., 81.]))
        assert_equal(p(q), np.poly1d([9., 12., 16., 8., 6.]))
        assert_equal(q(p), np.poly1d([3., 12., 32., 40., 34.]))
        assert_equal(p.deriv(), np.poly1d([2., 2.]))
        assert_equal(p.deriv(2), np.poly1d([2.]))
        assert_equal(np.polydiv(np.poly1d([1, 0, -1]), np.poly1d([1, 1])),
                     (np.poly1d([1., -1.]), np.poly1d([0.])))

    # 使用 pytest 的参数化装饰器,参数是 TYPE_CODES 中的类型码
    @pytest.mark.parametrize("type_code", TYPE_CODES)
    # 测试多项式对象的杂项功能,接受一个类型码参数
    def test_poly1d_misc(self, type_code: str) -> None:
        # 根据给定的类型码创建 NumPy 数据类型对象
        dtype = np.dtype(type_code)
        # 创建一个 NumPy 数组,使用指定的数据类型
        ar = np.array([1, 2, 3], dtype=dtype)
        # 使用数组创建多项式对象
        p = np.poly1d(ar)

        # 测试多项式对象的相等性 `__eq__`
        assert_equal(np.asarray(p), ar)
        # 断言多项式对象的数据类型与输入的数据类型一致
        assert_equal(np.asarray(p).dtype, dtype)
        # 断言多项式对象的阶数为 2
        assert_equal(len(p), 2)

        # 测试多项式对象的索引访问 `__getitem__`
        # 准备一个预期的索引与值的对应字典
        comparison_dct = {-1: 0, 0: 3, 1: 2, 2: 1, 3: 0}
        # 遍历字典进行断言测试
        for index, ref in comparison_dct.items():
            # 获取多项式对象在当前索引位置的值
            scalar = p[index]
            # 断言获取的值与预期值相等
            assert_equal(scalar, ref)
            # 如果数据类型是 np.object_,则额外断言获取的值是整数类型
            if dtype == np.object_:
                assert isinstance(scalar, int)
            else:
                # 否则,断言获取的值的数据类型与输入的数据类型一致
                assert_equal(scalar.dtype, dtype)

    # 测试多项式对象使用不同变量参数的行为
    def test_poly1d_variable_arg(self):
        # 使用自定义变量名 'y' 创建多项式对象并断言其字符串表示
        q = np.poly1d([1., 2, 3], variable='y')
        assert_equal(str(q),
                     '   2\n'
                     '1 y + 2 y + 3')
        # 使用自定义变量名 'lambda' 创建多项式对象并断言其字符串表示
        q = np.poly1d([1., 2, 3], variable='lambda')
        assert_equal(str(q),
                     '        2\n'
                     '1 lambda + 2 lambda + 3')

    # 测试多项式对象的其他功能
    def test_poly(self):
        # 断言计算给定系数的多项式的根,与预期的根数组相等
        assert_array_almost_equal(np.poly([3, -np.sqrt(2), np.sqrt(2)]),
                                  [1, -3, -2, 6])

        # 从 Matlab 文档中复制的测试用例
        A = [[1, 2, 3], [4, 5, 6], [7, 8, 0]]
        # 断言计算给定系数矩阵的多项式的系数,与预期的系数数组相等
        assert_array_almost_equal(np.poly(A), [1, -6, -72, -27])

        # 测试应该对于完美共轭根产生实数输出
        assert_(np.isrealobj(np.poly([+1.082j, +2.613j, -2.613j, -1.082j])))
        assert_(np.isrealobj(np.poly([0+1j, -0+-1j, 1+2j,
                                      1-2j, 1.+3.5j, 1-3.5j])))
        assert_(np.isrealobj(np.poly([1j, -1j, 1+2j, 1-2j, 1+3j, 1-3.j])))
        assert_(np.isrealobj(np.poly([1j, -1j, 2j, -2j])))
        assert_(np.isrealobj(np.poly([1j, -1j])))
        assert_(np.isrealobj(np.poly([1, -1])))

        # 断言计算给定系数的多项式的根,与预期的结果相符
        assert_(np.iscomplexobj(np.poly([1j, -1.0000001j])))

        # 随机生成一组复数系数,并断言计算多项式的根的实部结果
        np.random.seed(42)
        a = np.random.randn(100) + 1j*np.random.randn(100)
        assert_(np.isrealobj(np.poly(np.concatenate((a, np.conjugate(a))))))

    # 测试计算多项式的根功能
    def test_roots(self):
        # 断言计算给定系数的多项式的根,与预期的根数组相等
        assert_array_equal(np.roots([1, 0, 0]), [0, 0])

    # 测试多项式对象字符串表示中的前导零处理
    def test_str_leading_zeros(self):
        # 创建一个具有给定系数的多项式对象,并设置其中一个系数为零,断言其字符串表示
        p = np.poly1d([4, 3, 2, 1])
        p[3] = 0
        assert_equal(str(p),
                     "   2\n"
                     "3 x + 2 x + 1")

        # 创建一个具有给定系数的多项式对象,并将所有系数设置为零,断言其字符串表示
        p = np.poly1d([1, 2])
        p[0] = 0
        p[1] = 0
        assert_equal(str(p), " \n0")
    # 定义测试函数,验证多项式对象的操作
    def test_objects(self):
        # 导入 Decimal 类用于精确计算
        from decimal import Decimal
        # 创建多项式对象 p,系数为 [4.0, 3.0, 2.0]
        p = np.poly1d([Decimal('4.0'), Decimal('3.0'), Decimal('2.0')])
        # 将 p 乘以 Decimal('1.333333333333333') 得到 p2
        p2 = p * Decimal('1.333333333333333')
        # 断言 p2 的索引 1 的值等于 Decimal("3.9999999999999990")
        assert_(p2[1] == Decimal("3.9999999999999990"))
        # 对 p 求导数得到 p2
        p2 = p.deriv()
        # 断言 p2 的索引 1 的值等于 Decimal('8.0')
        assert_(p2[1] == Decimal('8.0'))
        # 对 p 求积分得到 p2
        p2 = p.integ()
        # 断言 p2 的索引 3 的值等于 Decimal("1.333333333333333333333333333")
        assert_(p2[3] == Decimal("1.333333333333333333333333333"))
        # 断言 p2 的索引 2 的值等于 Decimal('1.5')
        assert_(p2[2] == Decimal('1.5'))
        # 断言 p2 的系数的数据类型是 np.object_
        assert_(np.issubdtype(p2.coeffs.dtype, np.object_))
        # 创建多项式对象 p,系数为 [1, -3, 2],并断言其与给定列表相等
        assert_equal(np.poly([Decimal(1), Decimal(2)]),
                     [1, Decimal(-3), Decimal(2)])

    # 定义测试函数,验证复数系数的多项式对象操作
    def test_complex(self):
        # 创建多项式对象 p,系数为 [3j, 2j, 1j]
        p = np.poly1d([3j, 2j, 1j])
        # 对 p 求积分得到 p2
        p2 = p.integ()
        # 断言 p2 的系数等于 [1j, 1j, 1j, 0] 的所有元素
        assert_((p2.coeffs == [1j, 1j, 1j, 0]).all())
        # 对 p 求导数得到 p2
        p2 = p.deriv()
        # 断言 p2 的系数等于 [6j, 2j] 的所有元素
        assert_((p2.coeffs == [6j, 2j]).all())

    # 定义测试函数,验证多项式积分的系数计算
    def test_integ_coeffs(self):
        # 创建多项式对象 p,系数为 [3, 2, 1]
        p = np.poly1d([3, 2, 1])
        # 对 p 求积分,指定积分常数和系数 k,得到 p2
        p2 = p.integ(3, k=[9, 7, 6])
        # 断言 p2 的系数等于计算结果的所有元素
        assert_(
            (p2.coeffs == [1/4./5., 1/3./4., 1/2./3., 9/1./2., 7, 6]).all())

    # 定义测试函数,验证处理零维多项式的异常
    def test_zero_dims(self):
        # 尝试创建零维多项式,捕获 ValueError 异常
        try:
            np.poly(np.zeros((0, 0)))
        except ValueError:
            pass

    # 定义测试函数,验证多项式整数溢出问题的回归测试
    def test_poly_int_overflow(self):
        """
        Regression test for gh-5096.
        """
        # 创建一个范围为 [1, 20] 的向量 v
        v = np.arange(1, 21)
        # 断言 np.poly(v) 与 np.poly(np.diag(v)) 的近似相等
        assert_almost_equal(np.poly(v), np.poly(np.diag(v)))

    # 定义测试函数,验证多项式零值数据类型问题的回归测试
    def test_zero_poly_dtype(self):
        """
        Regression test for gh-16354.
        """
        # 创建一个值全为零的数组 z
        z = np.array([0, 0, 0])
        # 创建整型系数的多项式对象 p,并断言其系数的数据类型为 np.int64
        p = np.poly1d(z.astype(np.int64))
        assert_equal(p.coeffs.dtype, np.int64)
        # 创建单精度浮点型系数的多项式对象 p,并断言其系数的数据类型为 np.float32
        p = np.poly1d(z.astype(np.float32))
        assert_equal(p.coeffs.dtype, np.float32)
        # 创建复数型系数的多项式对象 p,并断言其系数的数据类型为 np.complex64
        p = np.poly1d(z.astype(np.complex64))
        assert_equal(p.coeffs.dtype, np.complex64)

    # 定义测试函数,验证多项式对象的相等性操作
    def test_poly_eq(self):
        # 创建多项式对象 p,系数为 [1, 2, 3]
        p = np.poly1d([1, 2, 3])
        # 创建多项式对象 p2,系数为 [1, 2, 4]
        p2 = np.poly1d([1, 2, 4])
        # 断言 p 是否等于 None,结果为 False
        assert_equal(p == None, False)
        # 断言 p 是否不等于 None,结果为 True
        assert_equal(p != None, True)
        # 断言 p 是否等于自身,结果为 True
        assert_equal(p == p, True)
        # 断言 p 是否等于 p2,结果为 False
        assert_equal(p == p2, False)
        # 断言 p 是否不等于 p2,结果为 True
        assert_equal(p != p2, True)

    # 定义测试函数,验证多项式除法操作
    def test_polydiv(self):
        # 创建多项式对象 b 和 a
        b = np.poly1d([2, 6, 6, 1])
        a = np.poly1d([-1j, (1+2j), -(2+1j), 1])
        # 对 b 除以 a 得到商 q 和余数 r
        q, r = np.polydiv(b, a)
        # 断言商 q 和余数 r 的系数数据类型为 np.complex128
        assert_equal(q.coeffs.dtype, np.complex128)
        assert_equal(r.coeffs.dtype, np.complex128)
        # 断言 q 乘以 a 再加上 r 等于 b
        assert_equal(q*a + r, b)

        # 创建列表 c 和多项式对象 d
        c = [1, 2, 3]
        d = np.poly1d([1, 2, 3])
        # 对 c 除以 d 得到商 s 和余数 t
        s, t = np.polydiv(c, d)
        # 断言 s 和 t 的类型为 np.poly1d
        assert isinstance(s, np.poly1d)
        assert isinstance(t, np.poly1d)
        # 对 d 除以 c 得到商 u 和余数 v
        u, v = np.polydiv(d, c)
        # 断言 u 和 v 的类型为 np.poly1d
        assert isinstance(u, np.poly1d)
        assert isinstance(v, np.poly1d)
    def test_poly_coeffs_mutable(self):
        """ 测试多项式系数是否可修改 """
        # 创建一个三次多项式对象,系数为 [1, 2, 3]
        p = np.poly1d([1, 2, 3])

        # 修改多项式的系数,应该会使系数增加 1
        p.coeffs += 1
        # 断言修改后的系数应为 [2, 3, 4]
        assert_equal(p.coeffs, [2, 3, 4])

        # 修改系数的第三个元素(系数为 3),增加 10
        p.coeffs[2] += 10
        # 断言修改后的系数应为 [2, 3, 14]
        assert_equal(p.coeffs, [2, 3, 14])

        # 尝试设置系数属性为一个新的 NumPy 数组(这是不允许的,应引发 AttributeError 异常)
        assert_raises(AttributeError, setattr, p, 'coeffs', np.array(1))

.\numpy\numpy\lib\tests\test_recfunctions.py

# 导入pytest模块,用于测试和断言
import pytest

# 导入NumPy库,并将其命名为np,用于数组操作
import numpy as np

# 导入NumPy的masked array模块,用于处理带有掩码的数组
import numpy.ma as ma

# 导入NumPy的masked records模块,用于操作带有掩码的记录数组
from numpy.ma.mrecords import MaskedRecords

# 导入NumPy的测试工具,包括断言函数assert_equal
from numpy.ma.testutils import assert_equal

# 导入NumPy的测试工具,包括assert_和assert_raises函数
from numpy.testing import assert_, assert_raises

# 导入NumPy的记录数组函数模块,包括字段操作和结构操作
from numpy.lib.recfunctions import (
    drop_fields, rename_fields, get_fieldstructure, recursive_fill_fields,
    find_duplicates, merge_arrays, append_fields, stack_arrays, join_by,
    repack_fields, unstructured_to_structured, structured_to_unstructured,
    apply_along_fields, require_fields, assign_fields_by_name)

# 导入NumPy的记录数组字段操作辅助函数
get_fieldspec = np.lib.recfunctions._get_fieldspec
get_names = np.lib.recfunctions.get_names
get_names_flat = np.lib.recfunctions.get_names_flat

# 导入NumPy的记录数组描述和数据类型压缩函数
zip_descr = np.lib.recfunctions._zip_descr
zip_dtype = np.lib.recfunctions._zip_dtype

# 定义测试类TestRecFunctions,用于测试记录数组相关函数
class TestRecFunctions:
    
    # 测试准备方法,初始化测试数据
    def setup_method(self):
        x = np.array([1, 2, ])
        y = np.array([10, 20, 30])
        z = np.array([('A', 1.), ('B', 2.)],
                     dtype=[('A', '|S3'), ('B', float)])
        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
                     dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
        self.data = (w, x, y, z)

    # 测试zip_descr函数
    def test_zip_descr(self):
        # 从self.data中获取测试数据
        (w, x, y, z) = self.data
        
        # 测试标准数组的zip_descr函数调用,使用flatten=True
        test = zip_descr((x, x), flatten=True)
        assert_equal(test,
                     np.dtype([('', int), ('', int)]))
        
        # 再次测试标准数组的zip_descr函数调用,使用flatten=False
        test = zip_descr((x, x), flatten=False)
        assert_equal(test,
                     np.dtype([('', int), ('', int)]))
        
        # 测试标准数组和灵活数据类型数组的zip_descr函数调用,使用flatten=True
        test = zip_descr((x, z), flatten=True)
        assert_equal(test,
                     np.dtype([('', int), ('A', '|S3'), ('B', float)]))
        
        # 再次测试标准数组和灵活数据类型数组的zip_descr函数调用,使用flatten=False
        test = zip_descr((x, z), flatten=False)
        assert_equal(test,
                     np.dtype([('', int),
                               ('', [('A', '|S3'), ('B', float)])]))
        
        # 测试标准数组和嵌套数据类型数组的zip_descr函数调用,使用flatten=True
        test = zip_descr((x, w), flatten=True)
        assert_equal(test,
                     np.dtype([('', int),
                               ('a', int),
                               ('ba', float), ('bb', int)]))
        
        # 再次测试标准数组和嵌套数据类型数组的zip_descr函数调用,使用flatten=False
        test = zip_descr((x, w), flatten=False)
        assert_equal(test,
                     np.dtype([('', int),
                               ('', [('a', int),
                                     ('b', [('ba', float), ('bb', int)])])]))
    def test_drop_fields(self):
        # Test drop_fields
        # 创建一个 NumPy 数组,包含复合数据类型,具有基本和嵌套字段
        a = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
                     dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])

        # 测试删除指定字段后的结果
        test = drop_fields(a, 'a')
        # 预期的结果数组,仅包含剩余字段
        control = np.array([((2, 3.0),), ((5, 6.0),)],
                           dtype=[('b', [('ba', float), ('bb', int)])])
        assert_equal(test, control)

        # 测试删除另一个字段后的结果(基本字段,但包含两个字段的嵌套)
        test = drop_fields(a, 'b')
        # 预期的结果数组,仅包含剩余字段
        control = np.array([(1,), (4,)], dtype=[('a', int)])
        assert_equal(test, control)

        # 测试删除嵌套子字段后的结果
        test = drop_fields(a, ['ba', ])
        # 预期的结果数组,仅包含剩余字段
        control = np.array([(1, (3.0,)), (4, (6.0,))],
                           dtype=[('a', int), ('b', [('bb', int)])])
        assert_equal(test, control)

        # 测试删除一个字段的所有嵌套子字段后的结果
        test = drop_fields(a, ['ba', 'bb'])
        # 预期的结果数组,仅包含剩余字段
        control = np.array([(1,), (4,)], dtype=[('a', int)])
        assert_equal(test, control)

        # 测试删除所有字段后的结果
        test = drop_fields(a, ['a', 'b'])
        # 预期的结果数组,不包含任何字段
        control = np.array([(), ()], dtype=[])
        assert_equal(test, control)

    def test_rename_fields(self):
        # Test rename fields
        # 创建一个 NumPy 数组,包含复合数据类型,具有基本和嵌套字段
        a = np.array([(1, (2, [3.0, 30.])), (4, (5, [6.0, 60.]))],
                     dtype=[('a', int),
                            ('b', [('ba', float), ('bb', (float, 2))])])
        # 测试重命名字段后的结果
        test = rename_fields(a, {'a': 'A', 'bb': 'BB'})
        # 期望的新数据类型
        newdtype = [('A', int), ('b', [('ba', float), ('BB', (float, 2))])]
        # 期望的控制数组视图
        control = a.view(newdtype)
        assert_equal(test.dtype, newdtype)
        assert_equal(test, control)

    def test_get_names(self):
        # Test get_names
        # 创建一个 NumPy 数据类型对象,其中包含命名字段
        ndtype = np.dtype([('A', '|S3'), ('B', float)])
        # 测试获取字段名称的结果
        test = get_names(ndtype)
        assert_equal(test, ('A', 'B'))

        ndtype = np.dtype([('a', int), ('b', [('ba', float), ('bb', int)])])
        test = get_names(ndtype)
        assert_equal(test, ('a', ('b', ('ba', 'bb'))))

        ndtype = np.dtype([('a', int), ('b', [])])
        test = get_names(ndtype)
        assert_equal(test, ('a', ('b', ())))

        ndtype = np.dtype([])
        test = get_names(ndtype)
        assert_equal(test, ())

    def test_get_names_flat(self):
        # Test get_names_flat
        # 创建一个 NumPy 数据类型对象,其中包含命名字段
        ndtype = np.dtype([('A', '|S3'), ('B', float)])
        # 测试获取扁平化字段名称的结果
        test = get_names_flat(ndtype)
        assert_equal(test, ('A', 'B'))

        ndtype = np.dtype([('a', int), ('b', [('ba', float), ('bb', int)])])
        test = get_names_flat(ndtype)
        assert_equal(test, ('a', 'b', 'ba', 'bb'))

        ndtype = np.dtype([('a', int), ('b', [])])
        test = get_names_flat(ndtype)
        assert_equal(test, ('a', 'b'))

        ndtype = np.dtype([])
        test = get_names_flat(ndtype)
        assert_equal(test, ())
    def test_get_fieldstructure(self):
        # Test get_fieldstructure

        # No nested fields
        ndtype = np.dtype([('A', '|S3'), ('B', float)])
        # 调用 get_fieldstructure 函数,传入 dtype 对象 ndtype
        test = get_fieldstructure(ndtype)
        # 断言结果与预期的空字典匹配
        assert_equal(test, {'A': [], 'B': []})

        # One 1-nested field
        ndtype = np.dtype([('A', int), ('B', [('BA', float), ('BB', '|S1')])])
        # 调用 get_fieldstructure 函数,传入 dtype 对象 ndtype
        test = get_fieldstructure(ndtype)
        # 断言结果与预期的字段结构字典匹配
        assert_equal(test, {'A': [], 'B': [], 'BA': ['B', ], 'BB': ['B']})

        # One 2-nested fields
        ndtype = np.dtype([('A', int),
                           ('B', [('BA', int),
                                  ('BB', [('BBA', int), ('BBB', int)])])])
        # 调用 get_fieldstructure 函数,传入 dtype 对象 ndtype
        test = get_fieldstructure(ndtype)
        # 预期的嵌套字段结构字典
        control = {'A': [], 'B': [], 'BA': ['B'], 'BB': ['B'],
                   'BBA': ['B', 'BB'], 'BBB': ['B', 'BB']}
        # 断言结果与预期的字段结构字典匹配
        assert_equal(test, control)

        # 0 fields
        ndtype = np.dtype([])
        # 调用 get_fieldstructure 函数,传入 dtype 对象 ndtype
        test = get_fieldstructure(ndtype)
        # 断言结果与预期的空字典匹配
        assert_equal(test, {})

    def test_find_duplicates(self):
        # Test find_duplicates
        # 创建一个结构化数组 a
        a = ma.array([(2, (2., 'B')), (1, (2., 'B')), (2, (2., 'B')),
                      (1, (1., 'B')), (2, (2., 'B')), (2, (2., 'C'))],
                     mask=[(0, (0, 0)), (0, (0, 0)), (0, (0, 0)),
                           (0, (0, 0)), (1, (0, 0)), (0, (1, 0))],
                     dtype=[('A', int), ('B', [('BA', float), ('BB', '|S1')])])
        # 调用 find_duplicates 函数,使用默认参数 ignoremask=False 和 return_index=True
        test = find_duplicates(a, ignoremask=False, return_index=True)
        # 预期的重复索引列表
        control = [0, 2]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])

        # 调用 find_duplicates 函数,使用 key='A' 和 return_index=True
        test = find_duplicates(a, key='A', return_index=True)
        # 预期的重复索引列表
        control = [0, 1, 2, 3, 5]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])

        # 调用 find_duplicates 函数,使用 key='B' 和 return_index=True
        test = find_duplicates(a, key='B', return_index=True)
        # 预期的重复索引列表
        control = [0, 1, 2, 4]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])

        # 调用 find_duplicates 函数,使用 key='BA' 和 return_index=True
        test = find_duplicates(a, key='BA', return_index=True)
        # 预期的重复索引列表
        control = [0, 1, 2, 4]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])

        # 调用 find_duplicates 函数,使用 key='BB' 和 return_index=True
        test = find_duplicates(a, key='BB', return_index=True)
        # 预期的重复索引列表
        control = [0, 1, 2, 3, 4]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])

    def test_find_duplicates_ignoremask(self):
        # Test the ignoremask option of find_duplicates
        # 创建一个结构化数据类型 ndtype
        ndtype = [('a', int)]
        # 创建一个掩码数组 a
        a = ma.array([1, 1, 1, 2, 2, 3, 3],
                     mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype)
        # 调用 find_duplicates 函数,使用 ignoremask=True 和 return_index=True
        test = find_duplicates(a, ignoremask=True, return_index=True)
        # 预期的重复索引列表
        control = [0, 1, 3, 4]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])

        # 调用 find_duplicates 函数,使用 ignoremask=False 和 return_index=True
        test = find_duplicates(a, ignoremask=False, return_index=True)
        # 预期的重复索引列表
        control = [0, 1, 2, 3, 4, 6]
        # 断言结果与预期的重复索引列表匹配(排序后)
        assert_equal(sorted(test[-1]), control)
        # 断言返回的值与数组 a 中相应索引位置的值匹配
        assert_equal(test[0], a[test[-1]])
    # 定义测试方法:重组字段
    def test_repack_fields(self):
        # 创建一个自定义的数据类型,包含一个无符号字节、一个单精度浮点数和一个长整型数
        dt = np.dtype('u1,f4,i8', align=True)
        # 创建一个包含两个元素的零数组,数据类型为上述自定义类型
        a = np.zeros(2, dtype=dt)

        # 断言重组字段函数按预期返回了指定的数据类型
        assert_equal(repack_fields(dt), np.dtype('u1,f4,i8'))
        # 断言重组字段函数返回的数据类型的字节大小是13
        assert_equal(repack_fields(a).itemsize, 13)
        # 断言带有对齐参数的重组字段函数能够正确地恢复原始数据类型
        assert_equal(repack_fields(repack_fields(dt), align=True), dt)

        # 确保类型信息得到保留
        # 将数据类型转换为记录数组的数据类型
        dt = np.dtype((np.record, dt))
        # 断言重组字段函数返回的数据类型的类型对象是 np.record
        assert_(repack_fields(dt).type is np.record)

    # 定义测试方法:非结构化数组转为结构化数组
    def test_unstructured_to_structured(self):
        # 创建一个形状为 (20, 2) 的零数组
        a = np.zeros((20, 2))
        # 定义测试用的数据类型参数列表
        test_dtype_args = [('x', float), ('y', float)]
        # 创建一个指定数据类型的数据类型对象
        test_dtype = np.dtype(test_dtype_args)
        # 调用非结构化数组转为结构化数组的函数,使用数据类型参数列表作为 dtype 参数
        field1 = unstructured_to_structured(a, dtype=test_dtype_args)  # now
        # 调用非结构化数组转为结构化数组的函数,使用数据类型对象作为 dtype 参数
        field2 = unstructured_to_structured(a, dtype=test_dtype)  # before
        # 断言两次调用的结果相等
        assert_equal(field1, field2)

    # 定义测试方法:按字段名进行赋值
    def test_field_assignment_by_name(self):
        # 创建一个包含两个元素的数组,数据类型包含三个字段:a (整数), b (双精度浮点数), c (无符号字节)
        a = np.ones(2, dtype=[('a', 'i4'), ('b', 'f8'), ('c', 'u1')])
        # 新的数据类型定义,包含两个字段:b (单精度浮点数), c (无符号字节)
        newdt = [('b', 'f4'), ('c', 'u1')]

        # 断言要求字段函数返回的结果与新数据类型一致
        assert_equal(require_fields(a, newdt), np.ones(2, newdt))

        # 创建一个包含两个元素的数组,数据类型为新定义的数据类型 newdt
        b = np.array([(1,2), (3,4)], dtype=newdt)
        # 调用按字段名赋值函数,保留未赋值字段的原值
        assign_fields_by_name(a, b, zero_unassigned=False)
        # 断言数组 a 的值与预期一致
        assert_equal(a, np.array([(1,1,2),(1,3,4)], dtype=a.dtype))
        # 再次调用按字段名赋值函数,未赋值字段赋值为零
        assign_fields_by_name(a, b)
        # 断言数组 a 的值与预期一致
        assert_equal(a, np.array([(0,1,2),(0,3,4)], dtype=a.dtype))

        # 测试嵌套字段的情况
        # 创建一个包含一个元素的数组,元素是嵌套结构,包含一个双精度浮点数和一个无符号字节
        a = np.ones(2, dtype=[('a', [('b', 'f8'), ('c', 'u1')])])
        # 新的数据类型定义,包含一个元素,元素是嵌套结构,包含一个无符号字节
        newdt = [('a', [('c', 'u1')])]
        # 断言要求字段函数返回的结果与新数据类型一致
        assert_equal(require_fields(a, newdt), np.ones(2, newdt))
        # 创建一个包含一个元素的数组,元素是嵌套结构,包含一个元素 (2,)
        b = np.array([((2,),), ((3,),)], dtype=newdt)
        # 调用按字段名赋值函数,保留未赋值字段的原值
        assign_fields_by_name(a, b, zero_unassigned=False)
        # 断言数组 a 的值与预期一致
        assert_equal(a, np.array([((1,2),), ((1,3),)], dtype=a.dtype))
        # 再次调用按字段名赋值函数,未赋值字段赋值为零
        assign_fields_by_name(a, b)
        # 断言数组 a 的值与预期一致
        assert_equal(a, np.array([((0,2),), ((0,3),)], dtype=a.dtype))

        # 测试针对 0 维数组的非结构化代码路径
        a, b = np.array(3), np.array(0)
        # 调用按字段名赋值函数
        assign_fields_by_name(b, a)
        # 断言数组 b 中索引为 () 的元素的值为 3
        assert_equal(b[()], 3)
class TestRecursiveFillFields:
    # Test recursive_fill_fields.

    def test_simple_flexible(self):
        # Test recursive_fill_fields on flexible-array
        # 创建一个包含两个元组的 NumPy 数组 a,其中元组包含 (1, 10.) 和 (2, 20.),
        # 数据类型为 [('A', int), ('B', float)]
        a = np.array([(1, 10.), (2, 20.)], dtype=[('A', int), ('B', float)])
        
        # 创建一个形状为 (3,) 的零数组 b,数据类型与数组 a 相同
        b = np.zeros((3,), dtype=a.dtype)
        
        # 调用 recursive_fill_fields 函数,将数组 a 填充到数组 b 中
        test = recursive_fill_fields(a, b)
        
        # 创建一个控制数组 control,包含三个元组 (1, 10.), (2, 20.), (0, 0.)
        # 数据类型为 [('A', int), ('B', float)]
        control = np.array([(1, 10.), (2, 20.), (0, 0.)],
                           dtype=[('A', int), ('B', float)])
        
        # 断言测试结果与控制数组相等
        assert_equal(test, control)

    def test_masked_flexible(self):
        # Test recursive_fill_fields on masked flexible-array
        # 创建一个包含两个元组的掩码数组 a,元组为 (1, 10.) 和 (2, 20.),掩码为 [(0, 1), (1, 0)]
        # 数据类型为 [('A', int), ('B', float)]
        a = ma.array([(1, 10.), (2, 20.)], mask=[(0, 1), (1, 0)],
                     dtype=[('A', int), ('B', float)])
        
        # 创建一个形状为 (3,) 的掩码数组 b,数据类型与数组 a 相同
        b = ma.zeros((3,), dtype=a.dtype)
        
        # 调用 recursive_fill_fields 函数,将掩码数组 a 填充到掩码数组 b 中
        test = recursive_fill_fields(a, b)
        
        # 创建一个控制数组 control,包含三个元组 (1, 10.), (2, 20.), (0, 0.)
        # 数据类型为 [('A', int), ('B', float)],并且具有相同的掩码
        control = ma.array([(1, 10.), (2, 20.), (0, 0.)],
                           mask=[(0, 1), (1, 0), (0, 0)],
                           dtype=[('A', int), ('B', float)])
        
        # 断言测试结果与控制数组相等
        assert_equal(test, control)


class TestMergeArrays:
    # Test merge_arrays

    def setup_method(self):
        # 设置测试方法的初始数据
        x = np.array([1, 2, ])
        y = np.array([10, 20, 30])
        z = np.array(
            [('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
        w = np.array(
            [(1, (2, 3.0, ())), (4, (5, 6.0, ()))],
            dtype=[('a', int), ('b', [('ba', float), ('bb', int), ('bc', [])])])
        self.data = (w, x, y, z)

    def test_solo(self):
        # Test merge_arrays on a single array.
        (_, x, _, z) = self.data
        
        # 调用 merge_arrays 函数,传入数组 x
        test = merge_arrays(x)
        
        # 创建一个控制数组 control,包含两个元组 (1,) 和 (2,)
        # 数据类型为 [('f0', int)]
        control = np.array([(1,), (2,)], dtype=[('f0', int)])
        
        # 断言测试结果与控制数组相等
        assert_equal(test, control)
        
        # 再次调用 merge_arrays 函数,传入元组 (x,)
        test = merge_arrays((x,))
        
        # 断言测试结果与控制数组相等
        assert_equal(test, control)
        
        # 调用 merge_arrays 函数,传入数组 z,并指定 flatten=False
        test = merge_arrays(z, flatten=False)
        
        # 断言测试结果与数组 z 相等
        assert_equal(test, z)
        
        # 再次调用 merge_arrays 函数,传入数组 z,并指定 flatten=True
        test = merge_arrays(z, flatten=True)
        
        # 断言测试结果与数组 z 相等
        assert_equal(test, z)

    def test_solo_w_flatten(self):
        # Test merge_arrays on a single array w & w/o flattening
        # 获取初始数据中的数组 w
        w = self.data[0]
        
        # 调用 merge_arrays 函数,传入数组 w,并指定 flatten=False
        test = merge_arrays(w, flatten=False)
        
        # 断言测试结果与数组 w 相等
        assert_equal(test, w)
        
        # 再次调用 merge_arrays 函数,传入数组 w,并指定 flatten=True
        test = merge_arrays(w, flatten=True)
        
        # 创建一个控制数组 control,包含两个元组 (1, 2, 3.0) 和 (4, 5, 6.0)
        # 数据类型为 [('a', int), ('ba', float), ('bb', int)]
        control = np.array([(1, 2, 3.0), (4, 5, 6.0)],
                           dtype=[('a', int), ('ba', float), ('bb', int)])
        
        # 断言测试结果与控制数组相等
        assert_equal(test, control)

    def test_standard(self):
        # Test standard & standard
        # Test merge arrays
        (_, x, y, _) = self.data
        
        # 调用 merge_arrays 函数,传入元组 (x, y),并指定 usemask=False
        test = merge_arrays((x, y), usemask=False)
        
        # 创建一个控制数组 control,包含三个元组 (1, 10), (2, 20), (-1, 30)
        # 数据类型为 [('f0', int), ('f1', int)]
        control = np.array([(1, 10), (2, 20), (-1, 30)],
                           dtype=[('f0', int), ('f1', int)])
        
        # 断言测试结果与控制数组相等
        assert_equal(test, control)
        
        # 再次调用 merge_arrays 函数,传入元组 (x, y),并指定 usemask=True
        test = merge_arrays((x, y), usemask=True)
        
        # 创建一个控制数组 control,包含三个元组 (1, 10), (2, 20), (-1, 30)
        # 数据类型为 [('f0', int), ('f1', int)],并且具有相同的掩码
        control = ma.array([(1, 10), (2, 20), (-1, 30)],
                           mask=[(0, 0), (0, 0), (1, 0)],
                           dtype=[('f0', int), ('f1', int)])
        
        # 断言测试结果与控制数组相等,并且掩码也相等
        assert_equal(test, control)
        assert_equal(test.mask, control.mask)
    def test_flatten(self):
        # Test standard & flexible
        (_, x, _, z) = self.data  # 从测试数据中解包出 x 和 z
        test = merge_arrays((x, z), flatten=True)  # 调用 merge_arrays 函数,使用 flatten=True 进行合并测试
        control = np.array([(1, 'A', 1.), (2, 'B', 2.)],
                           dtype=[('f0', int), ('A', '|S3'), ('B', float)])  # 预期的合并结果
        assert_equal(test, control)  # 断言测试结果与预期结果相等

        test = merge_arrays((x, z), flatten=False)  # 再次调用 merge_arrays 函数,使用 flatten=False 进行合并测试
        control = np.array([(1, ('A', 1.)), (2, ('B', 2.))],
                           dtype=[('f0', int),
                                  ('f1', [('A', '|S3'), ('B', float)])])  # 预期的合并结果,包含嵌套结构
        assert_equal(test, control)  # 断言测试结果与预期结果相等

    def test_flatten_wflexible(self):
        # Test flatten standard & nested
        (w, x, _, _) = self.data  # 从测试数据中解包出 w 和 x
        test = merge_arrays((x, w), flatten=True)  # 调用 merge_arrays 函数,使用 flatten=True 进行合并测试
        control = np.array([(1, 1, 2, 3.0), (2, 4, 5, 6.0)],
                           dtype=[('f0', int),
                                  ('a', int), ('ba', float), ('bb', int)])  # 预期的合并结果
        assert_equal(test, control)  # 断言测试结果与预期结果相等

        test = merge_arrays((x, w), flatten=False)  # 再次调用 merge_arrays 函数,使用 flatten=False 进行合并测试
        controldtype = [('f0', int),
                        ('f1', [('a', int),
                                ('b', [('ba', float), ('bb', int), ('bc', [])])])]  # 预期的 dtype 结构
        control = np.array([(1., (1, (2, 3.0, ()))), (2, (4, (5, 6.0, ())))],
                           dtype=controldtype)  # 预期的合并结果,包含复杂的嵌套结构
        assert_equal(test, control)  # 断言测试结果与预期结果相等

    def test_wmasked_arrays(self):
        # Test merge_arrays masked arrays
        (_, x, _, _) = self.data  # 从测试数据中解包出 x
        mx = ma.array([1, 2, 3], mask=[1, 0, 0])  # 创建一个带掩码的 MaskedArray
        test = merge_arrays((x, mx), usemask=True)  # 调用 merge_arrays 函数,使用带掩码的测试
        control = ma.array([(1, 1), (2, 2), (-1, 3)],
                           mask=[(0, 1), (0, 0), (1, 0)],
                           dtype=[('f0', int), ('f1', int)])  # 预期的合并结果,包含掩码信息
        assert_equal(test, control)  # 断言测试结果与预期结果相等
        test = merge_arrays((x, mx), usemask=True, asrecarray=True)  # 再次调用 merge_arrays 函数,使用 asrecarray=True 进行测试
        assert_equal(test, control)  # 断言测试结果与预期结果相等
        assert_(isinstance(test, MaskedRecords))  # 断言返回结果是 MaskedRecords 类的实例

    def test_w_singlefield(self):
        # Test single field
        test = merge_arrays((np.array([1, 2]).view([('a', int)]),
                             np.array([10., 20., 30.])),)  # 调用 merge_arrays 函数,测试单个字段合并
        control = ma.array([(1, 10.), (2, 20.), (-1, 30.)],
                           mask=[(0, 0), (0, 0), (1, 0)],
                           dtype=[('a', int), ('f1', float)])  # 预期的合并结果,包含掩码信息和 dtype
        assert_equal(test, control)  # 断言测试结果与预期结果相等
    def test_w_shorter_flex(self):
        # 测试 merge_arrays 函数,使用较短的 flexndarray 作为输入
        z = self.data[-1]

        # FIXME,这个测试看起来是不完整和有问题的
        #test = merge_arrays((z, np.array([10, 20, 30]).view([('C', int)])))
        #control = np.array([('A', 1., 10), ('B', 2., 20), ('-1', -1, 20)],
        #                   dtype=[('A', '|S3'), ('B', float), ('C', int)])
        #assert_equal(test, control)

        # 使用 merge_arrays 函数合并 z 和一个新的 numpy 数组,以避免 pyflakes 警告未使用的变量
        merge_arrays((z, np.array([10, 20, 30]).view([('C', int)])))
        # 创建一个 numpy 数组作为控制组,用于后续的断言比较
        np.array([('A', 1., 10), ('B', 2., 20), ('-1', -1, 20)],
                 dtype=[('A', '|S3'), ('B', float), ('C', int)])

    def test_singlerecord(self):
        # 从 self.data 中获取 x, y, z 的值
        (_, x, y, z) = self.data
        # 使用 merge_arrays 函数测试单个记录的合并,关闭掩码使用
        test = merge_arrays((x[0], y[0], z[0]), usemask=False)
        # 创建一个 numpy 数组作为控制组,用于断言比较
        control = np.array([(1, 10, ('A', 1))],
                           dtype=[('f0', int),
                                  ('f1', int),
                                  ('f2', [('A', '|S3'), ('B', float)])])
        # 断言测试结果与控制组相等
        assert_equal(test, control)
class TestAppendFields:
    # Test append_fields

    def setup_method(self):
        # 设置测试环境,初始化几个 NumPy 数组
        x = np.array([1, 2, ])
        y = np.array([10, 20, 30])
        z = np.array(
            [('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
                     dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
        # 将这些数组组成一个元组并赋值给实例变量 self.data
        self.data = (w, x, y, z)

    def test_append_single(self):
        # Test simple case
        # 解包 self.data 中的数组
        (_, x, _, _) = self.data
        # 调用 append_fields 函数,添加单个字段 'A' 到数组 x
        test = append_fields(x, 'A', data=[10, 20, 30])
        # 创建期望的控制结果
        control = ma.array([(1, 10), (2, 20), (-1, 30)],
                           mask=[(0, 0), (0, 0), (1, 0)],
                           dtype=[('f0', int), ('A', int)],)
        # 断言测试结果与期望结果相等
        assert_equal(test, control)

    def test_append_double(self):
        # Test simple case
        # 解包 self.data 中的数组
        (_, x, _, _) = self.data
        # 调用 append_fields 函数,添加两个字段 'A' 和 'B' 到数组 x
        test = append_fields(x, ('A', 'B'), data=[[10, 20, 30], [100, 200]])
        # 创建期望的控制结果
        control = ma.array([(1, 10, 100), (2, 20, 200), (-1, 30, -1)],
                           mask=[(0, 0, 0), (0, 0, 0), (1, 0, 1)],
                           dtype=[('f0', int), ('A', int), ('B', int)],)
        # 断言测试结果与期望结果相等
        assert_equal(test, control)

    def test_append_on_flex(self):
        # Test append_fields on flexible type arrays
        # 获取 self.data 中的最后一个数组 z
        z = self.data[-1]
        # 调用 append_fields 函数,在数组 z 上添加字段 'C'
        test = append_fields(z, 'C', data=[10, 20, 30])
        # 创建期望的控制结果
        control = ma.array([('A', 1., 10), ('B', 2., 20), (-1, -1., 30)],
                           mask=[(0, 0, 0), (0, 0, 0), (1, 1, 0)],
                           dtype=[('A', '|S3'), ('B', float), ('C', int)],)
        # 断言测试结果与期望结果相等
        assert_equal(test, control)

    def test_append_on_nested(self):
        # Test append_fields on nested fields
        # 获取 self.data 中的第一个数组 w
        w = self.data[0]
        # 调用 append_fields 函数,在数组 w 上添加字段 'C'
        test = append_fields(w, 'C', data=[10, 20, 30])
        # 创建期望的控制结果
        control = ma.array([(1, (2, 3.0), 10),
                            (4, (5, 6.0), 20),
                            (-1, (-1, -1.), 30)],
                           mask=[(
                               0, (0, 0), 0), (0, (0, 0), 0), (1, (1, 1), 0)],
                           dtype=[('a', int),
                                  ('b', [('ba', float), ('bb', int)]),
                                  ('C', int)],)
        # 断言测试结果与期望结果相等
        assert_equal(test, control)


class TestStackArrays:
    # Test stack_arrays
    def setup_method(self):
        # 设置测试环境,初始化几个 NumPy 数组
        x = np.array([1, 2, ])
        y = np.array([10, 20, 30])
        z = np.array(
            [('A', 1.), ('B', 2.)], dtype=[('A', '|S3'), ('B', float)])
        w = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
                     dtype=[('a', int), ('b', [('ba', float), ('bb', int)])])
        # 将这些数组组成一个元组并赋值给实例变量 self.data
        self.data = (w, x, y, z)

    def test_solo(self):
        # Test stack_arrays on single arrays
        # 解包 self.data 中的数组
        (_, x, _, _) = self.data
        # 调用 stack_arrays 函数,堆叠单个数组 x
        test = stack_arrays((x,))
        # 断言测试结果与 x 相等
        assert_equal(test, x)
        # 断言测试结果与 x 是同一个对象
        assert_(test is x)

        # 再次调用 stack_arrays 函数,直接传递数组 x
        test = stack_arrays(x)
        # 断言测试结果与 x 相等
        assert_equal(test, x)
        # 断言测试结果与 x 是同一个对象
        assert_(test is x)
    def test_unnamed_fields(self):
        # Tests combinations of arrays w/o named fields
        # 解构 self.data 元组,获取第二个和第三个元素作为 x 和 y
        (_, x, y, _) = self.data

        # 使用 stack_arrays 函数堆叠两个 x 数组,不使用掩码
        test = stack_arrays((x, x), usemask=False)
        # 创建控制数组,预期结果为 [1, 2, 1, 2]
        control = np.array([1, 2, 1, 2])
        # 断言测试结果与控制数组相等
        assert_equal(test, control)

        # 使用 stack_arrays 函数堆叠 x 和 y 数组,不使用掩码
        test = stack_arrays((x, y), usemask=False)
        # 创建控制数组,预期结果为 [1, 2, 10, 20, 30]
        control = np.array([1, 2, 10, 20, 30])
        # 断言测试结果与控制数组相等
        assert_equal(test, control)

        # 使用 stack_arrays 函数堆叠 y 和 x 数组,不使用掩码
        test = stack_arrays((y, x), usemask=False)
        # 创建控制数组,预期结果为 [10, 20, 30, 1, 2]
        control = np.array([10, 20, 30, 1, 2])
        # 断言测试结果与控制数组相等
        assert_equal(test, control)

    def test_unnamed_and_named_fields(self):
        # Test combination of arrays w/ & w/o named fields
        # 解构 self.data 元组,获取第二个和第四个元素作为 x 和 z
        (_, x, _, z) = self.data

        # 使用 stack_arrays 函数堆叠 x 和 z 数组
        test = stack_arrays((x, z))
        # 创建控制数组,包含元组和掩码,具体内容见下文
        control = ma.array([(1, -1, -1), (2, -1, -1),
                            (-1, 'A', 1), (-1, 'B', 2)],
                           mask=[(0, 1, 1), (0, 1, 1),
                                 (1, 0, 0), (1, 0, 0)],
                           dtype=[('f0', int), ('A', '|S3'), ('B', float)])
        # 断言测试结果与控制数组相等
        assert_equal(test, control)
        # 断言测试结果的掩码与控制数组的掩码相等
        assert_equal(test.mask, control.mask)

        # 使用 stack_arrays 函数堆叠 z 和 x 数组
        test = stack_arrays((z, x))
        # 创建控制数组,包含元组和掩码,具体内容见下文
        control = ma.array([('A', 1, -1), ('B', 2, -1),
                            (-1, -1, 1), (-1, -1, 2), ],
                           mask=[(0, 0, 1), (0, 0, 1),
                                 (1, 1, 0), (1, 1, 0)],
                           dtype=[('A', '|S3'), ('B', float), ('f2', int)])
        # 断言测试结果与控制数组相等
        assert_equal(test, control)
        # 断言测试结果的掩码与控制数组的掩码相等
        assert_equal(test.mask, control.mask)

        # 使用 stack_arrays 函数堆叠 z, z 和 x 数组
        test = stack_arrays((z, z, x))
        # 创建控制数组,包含元组和掩码,具体内容见下文
        control = ma.array([('A', 1, -1), ('B', 2, -1),
                            ('A', 1, -1), ('B', 2, -1),
                            (-1, -1, 1), (-1, -1, 2), ],
                           mask=[(0, 0, 1), (0, 0, 1),
                                 (0, 0, 1), (0, 0, 1),
                                 (1, 1, 0), (1, 1, 0)],
                           dtype=[('A', '|S3'), ('B', float), ('f2', int)])
        # 断言测试结果与控制数组相等
        assert_equal(test, control)
    def test_matching_named_fields(self):
        # Test combination of arrays w/ matching field names

        # 解构赋值,从 self.data 中获取第二个和第四个元素作为 x 和 z
        (_, x, _, z) = self.data
        
        # 创建一个 NumPy 结构化数组 zz,包含三个元组
        zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
                      dtype=[('A', '|S3'), ('B', float), ('C', float)])
        
        # 将 z 和 zz 数组堆叠在一起
        test = stack_arrays((z, zz))
        
        # 创建一个控制用的 MaskedArray 控制变量
        control = ma.array([('A', 1, -1), ('B', 2, -1),
                            ('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
                           dtype=[('A', '|S3'), ('B', float), ('C', float)],
                           mask=[(0, 0, 1), (0, 0, 1),
                                 (0, 0, 0), (0, 0, 0), (0, 0, 0)])
        
        # 断言 test 是否等于 control
        assert_equal(test, control)
        
        # 断言 test 的 mask 是否等于 control 的 mask
        assert_equal(test.mask, control.mask)

        # 将 z、zz 和 x 三个数组堆叠在一起
        test = stack_arrays((z, zz, x))
        
        # 创建一个包含 'f3' 字段的结构化数组的 dtype
        ndtype = [('A', '|S3'), ('B', float), ('C', float), ('f3', int)]
        
        # 创建另一个控制用的 MaskedArray 控制变量
        control = ma.array([('A', 1, -1, -1), ('B', 2, -1, -1),
                            ('a', 10., 100., -1), ('b', 20., 200., -1),
                            ('c', 30., 300., -1),
                            (-1, -1, -1, 1), (-1, -1, -1, 2)],
                           dtype=ndtype,
                           mask=[(0, 0, 1, 1), (0, 0, 1, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1), (0, 0, 0, 1),
                                 (1, 1, 1, 0), (1, 1, 1, 0)])
        
        # 断言 test 是否等于 control
        assert_equal(test, control)
        
        # 断言 test 的 mask 是否等于 control 的 mask
        assert_equal(test.mask, control.mask)

    def test_defaults(self):
        # Test defaults: no exception raised if keys of defaults are not fields.

        # 解构赋值,从 self.data 中获取第四个元素作为 z
        (_, _, _, z) = self.data
        
        # 创建一个 NumPy 结构化数组 zz,包含三个元组
        zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
                      dtype=[('A', '|S3'), ('B', float), ('C', float)])
        
        # 定义默认值字典
        defaults = {'A': '???', 'B': -999., 'C': -9999., 'D': -99999.}
        
        # 使用默认值字典将 z 和 zz 数组堆叠在一起
        test = stack_arrays((z, zz), defaults=defaults)
        
        # 创建一个控制用的 MaskedArray 控制变量
        control = ma.array([('A', 1, -9999.), ('B', 2, -9999.),
                            ('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
                           dtype=[('A', '|S3'), ('B', float), ('C', float)],
                           mask=[(0, 0, 1), (0, 0, 1),
                                 (0, 0, 0), (0, 0, 0), (0, 0, 0)])
        
        # 断言 test 是否等于 control
        assert_equal(test, control)
        
        # 断言 test 的数据部分是否等于 control 的数据部分
        assert_equal(test.data, control.data)
        
        # 断言 test 的 mask 是否等于 control 的 mask
        assert_equal(test.mask, control.mask)
    def test_autoconversion(self):
        # Tests autoconversion
        # 定义自动转换的数据类型列表
        adtype = [('A', int), ('B', bool), ('C', float)]
        # 创建数组 a,设置数据和掩码,使用指定的数据类型 adtype
        a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
        # 定义数据类型列表 bdtype
        bdtype = [('A', int), ('B', float), ('C', float)]
        # 创建数组 b,设置数据,使用指定的数据类型 bdtype
        b = ma.array([(4, 5, 6)], dtype=bdtype)
        # 创建控制数组 control,包含数组 a 和 b 的组合数据,设置掩码,使用数据类型 bdtype
        control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
                           dtype=bdtype)
        # 使用函数 stack_arrays 进行数组堆叠,启用自动转换
        test = stack_arrays((a, b), autoconvert=True)
        # 断言 test 和 control 数组相等
        assert_equal(test, control)
        # 断言 test 的掩码和 control 的掩码相等
        assert_equal(test.mask, control.mask)
        # 使用 assert_raises 检查 TypeError 异常是否被触发
        with assert_raises(TypeError):
            stack_arrays((a, b), autoconvert=False)

    def test_checktitles(self):
        # Test using titles in the field names
        # 定义带标题的数据类型列表 adtype
        adtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
        # 创建数组 a,设置数据和掩码,使用指定的数据类型 adtype
        a = ma.array([(1, 2, 3)], mask=[(0, 1, 0)], dtype=adtype)
        # 定义带标题的数据类型列表 bdtype
        bdtype = [(('a', 'A'), int), (('b', 'B'), bool), (('c', 'C'), float)]
        # 创建数组 b,设置数据,使用指定的数据类型 bdtype
        b = ma.array([(4, 5, 6)], dtype=bdtype)
        # 使用函数 stack_arrays 进行数组堆叠
        test = stack_arrays((a, b))
        # 创建控制数组 control,包含数组 a 和 b 的组合数据,设置掩码,使用数据类型 bdtype
        control = ma.array([(1, 2, 3), (4, 5, 6)], mask=[(0, 1, 0), (0, 0, 0)],
                           dtype=bdtype)
        # 断言 test 和 control 数组相等
        assert_equal(test, control)
        # 断言 test 的掩码和 control 的掩码相等
        assert_equal(test.mask, control.mask)

    def test_subdtype(self):
        # 创建数组 z,设置数据和子数据类型,使用指定的数据类型
        z = np.array([
            ('A', 1), ('B', 2)
        ], dtype=[('A', '|S3'), ('B', float, (1,))])
        # 创建数组 zz,设置数据和子数据类型,使用指定的数据类型
        zz = np.array([
            ('a', [10.], 100.), ('b', [20.], 200.), ('c', [30.], 300.)
        ], dtype=[('A', '|S3'), ('B', float, (1,)), ('C', float)])
        # 使用函数 stack_arrays 进行数组堆叠
        res = stack_arrays((z, zz))
        # 创建期望的掩码数组 expected,设置数据、掩码和数据类型
        expected = ma.array(
            data=[
                (b'A', [1.0], 0),
                (b'B', [2.0], 0),
                (b'a', [10.0], 100.0),
                (b'b', [20.0], 200.0),
                (b'c', [30.0], 300.0)],
            mask=[
                (False, [False],  True),
                (False, [False],  True),
                (False, [False], False),
                (False, [False], False),
                (False, [False], False)
            ],
            dtype=zz.dtype
        )
        # 断言 res 的数据类型和 expected 的数据类型相等
        assert_equal(res.dtype, expected.dtype)
        # 断言 res 和 expected 数组相等
        assert_equal(res, expected)
        # 断言 res 的掩码和 expected 的掩码相等
        assert_equal(res.mask, expected.mask)
class TestJoinBy:
    # 在测试类中设置方法的初始状态
    def setup_method(self):
        # 创建包含三个字段的 NumPy 数组 self.a,每个字段都是从不同范围的整数值中创建的
        self.a = np.array(list(zip(np.arange(10), np.arange(50, 60),
                                   np.arange(100, 110))),
                          dtype=[('a', int), ('b', int), ('c', int)])
        # 创建包含三个字段的 NumPy 数组 self.b,与 self.a 结构相同,但字段不完全一致
        self.b = np.array(list(zip(np.arange(5, 15), np.arange(65, 75),
                                   np.arange(100, 110))),
                          dtype=[('a', int), ('b', int), ('d', int)])

    # 测试内连接的基本功能
    def test_inner_join(self):
        # 获取 self.a 和 self.b 的引用
        a, b = self.a, self.b
        # 使用 'a' 字段进行内连接操作,并将结果存储在 test 变量中
        test = join_by('a', a, b, jointype='inner')
        # 创建预期结果的 NumPy 数组 control,包含特定字段和值
        control = np.array([(5, 55, 65, 105, 100), (6, 56, 66, 106, 101),
                            (7, 57, 67, 107, 102), (8, 58, 68, 108, 103),
                            (9, 59, 69, 109, 104)],
                           dtype=[('a', int), ('b1', int), ('b2', int),
                                  ('c', int), ('d', int)])
        # 断言 test 和 control 数组是否相等
        assert_equal(test, control)

    # 测试 join_by 函数的一般连接功能(目前被注释掉)
    def test_join(self):
        # 获取 self.a 和 self.b 的引用
        a, b = self.a, self.b
        
        # Fixme, this test is broken
        #test = join_by(('a', 'b'), a, b)
        #control = np.array([(5, 55, 105, 100), (6, 56, 106, 101),
        #                    (7, 57, 107, 102), (8, 58, 108, 103),
        #                    (9, 59, 109, 104)],
        #                   dtype=[('a', int), ('b', int),
        #                          ('c', int), ('d', int)])
        #assert_equal(test, control)

        # 使用 join_by 函数,连接 ('a', 'b') 字段,但忽略其返回结果(仅为了避免警告)
        join_by(('a', 'b'), a, b)
        # 创建一个 NumPy 数组,但这个数组并没有被用于任何断言或返回操作
        np.array([(5, 55, 105, 100), (6, 56, 106, 101),
                  (7, 57, 107, 102), (8, 58, 108, 103),
                  (9, 59, 109, 104)],
                  dtype=[('a', int), ('b', int),
                         ('c', int), ('d', int)])

    # 测试子数据类型的连接功能(用于验证已知的 bug)
    def test_join_subdtype(self):
        # 创建包含单个键值对的 NumPy 数组 foo,键 'key' 对应一个整数值
        foo = np.array([(1,)],
                       dtype=[('key', int)])
        # 创建包含一个键值对的 NumPy 数组 bar,键 'key' 对应一个整数值,值 'value' 是一个有三个元素的无符号整数数组
        bar = np.array([(1, np.array([1,2,3]))],
                       dtype=[('key', int), ('value', 'uint16', 3)])
        # 使用 'key' 字段进行连接操作,将结果存储在 res 变量中
        res = join_by('key', foo, bar)
        # 断言 res 和 bar 的视图(MaskedArray)是否相等
        assert_equal(res, bar.view(ma.MaskedArray))
    def test_outer_join(self):
        a, b = self.a, self.b
        # 调用自定义函数join_by进行外连接测试,使用字段('a', 'b'),连接a和b两个数组,连接方式为'outer'
        test = join_by(('a', 'b'), a, b, 'outer')
        # 预期的控制结果,一个带有数据和遮罩的MaskedArray对象
        control = ma.array([(0, 50, 100, -1), (1, 51, 101, -1),
                            (2, 52, 102, -1), (3, 53, 103, -1),
                            (4, 54, 104, -1), (5, 55, 105, -1),
                            (5, 65, -1, 100), (6, 56, 106, -1),
                            (6, 66, -1, 101), (7, 57, 107, -1),
                            (7, 67, -1, 102), (8, 58, 108, -1),
                            (8, 68, -1, 103), (9, 59, 109, -1),
                            (9, 69, -1, 104), (10, 70, -1, 105),
                            (11, 71, -1, 106), (12, 72, -1, 107),
                            (13, 73, -1, 108), (14, 74, -1, 109)],
                           mask=[(0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 1, 0), (0, 0, 0, 1),
                                 (0, 0, 1, 0), (0, 0, 0, 1),
                                 (0, 0, 1, 0), (0, 0, 0, 1),
                                 (0, 0, 1, 0), (0, 0, 0, 1),
                                 (0, 0, 1, 0), (0, 0, 1, 0),
                                 (0, 0, 1, 0), (0, 0, 1, 0),
                                 (0, 0, 1, 0), (0, 0, 1, 0)],
                           dtype=[('a', int), ('b', int),
                                  ('c', int), ('d', int)])
        # 使用断言验证test和control是否相等
        assert_equal(test, control)

    def test_leftouter_join(self):
        a, b = self.a, self.b
        # 调用自定义函数join_by进行左外连接测试,使用字段('a', 'b'),连接a和b两个数组,连接方式为'leftouter'
        test = join_by(('a', 'b'), a, b, 'leftouter')
        # 预期的控制结果,一个带有数据和遮罩的MaskedArray对象
        control = ma.array([(0, 50, 100, -1), (1, 51, 101, -1),
                            (2, 52, 102, -1), (3, 53, 103, -1),
                            (4, 54, 104, -1), (5, 55, 105, -1),
                            (6, 56, 106, -1), (7, 57, 107, -1),
                            (8, 58, 108, -1), (9, 59, 109, -1)],
                           mask=[(0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1),
                                 (0, 0, 0, 1), (0, 0, 0, 1)],
                           dtype=[('a', int), ('b', int), ('c', int), ('d', int)])
        # 使用断言验证test和control是否相等
        assert_equal(test, control)

    def test_different_field_order(self):
        # 测试情形gh-8940
        # 创建一个包含3行的全零数组a,dtype为[('a', 'i4'), ('b', 'f4'), ('c', 'u1')]
        a = np.zeros(3, dtype=[('a', 'i4'), ('b', 'f4'), ('c', 'u1')])
        # 创建一个包含3行的全一数组b,dtype为[('c', 'u1'), ('b', 'f4'), ('a', 'i4')]
        b = np.ones(3, dtype=[('c', 'u1'), ('b', 'f4'), ('a', 'i4')])
        # 调用自定义函数join_by进行内连接测试,使用字段['c', 'b'],连接a和b两个数组,连接方式为'inner',不使用遮罩
        j = join_by(['c', 'b'], a, b, jointype='inner', usemask=False)
        # 使用断言验证连接后的dtype的字段名是否为['b', 'c', 'a1', 'a2']
        assert_equal(j.dtype.names, ['b', 'c', 'a1', 'a2'])
    # 测试函数:检查在具有重复键的情况下是否引发 ValueError 异常
    def test_duplicate_keys(self):
        # 创建包含零值的 NumPy 结构化数组 a
        a = np.zeros(3, dtype=[('a', 'i4'), ('b', 'f4'), ('c', 'u1')])
        # 创建包含全一值的 NumPy 结构化数组 b,但是键 'c' 和 'a' 的顺序与 a 中不同
        b = np.ones(3, dtype=[('c', 'u1'), ('b', 'f4'), ('a', 'i4')])
        # 断言调用 join_by 函数时会引发 ValueError 异常
        assert_raises(ValueError, join_by, ['a', 'b', 'b'], a, b)

    # 测试函数:检查具有相同键但数据类型不同的情况下是否正确合并
    def test_same_name_different_dtypes_key(self):
        # 定义 a 的数据类型,包含 'key' 和 '<f4' 类型的元组
        a_dtype = np.dtype([('key', 'S5'), ('value', '<f4')])
        # 定义 b 的数据类型,包含 'key' 和 '<f4' 类型的元组,但 'key' 的长度为 10
        b_dtype = np.dtype([('key', 'S10'), ('value', '<f4')])
        # 期望结果的数据类型,包含 'key'、'value1'(<f4)、'value2'(<f4)的元组
        expected_dtype = np.dtype([
            ('key', 'S10'), ('value1', '<f4'), ('value2', '<f4')])

        # 创建 NumPy 结构化数组 a 和 b
        a = np.array([('Sarah',  8.0), ('John', 6.0)], dtype=a_dtype)
        b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype)
        # 调用 join_by 函数,按照 'key' 合并 a 和 b
        res = join_by('key', a, b)

        # 断言合并结果的数据类型与期望的数据类型相同
        assert_equal(res.dtype, expected_dtype)

    # 测试函数:检查具有相同键但数据类型不同的情况下是否正确合并(另一种情况)
    def test_same_name_different_dtypes(self):
        # 定义 a 的数据类型,包含 'key' 和 '<f4' 类型的元组
        a_dtype = np.dtype([('key', 'S10'), ('value', '<f4')])
        # 定义 b 的数据类型,包含 'key' 和 '<f8' 类型的元组
        b_dtype = np.dtype([('key', 'S10'), ('value', '<f8')])
        # 期望结果的数据类型,包含 'key'、'value1'(<f4)、'value2'(<f8)的元组
        expected_dtype = np.dtype([
            ('key', '|S10'), ('value1', '<f4'), ('value2', '<f8')])

        # 创建 NumPy 结构化数组 a 和 b
        a = np.array([('Sarah',  8.0), ('John', 6.0)], dtype=a_dtype)
        b = np.array([('Sarah', 10.0), ('John', 7.0)], dtype=b_dtype)
        # 调用 join_by 函数,按照 'key' 合并 a 和 b
        res = join_by('key', a, b)

        # 断言合并结果的数据类型与期望的数据类型相同
        assert_equal(res.dtype, expected_dtype)

    # 测试函数:检查包含子数组键的情况下是否正确合并
    def test_subarray_key(self):
        # 定义 a 的数据类型,包含 'pos' 和 '<f4' 类型的元组,'pos' 是一个长度为 3 的整数数组
        a_dtype = np.dtype([('pos', int, 3), ('f', '<f4')])
        # 创建 NumPy 结构化数组 a
        a = np.array([([1, 1, 1], np.pi), ([1, 2, 3], 0.0)], dtype=a_dtype)

        # 定义 b 的数据类型,包含 'pos' 和 '<f4' 类型的元组,'pos' 是一个长度为 3 的整数数组
        b_dtype = np.dtype([('pos', int, 3), ('g', '<f4')])
        # 创建 NumPy 结构化数组 b
        b = np.array([([1, 1, 1], 3), ([3, 2, 1], 0.0)], dtype=b_dtype)

        # 期望结果的数据类型,包含 'pos'、'f'(<f4)、'g'(<f4)的元组
        expected_dtype = np.dtype([('pos', int, 3), ('f', '<f4'), ('g', '<f4')])
        # 期望的合并结果
        expected = np.array([([1, 1, 1], np.pi, 3)], dtype=expected_dtype)

        # 调用 join_by 函数,按照 'pos' 合并 a 和 b
        res = join_by('pos', a, b)

        # 断言合并结果的数据类型与期望的数据类型相同
        assert_equal(res.dtype, expected_dtype)
        # 断言合并结果与期望的结果数组相同
        assert_equal(res, expected)

    # 测试函数:检查带填充字段的数据类型是否正确处理
    def test_padded_dtype(self):
        # 定义数据类型 dt,包含一个 'i1' 和 '<f4' 类型的元组,并指定对齐为真
        dt = np.dtype('i1,f4', align=True)
        dt.names = ('k', 'v')
        # 断言 dt 的描述中包含 3 个元素,表示已插入填充字段
        assert_(len(dt.descr), 3)

        # 创建 NumPy 结构化数组 a 和 b,使用定义好的数据类型 dt
        a = np.array([(1, 3), (3, 2)], dt)
        b = np.array([(1, 1), (2, 2)], dt)
        # 调用 join_by 函数,按照 'k' 合并 a 和 b
        res = join_by('k', a, b)

        # 期望结果的数据类型,包含 'k'、'v1'('f4')、'v2'('f4')的元组
        expected_dtype = np.dtype([
            ('k', 'i1'), ('v1', 'f4'), ('v2', 'f4')
        ])

        # 断言合并结果的数据类型与期望的数据类型相同
        assert_equal(res.dtype, expected_dtype)
class TestJoinBy2:
    @classmethod
    def setup_method(cls):
        # 设置测试环境方法,创建两个结构化数组 a 和 b
        cls.a = np.array(list(zip(np.arange(10), np.arange(50, 60),
                                  np.arange(100, 110))),
                         dtype=[('a', int), ('b', int), ('c', int)])
        cls.b = np.array(list(zip(np.arange(10), np.arange(65, 75),
                                  np.arange(100, 110))),
                         dtype=[('a', int), ('b', int), ('d', int)])

    def test_no_r1postfix(self):
        # join_by 函数的基本测试,使用 r1postfix='' 和 r2postfix='2' 进行内连接
        a, b = self.a, self.b

        test = join_by(
            'a', a, b, r1postfix='', r2postfix='2', jointype='inner')
        # 预期的结果数组
        control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
                            (2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
                            (4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
                            (6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
                            (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
                           dtype=[('a', int), ('b', int), ('b2', int),
                                  ('c', int), ('d', int)])
        assert_equal(test, control)

    def test_no_postfix(self):
        # 测试异常情况:r1postfix='' 和 r2postfix='' 都为空时,应引发 ValueError
        assert_raises(ValueError, join_by, 'a', self.a, self.b,
                      r1postfix='', r2postfix='')

    def test_no_r2postfix(self):
        # join_by 函数的基本测试,使用 r1postfix='1' 和 r2postfix='' 进行内连接
        a, b = self.a, self.b

        test = join_by(
            'a', a, b, r1postfix='1', r2postfix='', jointype='inner')
        # 预期的结果数组
        control = np.array([(0, 50, 65, 100, 100), (1, 51, 66, 101, 101),
                            (2, 52, 67, 102, 102), (3, 53, 68, 103, 103),
                            (4, 54, 69, 104, 104), (5, 55, 70, 105, 105),
                            (6, 56, 71, 106, 106), (7, 57, 72, 107, 107),
                            (8, 58, 73, 108, 108), (9, 59, 74, 109, 109)],
                           dtype=[('a', int), ('b1', int), ('b', int),
                                  ('c', int), ('d', int)])
        assert_equal(test, control)
    # 定义一个测试函数,测试两个键和两个变量的情况
    def test_two_keys_two_vars(self):
        # 创建一个 NumPy 数组 `a`,其中每个元素是一个元组,元组中包含四个字段
        # 第一个字段 `k` 是重复的数组 [10, 11],重复 5 次
        # 第二个字段 `a` 是重复的数组 [0, 1, 2, 3, 4],每个元素重复 2 次
        # 第三个字段 `b` 是从 50 到 59 的连续数组
        # 第四个字段 `c` 是从 10 到 19 的连续数组
        a = np.array(list(zip(np.tile([10, 11], 5), np.repeat(np.arange(5), 2),
                              np.arange(50, 60), np.arange(10, 20))),
                     dtype=[('k', int), ('a', int), ('b', int), ('c', int)])

        # 创建另一个 NumPy 数组 `b`,结构与 `a` 相同
        # 但第三个字段 `b` 的值是从 65 到 74 的连续数组
        # 第四个字段 `c` 的值是从 0 到 9 的连续数组
        b = np.array(list(zip(np.tile([10, 11], 5), np.repeat(np.arange(5), 2),
                              np.arange(65, 75), np.arange(0, 10))),
                     dtype=[('k', int), ('a', int), ('b', int), ('c', int)])

        # 创建一个控制用的 NumPy 数组 `control`
        # 包含了预期的合并结果,有六个字段
        control = np.array([(10, 0, 50, 65, 10, 0), (11, 0, 51, 66, 11, 1),
                            (10, 1, 52, 67, 12, 2), (11, 1, 53, 68, 13, 3),
                            (10, 2, 54, 69, 14, 4), (11, 2, 55, 70, 15, 5),
                            (10, 3, 56, 71, 16, 6), (11, 3, 57, 72, 17, 7),
                            (10, 4, 58, 73, 18, 8), (11, 4, 59, 74, 19, 9)],
                           dtype=[('k', int), ('a', int), ('b1', int),
                                  ('b2', int), ('c1', int), ('c2', int)])

        # 调用 `join_by` 函数进行数组合并测试
        # 以 ['a', 'k'] 为键,将数组 `a` 和 `b` 进行内连接
        # 设置后缀 '1' 和 '2',合并后的结果命名为 `test`
        test = join_by(
            ['a', 'k'], a, b, r1postfix='1', r2postfix='2', jointype='inner')

        # 断言合并后结果 `test` 的数据类型与 `control` 的数据类型相同
        assert_equal(test.dtype, control.dtype)
        
        # 断言合并后的结果 `test` 与预期的 `control` 结果相等
        assert_equal(test, control)
class TestAppendFieldsObj:
    """
    Test append_fields with arrays containing objects
    """
    # https://github.com/numpy/numpy/issues/2346

    def setup_method(self):
        # 在测试方法执行前设置数据对象,使用了日期对象作为示例数据
        from datetime import date
        self.data = dict(obj=date(2000, 1, 1))

    def test_append_to_objects(self):
        # 测试在基础数组包含对象时使用 append_fields 函数
        obj = self.data['obj']
        # 创建一个包含对象和浮点数的 NumPy 数组,定义了自定义数据类型
        x = np.array([(obj, 1.), (obj, 2.)],
                     dtype=[('A', object), ('B', float)])
        # 创建一个整数类型的 NumPy 数组
        y = np.array([10, 20], dtype=int)
        # 调用 append_fields 函数将整数数组 y 添加为 x 数组的新字段 'C'
        test = append_fields(x, 'C', data=y, usemask=False)
        # 创建一个控制组数组,验证是否正确追加了字段 'C'
        control = np.array([(obj, 1.0, 10), (obj, 2.0, 20)],
                           dtype=[('A', object), ('B', float), ('C', int)])
        # 使用断言检查测试结果是否与控制组一致
        assert_equal(test, control)