NumPy 源码解析(八十五)
.\numpy\numpy\_core\tests\test_numerictypes.py
# 导入必要的模块和库
import sys # 导入 sys 模块
import itertools # 导入 itertools 模块
import pytest # 导入 pytest 模块
import numpy as np # 导入 NumPy 库并重命名为 np
import numpy._core.numerictypes as nt # 导入 NumPy 的内部模块 numerictypes,并重命名为 nt
from numpy._core.numerictypes import ( # 从 numerictypes 模块导入多个函数和类
issctype, sctype2char, maximum_sctype, sctypes
)
from numpy.testing import ( # 从 testing 模块导入多个函数和类
assert_, assert_equal, assert_raises, assert_raises_regex, IS_PYPY
)
# This is the structure of the table used for plain objects:
#
# +-+-+-+
# |x|y|z|
# +-+-+-+
# Structure of a plain array description:
# 定义一个描述普通对象的结构
Pdescr = [
('x', 'i4', (2,)), # 第一个字段 x,整数类型,形状为 (2,)
('y', 'f8', (2, 2)), # 第二个字段 y,浮点数类型,形状为 (2, 2)
('z', 'u1') # 第三个字段 z,无符号整数类型
]
# A plain list of tuples with values for testing:
# 用于测试的普通元组列表,包含了各个字段的数值
PbufferT = [
# x y z
([3, 2], [[6., 4.], [6., 4.]], 8), # 第一组数据
([4, 3], [[7., 5.], [7., 5.]], 9) # 第二组数据
]
# This is the structure of the table used for nested objects (DON'T PANIC!):
#
# +-+---------------------------------+-----+----------+-+-+
# |x|Info |color|info |y|z|
# | +-----+--+----------------+----+--+ +----+-----+ | |
# | |value|y2|Info2 |name|z2| |Name|Value| | |
# | | | +----+-----+--+--+ | | | | | | |
# | | | |name|value|y3|z3| | | | | | | |
# +-+-----+--+----+-----+--+--+----+--+-----+----+-----+-+-+
#
# The corresponding nested array description:
# 定义一个描述嵌套对象的结构
Ndescr = [
('x', 'i4', (2,)), # 第一个字段 x,整数类型,形状为 (2,)
('Info', [ # 第二个字段 Info,包含多个子字段
('value', 'c16'), # value,复数类型
('y2', 'f8'), # y2,浮点数类型
('Info2', [ # Info2 包含多个子字段
('name', 'S2'), # name,字符串类型,长度为 2
('value', 'c16', (2,)), # value,复数类型,形状为 (2,)
('y3', 'f8', (2,)), # y3,浮点数类型,形状为 (2,)
('z3', 'u4', (2,)) # z3,无符号整数类型,形状为 (2,)
]),
('name', 'S2'), # name,字符串类型,长度为 2
('z2', 'b1') # z2,布尔类型
]),
('color', 'S2'), # 第三个字段 color,字符串类型,长度为 2
('info', [ # 第四个字段 info,包含多个子字段
('Name', 'U8'), # Name,Unicode 字符串类型,长度为 8
('Value', 'c16') # Value,复数类型
]),
('y', 'f8', (2, 2)), # 第五个字段 y,浮点数类型,形状为 (2, 2)
('z', 'u1') # 第六个字段 z,无符号整数类型
]
NbufferT = [
# x Info color info y z
# value y2 Info2 name z2 Name Value
# name value y3 z3
([3, 2], (6j, 6., (b'nn', [6j, 4j], [6., 4.], [1, 2]), b'NN', True),
b'cc', ('NN', 6j), [[6., 4.], [6., 4.]], 8), # 第一组数据
([4, 3], (7j, 7., (b'oo', [7j, 5j], [7., 5.], [2, 1]), b'OO', False),
b'dd', ('OO', 7j), [[7., 5.], [7., 5.]], 9) # 第二组数据
]
byteorder = {'little': '<', 'big': '>'}[sys.byteorder] # 根据系统字节顺序选择对应的符号
def normalize_descr(descr):
"Normalize a description adding the platform byteorder."
# 将描述符描述中的字节顺序标准化为当前平台的字节顺序
out = []
# 遍历描述符列表中的每一项
for item in descr:
# 提取当前项的数据类型
dtype = item[1]
# 如果数据类型是字符串类型
if isinstance(dtype, str):
# 检查数据类型字符串的第一个字符,判断是否需要调整字节顺序
if dtype[0] not in ['|', '<', '>']:
# 检查是否为单字节数据或特定数据类型,若是则调整数据类型格式
onebyte = dtype[1:] == "1"
if onebyte or dtype[0] in ['S', 'V', 'b']:
dtype = "|" + dtype
else:
dtype = byteorder + dtype
# 检查当前项是否包含附加信息,并且附加信息的元素数量大于1
if len(item) > 2 and np.prod(item[2]) > 1:
nitem = (item[0], dtype, item[2]) # 构建新的元组项
else:
nitem = (item[0], dtype) # 构建新的元组项
# 将新的元组项添加到输出列表中
out.append(nitem)
# 如果数据类型是列表类型
elif isinstance(dtype, list):
# 递归调用normalize_descr函数处理列表数据类型,并将处理结果作为元组项添加到输出列表中
l = normalize_descr(dtype)
out.append((item[0], l))
# 如果数据类型既不是字符串类型也不是列表类型,则抛出异常
else:
raise ValueError("Expected a str or list and got %s" % (type(item)))
# 返回处理后的输出列表
return out
############################################################
# Creation tests
############################################################
class CreateZeros:
"""Check the creation of heterogeneous arrays zero-valued"""
def test_zeros0D(self):
"""Check creation of 0-dimensional objects"""
# 创建一个零维数组对象,使用给定的描述符类型
h = np.zeros((), dtype=self._descr)
# 断言描述符已被规范化,并与数组的数据类型描述符相匹配
assert_(normalize_descr(self._descr) == h.dtype.descr)
# 断言字段 'x' 的名称的前四个字符为 'void'
assert_(h.dtype.fields['x'][0].name[:4] == 'void')
# 断言字段 'x' 的字符为 'V',表示它是 void 类型
assert_(h.dtype.fields['x'][0].char == 'V')
# 断言字段 'x' 的类型为 np.void
assert_(h.dtype.fields['x'][0].type == np.void)
# 对数据的一个小检查,确保数据正确初始化为零
assert_equal(h['z'], np.zeros((), dtype='u1'))
def test_zerosSD(self):
"""Check creation of single-dimensional objects"""
# 创建一个单维数组对象,使用给定的描述符类型
h = np.zeros((2,), dtype=self._descr)
# 断言描述符已被规范化,并与数组的数据类型描述符相匹配
assert_(normalize_descr(self._descr) == h.dtype.descr)
# 断言字段 'y' 的名称的前四个字符为 'void'
assert_(h.dtype['y'].name[:4] == 'void')
# 断言字段 'y' 的字符为 'V',表示它是 void 类型
assert_(h.dtype['y'].char == 'V')
# 断言字段 'y' 的类型为 np.void
assert_(h.dtype['y'].type == np.void)
# 对数据的一个小检查,确保数据正确初始化为零
assert_equal(h['z'], np.zeros((2,), dtype='u1'))
def test_zerosMD(self):
"""Check creation of multi-dimensional objects"""
# 创建一个多维数组对象,使用给定的描述符类型
h = np.zeros((2, 3), dtype=self._descr)
# 断言描述符已被规范化,并与数组的数据类型描述符相匹配
assert_(normalize_descr(self._descr) == h.dtype.descr)
# 断言字段 'z' 的名称为 'uint8'
assert_(h.dtype['z'].name == 'uint8')
# 断言字段 'z' 的字符为 'B',表示它是无符号字节类型
assert_(h.dtype['z'].char == 'B')
# 断言字段 'z' 的类型为 np.uint8
assert_(h.dtype['z'].type == np.uint8)
# 对数据的一个小检查,确保数据正确初始化为零
assert_equal(h['z'], np.zeros((2, 3), dtype='u1'))
class TestCreateZerosPlain(CreateZeros):
"""Check the creation of heterogeneous arrays zero-valued (plain)"""
_descr = Pdescr
class TestCreateZerosNested(CreateZeros):
"""Check the creation of heterogeneous arrays zero-valued (nested)"""
_descr = Ndescr
class CreateValues:
"""Check the creation of heterogeneous arrays with values"""
def test_tuple(self):
"""Check creation from tuples"""
# 从元组创建数组对象,使用给定的描述符类型
h = np.array(self._buffer, dtype=self._descr)
# 断言描述符已被规范化,并与数组的数据类型描述符相匹配
assert_(normalize_descr(self._descr) == h.dtype.descr)
# 如果存在多行数据,断言数组的形状为 (2,)
if self.multiple_rows:
assert_(h.shape == (2,))
else:
assert_(h.shape == ())
def test_list_of_tuple(self):
"""Check creation from list of tuples"""
# 从元组列表创建数组对象,使用给定的描述符类型
h = np.array([self._buffer], dtype=self._descr)
# 断言描述符已被规范化,并与数组的数据类型描述符相匹配
assert_(normalize_descr(self._descr) == h.dtype.descr)
# 如果存在多行数据,断言数组的形状为 (1, 2)
if self.multiple_rows:
assert_(h.shape == (1, 2))
else:
assert_(h.shape == (1,))
def test_list_of_list_of_tuple(self):
"""Check creation from list of list of tuples"""
# 从元组列表的列表创建数组对象,使用给定的描述符类型
h = np.array([[self._buffer]], dtype=self._descr)
# 断言描述符已被规范化,并与数组的数据类型描述符相匹配
assert_(normalize_descr(self._descr) == h.dtype.descr)
# 如果存在多行数据,断言数组的形状为 (1, 1, 2)
if self.multiple_rows:
assert_(h.shape == (1, 1, 2))
else:
assert_(h.shape == (1, 1))
class TestCreateValuesPlainSingle(CreateValues):
"""Check the creation of heterogeneous arrays (plain, single row)"""
# 将 Pdescr 赋值给 _descr 变量
_descr = Pdescr
# 初始化 multiple_rows 变量为 0
multiple_rows = 0
# 从 PbufferT 列表中取出第一个元素,并赋值给 _buffer 变量
_buffer = PbufferT[0]
class TestCreateValuesPlainMultiple(CreateValues):
"""Check the creation of heterogeneous arrays (plain, multiple rows)"""
_descr = Pdescr # 设置描述符为 Pdescr
multiple_rows = 1 # 多行标志设为 1
_buffer = PbufferT # 缓冲区设为 PbufferT
class TestCreateValuesNestedSingle(CreateValues):
"""Check the creation of heterogeneous arrays (nested, single row)"""
_descr = Ndescr # 设置描述符为 Ndescr
multiple_rows = 0 # 多行标志设为 0
_buffer = NbufferT[0] # 缓冲区设为 NbufferT 的第一个元素
class TestCreateValuesNestedMultiple(CreateValues):
"""Check the creation of heterogeneous arrays (nested, multiple rows)"""
_descr = Ndescr # 设置描述符为 Ndescr
multiple_rows = 1 # 多行标志设为 1
_buffer = NbufferT # 缓冲区设为 NbufferT
############################################################
# Reading tests
############################################################
class ReadValuesPlain:
"""Check the reading of values in heterogeneous arrays (plain)"""
def test_access_fields(self):
# 使用描述符和缓冲区创建 NumPy 数组 h
h = np.array(self._buffer, dtype=self._descr)
if not self.multiple_rows:
assert_(h.shape == ()) # 断言 h 的形状为空
assert_equal(h['x'], np.array(self._buffer[0], dtype='i4')) # 断言 h 的 'x' 值等于缓冲区第一个元素的 'i4' 类型数组
assert_equal(h['y'], np.array(self._buffer[1], dtype='f8')) # 断言 h 的 'y' 值等于缓冲区第二个元素的 'f8' 类型数组
assert_equal(h['z'], np.array(self._buffer[2], dtype='u1')) # 断言 h 的 'z' 值等于缓冲区第三个元素的 'u1' 类型数组
else:
assert_(len(h) == 2) # 断言 h 的长度为 2
assert_equal(h['x'], np.array([self._buffer[0][0],
self._buffer[1][0]], dtype='i4')) # 断言 h 的 'x' 值等于缓冲区第一列 'i4' 类型数组
assert_equal(h['y'], np.array([self._buffer[0][1],
self._buffer[1][1]], dtype='f8')) # 断言 h 的 'y' 值等于缓冲区第二列 'f8' 类型数组
assert_equal(h['z'], np.array([self._buffer[0][2],
self._buffer[1][2]], dtype='u1')) # 断言 h 的 'z' 值等于缓冲区第三列 'u1' 类型数组
class TestReadValuesPlainSingle(ReadValuesPlain):
"""Check the creation of heterogeneous arrays (plain, single row)"""
_descr = Pdescr # 设置描述符为 Pdescr
multiple_rows = 0 # 多行标志设为 0
_buffer = PbufferT[0] # 缓冲区设为 PbufferT 的第一个元素
class TestReadValuesPlainMultiple(ReadValuesPlain):
"""Check the values of heterogeneous arrays (plain, multiple rows)"""
_descr = Pdescr # 设置描述符为 Pdescr
multiple_rows = 1 # 多行标志设为 1
_buffer = PbufferT # 缓冲区设为 PbufferT
class ReadValuesNested:
"""Check the reading of values in heterogeneous arrays (nested)"""
def test_access_top_fields(self):
"""检查访问嵌套数组的顶层字段"""
# 使用 self._buffer 和 self._descr 创建一个 NumPy 数组 h
h = np.array(self._buffer, dtype=self._descr)
# 如果不是多行数据
if not self.multiple_rows:
# 断言 h 的形状为空元组
assert_(h.shape == ())
# 断言 h 中字段 'x' 的值等于 self._buffer[0] 的整数数组
assert_equal(h['x'], np.array(self._buffer[0], dtype='i4'))
# 断言 h 中字段 'y' 的值等于 self._buffer[4] 的浮点数数组
assert_equal(h['y'], np.array(self._buffer[4], dtype='f8'))
# 断言 h 中字段 'z' 的值等于 self._buffer[5] 的无符号整数数组
assert_equal(h['z'], np.array(self._buffer[5], dtype='u1'))
else:
# 断言 h 的长度为 2
assert_(len(h) == 2)
# 断言 h 中字段 'x' 的值等于 self._buffer 中每行的第一个元素的整数数组
assert_equal(h['x'], np.array([self._buffer[0][0],
self._buffer[1][0]], dtype='i4'))
# 断言 h 中字段 'y' 的值等于 self._buffer 中每行的第五个元素的浮点数数组
assert_equal(h['y'], np.array([self._buffer[0][4],
self._buffer[1][4]], dtype='f8'))
# 断言 h 中字段 'z' 的值等于 self._buffer 中每行的第六个元素的无符号整数数组
assert_equal(h['z'], np.array([self._buffer[0][5],
self._buffer[1][5]], dtype='u1'))
def test_nested1_acessors(self):
"""检查访问嵌套数组的嵌套字段(第一层级)"""
# 使用 self._buffer 和 self._descr 创建一个 NumPy 数组 h
h = np.array(self._buffer, dtype=self._descr)
# 如果不是多行数据
if not self.multiple_rows:
# 断言 h 中字段 'Info' 下的 'value' 等于 self._buffer[1][0] 的长度为 16 的字符数组
assert_equal(h['Info']['value'],
np.array(self._buffer[1][0], dtype='c16'))
# 断言 h 中字段 'Info' 下的 'y2' 等于 self._buffer[1][1] 的浮点数数组
assert_equal(h['Info']['y2'],
np.array(self._buffer[1][1], dtype='f8'))
# 断言 h 中字段 'info' 下的 'Name' 等于 self._buffer[3][0] 的长度为 2 的 Unicode 字符数组
assert_equal(h['info']['Name'],
np.array(self._buffer[3][0], dtype='U2'))
# 断言 h 中字段 'info' 下的 'Value' 等于 self._buffer[3][1] 的长度为 16 的字符数组
assert_equal(h['info']['Value'],
np.array(self._buffer[3][1], dtype='c16'))
else:
# 断言 h 中字段 'Info' 下的 'value' 等于 self._buffer 中每行第二个元素的长度为 16 的字符数组
assert_equal(h['Info']['value'],
np.array([self._buffer[0][1][0],
self._buffer[1][1][0]],
dtype='c16'))
# 断言 h 中字段 'Info' 下的 'y2' 等于 self._buffer 中每行第二个元素的浮点数数组
assert_equal(h['Info']['y2'],
np.array([self._buffer[0][1][1],
self._buffer[1][1][1]],
dtype='f8'))
# 断言 h 中字段 'info' 下的 'Name' 等于 self._buffer 中每行第四个元素的长度为 2 的 Unicode 字符数组
assert_equal(h['info']['Name'],
np.array([self._buffer[0][3][0],
self._buffer[1][3][0]],
dtype='U2'))
# 断言 h 中字段 'info' 下的 'Value' 等于 self._buffer 中每行第四个元素的长度为 16 的字符数组
assert_equal(h['info']['Value'],
np.array([self._buffer[0][3][1],
self._buffer[1][3][1]],
dtype='c16'))
def test_nested2_acessors(self):
"""Check reading the nested fields of a nested array (2nd level)"""
# 将 self._buffer 转换为 NumPy 数组 h,使用给定的数据类型 self._descr
h = np.array(self._buffer, dtype=self._descr)
# 如果不是多行数据:
if not self.multiple_rows:
# 断言验证 h['Info']['Info2']['value'] 是否等于指定的复数数组
assert_equal(h['Info']['Info2']['value'],
np.array(self._buffer[1][2][1], dtype='c16'))
# 断言验证 h['Info']['Info2']['z3'] 是否等于指定的无符号整数数组
assert_equal(h['Info']['Info2']['z3'],
np.array(self._buffer[1][2][3], dtype='u4'))
else:
# 断言验证 h['Info']['Info2']['value'] 是否等于指定的复数数组列表
assert_equal(h['Info']['Info2']['value'],
np.array([self._buffer[0][1][2][1],
self._buffer[1][1][2][1]],
dtype='c16'))
# 断言验证 h['Info']['Info2']['z3'] 是否等于指定的无符号整数数组列表
assert_equal(h['Info']['Info2']['z3'],
np.array([self._buffer[0][1][2][3],
self._buffer[1][1][2][3]],
dtype='u4'))
def test_nested1_descriptor(self):
"""Check access nested descriptors of a nested array (1st level)"""
# 将 self._buffer 转换为 NumPy 数组 h,使用给定的数据类型 self._descr
h = np.array(self._buffer, dtype=self._descr)
# 断言验证 h.dtype['Info']['value'] 的名称是否为 'complex128'
assert_(h.dtype['Info']['value'].name == 'complex128')
# 断言验证 h.dtype['Info']['y2'] 的名称是否为 'float64'
assert_(h.dtype['Info']['y2'].name == 'float64')
# 断言验证 h.dtype['info']['Name'] 的名称是否为 'str256'
assert_(h.dtype['info']['Name'].name == 'str256')
# 断言验证 h.dtype['info']['Value'] 的名称是否为 'complex128'
assert_(h.dtype['info']['Value'].name == 'complex128')
def test_nested2_descriptor(self):
"""Check access nested descriptors of a nested array (2nd level)"""
# 将 self._buffer 转换为 NumPy 数组 h,使用给定的数据类型 self._descr
h = np.array(self._buffer, dtype=self._descr)
# 断言验证 h.dtype['Info']['Info2']['value'] 的名称是否为 'void256'
assert_(h.dtype['Info']['Info2']['value'].name == 'void256')
# 断言验证 h.dtype['Info']['Info2']['z3'] 的名称是否为 'void64'
assert_(h.dtype['Info']['Info2']['z3'].name == 'void64')
class TestReadValuesNestedSingle(ReadValuesNested):
"""Check the values of heterogeneous arrays (nested, single row)"""
# 设置描述符为Ndescr
_descr = Ndescr
# 多行标志设为False
multiple_rows = False
# 缓冲区为NbufferT的第一个元素
_buffer = NbufferT[0]
class TestReadValuesNestedMultiple(ReadValuesNested):
"""Check the values of heterogeneous arrays (nested, multiple rows)"""
# 设置描述符为Ndescr
_descr = Ndescr
# 多行标志设为True
multiple_rows = True
# 缓冲区为NbufferT
_buffer = NbufferT
class TestEmptyField:
def test_assign(self):
# 创建一个包含10个浮点数的数组a
a = np.arange(10, dtype=np.float32)
# 重新设定数组a的数据类型为一个结构化数组,包含一个int字段和两个float字段
a.dtype = [("int", "<0i4"), ("float", "<2f4")]
# 断言int字段的形状为(5, 0)
assert_(a['int'].shape == (5, 0))
# 断言float字段的形状为(5, 2)
assert_(a['float'].shape == (5, 2))
class TestMultipleFields:
def setup_method(self):
# 创建一个结构化数组ary,包含两行四列的数据,每列的类型分别为i4, f4, i2, c8
self.ary = np.array([(1, 2, 3, 4), (5, 6, 7, 8)], dtype='i4,f4,i2,c8')
def _bad_call(self):
# 尝试返回数组ary中的'f0'和'f1'字段,这会触发IndexError
return self.ary['f0', 'f1']
def test_no_tuple(self):
# 断言调用_bad_call方法会引发IndexError异常
assert_raises(IndexError, self._bad_call)
def test_return(self):
# 返回包含ary中'f0'和'f2'字段数据的列表
res = self.ary[['f0', 'f2']].tolist()
# 断言返回的列表与预期结果[(1, 3), (5, 7)]相等
assert_(res == [(1, 3), (5, 7)])
class TestIsSubDType:
# 标量类型可以被提升为dtype
wrappers = [np.dtype, lambda x: x]
def test_both_abstract(self):
# 断言np.floating是np.inexact的子类型
assert_(np.issubdtype(np.floating, np.inexact))
# 断言np.inexact不是np.floating的子类型
assert_(not np.issubdtype(np.inexact, np.floating))
def test_same(self):
# 针对np.float32和np.int32,通过不同的包装器组合,断言类型之间的子类型关系
for cls in (np.float32, np.int32):
for w1, w2 in itertools.product(self.wrappers, repeat=2):
assert_(np.issubdtype(w1(cls), w2(cls)))
def test_subclass(self):
# 使用包装器,断言np.float32和np.float64是np.floating的子类型
for w in self.wrappers:
assert_(np.issubdtype(w(np.float32), np.floating))
assert_(np.issubdtype(w(np.float64), np.floating))
def test_subclass_backwards(self):
# 使用包装器,断言np.floating不是np.float32或np.float64的子类型
for w in self.wrappers:
assert_(not np.issubdtype(np.floating, w(np.float32)))
assert_(not np.issubdtype(np.floating, w(np.float64)))
def test_sibling_class(self):
# 使用不同的包装器组合,断言np.float32不是np.float64的子类型,反之亦然
for w1, w2 in itertools.product(self.wrappers, repeat=2):
assert_(not np.issubdtype(w1(np.float32), w2(np.float64)))
assert_(not np.issubdtype(w1(np.float64), w2(np.float32)))
def test_nondtype_nonscalartype(self):
# 以下测试用例源自于 gh-9505,该问题由 gh-14619 和 gh-9505 引入了废弃修复。
# 检查 np.float32 是否不是 'float64' 的子类型
assert not np.issubdtype(np.float32, 'float64')
# 检查 np.float32 是否不是 'f8' 的子类型
assert not np.issubdtype(np.float32, 'f8')
# 检查 np.int32 是否不是 str 的子类型
assert not np.issubdtype(np.int32, str)
# 检查 np.int32 是否不是 'int64' 的子类型
assert not np.issubdtype(np.int32, 'int64')
# 检查 np.str_ 是否不是 'void' 的子类型
assert not np.issubdtype(np.str_, 'void')
# 检查 np.int8 是否不是 int 的子类型,正确的是 np.integer
assert not np.issubdtype(np.int8, int)
# 检查 np.float32 是否不是 float 的子类型
assert not np.issubdtype(np.float32, float)
# 检查 np.complex64 是否不是 complex 的子类型
assert not np.issubdtype(np.complex64, complex)
# 检查 np.float32 是否不是 "float" 的子类型
assert not np.issubdtype(np.float32, "float")
# 检查 np.float64 是否是 "float" 的子类型
assert np.issubdtype(np.float64, "float")
# 测试 int、float、complex 的正确数据类型和抽象类型
# 对于 int、float、complex,正确的数据类型应为 np.integer、np.floating、np.complexfloating
assert np.issubdtype(np.float64, 'float64')
assert np.issubdtype(np.float64, 'f8')
assert np.issubdtype(np.str_, str)
assert np.issubdtype(np.int64, 'int64')
assert np.issubdtype(np.void, 'void')
assert np.issubdtype(np.int8, np.integer)
assert np.issubdtype(np.float32, np.floating)
assert np.issubdtype(np.complex64, np.complexfloating)
# 检查 np.float64 是否是 "float" 的子类型
assert np.issubdtype(np.float64, "float")
# 检查 np.float32 是否是 "f" 的子类型
assert np.issubdtype(np.float32, "f")
class TestIsDType:
"""
Check correctness of `np.isdtype`. The test considers different argument
configurations: `np.isdtype(dtype, k1)` and `np.isdtype(dtype, (k1, k2))`
with concrete dtypes and dtype groups.
"""
# 定义 dtype 组的字典,映射到具体的 NumPy 类型数组
dtype_group_dict = {
"signed integer": sctypes["int"],
"unsigned integer": sctypes["uint"],
"integral": sctypes["int"] + sctypes["uint"],
"real floating": sctypes["float"],
"complex floating": sctypes["complex"],
"numeric": (
sctypes["int"] + sctypes["uint"] + sctypes["float"] +
sctypes["complex"]
)
}
@pytest.mark.parametrize(
"dtype,close_dtype",
[
(np.int64, np.int32), (np.uint64, np.uint32),
(np.float64, np.float32), (np.complex128, np.complex64)
]
)
@pytest.mark.parametrize(
"dtype_group",
[
None, "signed integer", "unsigned integer", "integral",
"real floating", "complex floating", "numeric"
]
)
# 测试 np.isdtype 函数的不同参数组合
def test_isdtype(self, dtype, close_dtype, dtype_group):
# 首先检查相同的 dtype 应返回 `True`,不同的应返回 `False`
# 即使它们在 dtype 层次结构中接近也是如此!
if dtype_group is None:
assert np.isdtype(dtype, dtype)
assert not np.isdtype(dtype, close_dtype)
assert np.isdtype(dtype, (dtype, close_dtype))
# 检查 dtype 是否属于指定的 dtype 组,应返回 `True`,否则返回 `False`
elif dtype in self.dtype_group_dict[dtype_group]:
assert np.isdtype(dtype, dtype_group)
assert np.isdtype(dtype, (close_dtype, dtype_group))
else:
assert not np.isdtype(dtype, dtype_group)
# 测试 np.isdtype 函数的无效参数
def test_isdtype_invalid_args(self):
with assert_raises_regex(TypeError, r".*must be a NumPy dtype.*"):
np.isdtype("int64", np.int64)
with assert_raises_regex(TypeError, r".*kind argument must.*"):
np.isdtype(np.int64, 1)
with assert_raises_regex(ValueError, r".*not a known kind name.*"):
np.isdtype(np.int64, "int64")
# 检查 sctypes 字典的完整性,特别是在 32 位系统上
def test_sctypes_complete(self):
assert np.int32 in sctypes['int']
assert np.intc in sctypes['int']
assert np.int64 in sctypes['int']
assert np.uint32 in sctypes['uint']
assert np.uintc in sctypes['uint']
assert np.uint64 in sctypes['uint']
class TestSctypeDict:
# 测试 np._core.sctypeDict 字典中的 longdouble 类型
def test_longdouble(self):
assert_(np._core.sctypeDict['float64'] is not np.longdouble)
assert_(np._core.sctypeDict['complex128'] is not np.clongdouble)
# 测试 np._core.sctypeDict 字典中的 ulong 类型
def test_ulong(self):
assert np._core.sctypeDict['ulong'] is np.ulong
assert np.dtype(np.ulong) is np.dtype("ulong")
assert np.dtype(np.ulong).itemsize == np.dtype(np.long).itemsize
@pytest.mark.filterwarnings("ignore:.*maximum_sctype.*:DeprecationWarning")
class TestMaximumSctype:
# 定义一个测试类 TestMaximumSctype,用于测试 maximum_sctype 函数的各种数据类型情况
# note that parametrizing with sctype['int'] and similar would skip types
# with the same size (gh-11923)
# 注意,使用 sctype['int'] 等进行参数化会跳过相同大小的类型 (gh-11923)
@pytest.mark.parametrize(
't', [np.byte, np.short, np.intc, np.long, np.longlong]
)
def test_int(self, t):
# 测试整数类型的最大数据类型
assert_equal(maximum_sctype(t), np._core.sctypes['int'][-1])
@pytest.mark.parametrize(
't', [np.ubyte, np.ushort, np.uintc, np.ulong, np.ulonglong]
)
def test_uint(self, t):
# 测试无符号整数类型的最大数据类型
assert_equal(maximum_sctype(t), np._core.sctypes['uint'][-1])
@pytest.mark.parametrize('t', [np.half, np.single, np.double, np.longdouble])
def test_float(self, t):
# 测试浮点数类型的最大数据类型
assert_equal(maximum_sctype(t), np._core.sctypes['float'][-1])
@pytest.mark.parametrize('t', [np.csingle, np.cdouble, np.clongdouble])
def test_complex(self, t):
# 测试复数类型的最大数据类型
assert_equal(maximum_sctype(t), np._core.sctypes['complex'][-1])
@pytest.mark.parametrize('t', [np.bool, np.object_, np.str_, np.bytes_,
np.void])
def test_other(self, t):
# 测试其他类型的最大数据类型
assert_equal(maximum_sctype(t), t)
class Test_sctype2char:
# 定义一个测试类 Test_sctype2char,用于测试 sctype2char 函数的各种数据类型情况
# This function is old enough that we're really just documenting the quirks
# at this point.
# 这个函数已经足够老了,我们只是在记录当前阶段的一些怪异行为。
def test_scalar_type(self):
# 测试标量类型的 sctype2char 转换
assert_equal(sctype2char(np.double), 'd')
assert_equal(sctype2char(np.long), 'l')
assert_equal(sctype2char(np.int_), np.array(0).dtype.char)
assert_equal(sctype2char(np.str_), 'U')
assert_equal(sctype2char(np.bytes_), 'S')
def test_other_type(self):
# 测试其他类型的 sctype2char 转换
assert_equal(sctype2char(float), 'd')
assert_equal(sctype2char(list), 'O')
assert_equal(sctype2char(np.ndarray), 'O')
def test_third_party_scalar_type(self):
# 测试第三方标量类型的 sctype2char 转换
from numpy._core._rational_tests import rational
assert_raises(KeyError, sctype2char, rational)
assert_raises(KeyError, sctype2char, rational(1))
def test_array_instance(self):
# 测试数组实例的 sctype2char 转换
assert_equal(sctype2char(np.array([1.0, 2.0])), 'd')
def test_abstract_type(self):
# 测试抽象类型的 sctype2char 转换
assert_raises(KeyError, sctype2char, np.floating)
def test_non_type(self):
# 测试非类型的 sctype2char 转换
assert_raises(ValueError, sctype2char, 1)
@pytest.mark.parametrize("rep, expected", [
(np.int32, True),
(list, False),
(1.1, False),
(str, True),
(np.dtype(np.float64), True),
(np.dtype((np.int16, (3, 4))), True),
(np.dtype([('a', np.int8)]), True),
])
def test_issctype(rep, expected):
# ensure proper identification of scalar
# data-types by issctype()
# 确保 issctype() 正确识别标量数据类型
actual = issctype(rep)
assert_equal(actual, expected)
@pytest.mark.skipif(sys.flags.optimize > 1,
reason="no docstrings present to inspect when PYTHONOPTIMIZE/Py_OptimizeFlag > 1")
@pytest.mark.xfail(IS_PYPY,
reason="PyPy cannot modify tp_doc after PyType_Ready")
class TestDocStrings:
# 定义一个测试类 TestDocStrings,用于测试在不同条件下文档字符串的处理行为
def test_platform_dependent_aliases(self):
# 检查 numpy 库中的 np.int64 是否和 np.int_ 相同
if np.int64 is np.int_:
# 如果相同,断言 'int64' 在 np.int_ 对象的文档字符串中出现
assert_('int64' in np.int_.__doc__)
# 如果 np.int64 和 np.longlong 相同
elif np.int64 is np.longlong:
# 断言 'int64' 在 np.longlong 对象的文档字符串中出现
assert_('int64' in np.longlong.__doc__)
# 定义一个测试类 TestScalarTypeNames,用于测试数值类型的命名是否正确
class TestScalarTypeNames:
# gh-9799
# 定义一个包含各种数值类型的列表
numeric_types = [
np.byte, np.short, np.intc, np.long, np.longlong,
np.ubyte, np.ushort, np.uintc, np.ulong, np.ulonglong,
np.half, np.single, np.double, np.longdouble,
np.csingle, np.cdouble, np.clongdouble,
]
# 测试数值类型的名称是否唯一
def test_names_are_unique(self):
# 检查列表中的元素是否都不是别名
assert len(set(self.numeric_types)) == len(self.numeric_types)
# 检查数值类型的名称是否唯一
names = [t.__name__ for t in self.numeric_types]
assert len(set(names)) == len(names)
# 使用 pytest 的 parametrize 装饰器,为每种数值类型 t 运行以下测试
@pytest.mark.parametrize('t', numeric_types)
def test_names_reflect_attributes(self, t):
""" Test that names correspond to where the type is under ``np.`` """
# 检查数值类型的名称是否对应于 np 下的类型
assert getattr(np, t.__name__) is t
# 使用 pytest 的 parametrize 装饰器,为每种数值类型 t 运行以下测试
@pytest.mark.parametrize('t', numeric_types)
def test_names_are_undersood_by_dtype(self, t):
""" Test the dtype constructor maps names back to the type """
# 检查 dtype 构造函数是否能将名称映射回类型
assert np.dtype(t.__name__).type is t
# 定义一个测试类 TestBoolDefinition,用于测试布尔类型的定义
class TestBoolDefinition:
# 测试布尔类型的定义
def test_bool_definition(self):
# 检查 nt.bool 是否等于 np.bool
assert nt.bool is np.bool
.\numpy\numpy\_core\tests\test_overrides.py
# 导入inspect模块,用于获取对象的信息
import inspect
# 导入sys模块,提供与解释器交互的功能
import sys
# 导入os模块,提供与操作系统交互的功能
import os
# 导入tempfile模块,用于创建临时文件和目录
import tempfile
# 从io模块导入StringIO类,用于在内存中读写str对象
from io import StringIO
# 导入unittest模块中的mock类,用于模拟对象行为
from unittest import mock
# 导入pickle模块,用于序列化和反序列化Python对象
import pickle
# 导入pytest模块,用于编写和运行测试
import pytest
# 导入numpy模块,并将其重命名为np,提供对数组和数学函数的支持
import numpy as np
# 从numpy.testing模块导入断言函数,用于测试numpy数组的期望行为
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex)
# 从numpy._core.overrides模块导入特定函数和类
from numpy._core.overrides import (
_get_implementing_args, array_function_dispatch,
verify_matching_signatures)
# 定义一个函数,在调用数组函数时返回NotImplemented
def _return_not_implemented(self, *args, **kwargs):
return NotImplemented
# 使用array_function_dispatch装饰器注册一个函数,指定一个参数的数组函数分发
@array_function_dispatch(lambda array: (array,))
def dispatched_one_arg(array):
"""Docstring."""
return 'original'
# 使用array_function_dispatch装饰器注册一个函数,指定两个参数的数组函数分发
@array_function_dispatch(lambda array1, array2: (array1, array2))
def dispatched_two_arg(array1, array2):
"""Docstring."""
return 'original'
# 测试类TestGetImplementingArgs,用于测试_get_implementing_args函数
class TestGetImplementingArgs:
# 测试_ndarray方法,用于测试数组类型的行为
def test_ndarray(self):
array = np.array(1)
# 调用_get_implementing_args函数,检查返回的参数是否与数组一致
args = _get_implementing_args([array])
assert_equal(list(args), [array])
# 调用_get_implementing_args函数,检查返回的参数是否与数组一致
args = _get_implementing_args([array, array])
assert_equal(list(args), [array])
# 调用_get_implementing_args函数,检查返回的参数是否与数组一致
args = _get_implementing_args([array, 1])
assert_equal(list(args), [array])
# 调用_get_implementing_args函数,检查返回的参数是否与数组一致
args = _get_implementing_args([1, array])
assert_equal(list(args), [array])
# 测试_ndarray_subclasses方法,用于测试数组的子类行为
def test_ndarray_subclasses(self):
# 定义一个继承自np.ndarray的类OverrideSub,覆盖__array_function__方法
class OverrideSub(np.ndarray):
__array_function__ = _return_not_implemented
# 定义一个继承自np.ndarray的类NoOverrideSub,未覆盖__array_function__方法
class NoOverrideSub(np.ndarray):
pass
array = np.array(1).view(np.ndarray)
override_sub = np.array(1).view(OverrideSub)
no_override_sub = np.array(1).view(NoOverrideSub)
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
args = _get_implementing_args([array, override_sub])
assert_equal(list(args), [override_sub, array])
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
args = _get_implementing_args([array, no_override_sub])
assert_equal(list(args), [no_override_sub, array])
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
args = _get_implementing_args([override_sub, no_override_sub])
assert_equal(list(args), [override_sub, no_override_sub])
# 测试_ndarray_and_duck_array方法,用于测试数组和鸭子类型对象的行为
def test_ndarray_and_duck_array(self):
# 定义一个类Other,覆盖__array_function__方法
class Other:
__array_function__ = _return_not_implemented
array = np.array(1)
other = Other()
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
args = _get_implementing_args([other, array])
assert_equal(list(args), [other, array])
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
args = _get_implementing_args([array, other])
assert_equal(list(args), [array, other])
# 测试_ndarray_subclass_and_duck_array方法,用于测试数组子类和鸭子类型对象的行为
def test_ndarray_subclass_and_duck_array(self):
# 定义一个继承自np.ndarray的类OverrideSub,覆盖__array_function__方法
class OverrideSub(np.ndarray):
__array_function__ = _return_not_implemented
# 定义一个类Other,覆盖__array_function__方法
class Other:
__array_function__ = _return_not_implemented
array = np.array(1)
subarray = np.array(1).view(OverrideSub)
other = Other()
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
assert_equal(_get_implementing_args([array, subarray, other]),
[subarray, array, other])
# 调用_get_implementing_args函数,检查返回的参数是否按预期排列
assert_equal(_get_implementing_args([array, other, subarray]),
[subarray, array, other])
# 定义一个测试方法,用于测试多个“鸭子类型”数组的情况
def test_many_duck_arrays(self):
# 定义类 A,设置 __array_function__ 属性为 _return_not_implemented 函数
class A:
__array_function__ = _return_not_implemented
# 定义类 B,继承自 A,设置 __array_function__ 属性为 _return_not_implemented 函数
class B(A):
__array_function__ = _return_not_implemented
# 定义类 C,继承自 A,设置 __array_function__ 属性为 _return_not_implemented 函数
class C(A):
__array_function__ = _return_not_implemented
# 定义类 D,没有继承自 A,设置 __array_function__ 属性为 _return_not_implemented 函数
class D:
__array_function__ = _return_not_implemented
# 创建类 A、B、C 和 D 的实例
a = A()
b = B()
c = C()
d = D()
# 使用 _get_implementing_args 函数测试不同参数组合下的返回结果是否符合预期
assert_equal(_get_implementing_args([1]), []) # 测试传入整数 1 的情况
assert_equal(_get_implementing_args([a]), [a]) # 测试传入实例 a 的情况
assert_equal(_get_implementing_args([a, 1]), [a]) # 测试传入实例 a 和整数 1 的情况
assert_equal(_get_implementing_args([a, a, a]), [a]) # 测试传入多个实例 a 的情况
assert_equal(_get_implementing_args([a, d, a]), [a, d]) # 测试传入实例 a、d 和 a 的情况
assert_equal(_get_implementing_args([a, b]), [b, a]) # 测试传入实例 a 和 b 的情况
assert_equal(_get_implementing_args([b, a]), [b, a]) # 测试传入实例 b 和 a 的情况
assert_equal(_get_implementing_args([a, b, c]), [b, c, a]) # 测试传入实例 a、b 和 c 的情况
assert_equal(_get_implementing_args([a, c, b]), [c, b, a]) # 测试传入实例 a、c 和 b 的情况
# 定义一个测试方法,用于测试过多“鸭子类型”数组的情况
def test_too_many_duck_arrays(self):
# 创建一个命名空间字典,其中 __array_function__ 属性为 _return_not_implemented 函数
namespace = dict(__array_function__=_return_not_implemented)
# 创建 65 个动态生成的类,每个类都有相同的命名空间
types = [type('A' + str(i), (object,), namespace) for i in range(65)]
# 创建一个包含 65 个实例的列表
relevant_args = [t() for t in types]
# 使用 _get_implementing_args 函数测试前 64 个实例是否返回预期的结果
actual = _get_implementing_args(relevant_args[:64])
assert_equal(actual, relevant_args[:64])
# 使用 assert_raises_regex 断言捕获 TypeError 异常,检查当传入超过 64 个不同类型实例时是否引发异常
with assert_raises_regex(TypeError, 'distinct argument types'):
_get_implementing_args(relevant_args)
class TestNDArrayArrayFunction:
def test_method(self):
class Other:
__array_function__ = _return_not_implemented # 定义一个没有实现的特殊方法 __array_function__
class NoOverrideSub(np.ndarray): # 定义一个继承自 np.ndarray 的子类,没有覆盖 __array_function__
pass
class OverrideSub(np.ndarray):
__array_function__ = _return_not_implemented # 定义一个继承自 np.ndarray 的子类,并覆盖 __array_function__
array = np.array([1]) # 创建一个包含单个元素的 numpy 数组
other = Other() # 创建 Other 类的实例
no_override_sub = array.view(NoOverrideSub) # 将 array 视图转换为 NoOverrideSub 类的实例
override_sub = array.view(OverrideSub) # 将 array 视图转换为 OverrideSub 类的实例
# 调用 array 对象的 __array_function__ 方法,传入函数、类型元组、参数和关键字参数,期望返回 'original'
result = array.__array_function__(func=dispatched_two_arg,
types=(np.ndarray,),
args=(array, 1.), kwargs={})
assert_equal(result, 'original') # 断言结果为 'original'
# 调用 array 对象的 __array_function__ 方法,传入函数、类型元组、参数和关键字参数,期望返回 NotImplemented
result = array.__array_function__(func=dispatched_two_arg,
types=(np.ndarray, Other),
args=(array, other), kwargs={})
assert_(result is NotImplemented) # 断言结果为 NotImplemented
# 调用 array 对象的 __array_function__ 方法,传入函数、类型元组、参数和关键字参数,期望返回 'original'
result = array.__array_function__(func=dispatched_two_arg,
types=(np.ndarray, NoOverrideSub),
args=(array, no_override_sub),
kwargs={})
assert_equal(result, 'original') # 断言结果为 'original'
# 调用 array 对象的 __array_function__ 方法,传入函数、类型元组、参数和关键字参数,期望返回 'original'
result = array.__array_function__(func=dispatched_two_arg,
types=(np.ndarray, OverrideSub),
args=(array, override_sub),
kwargs={})
assert_equal(result, 'original') # 断言结果为 'original'
# 使用 assert_raises_regex 上下文管理器断言 TypeError,错误信息中包含 'no implementation found'
with assert_raises_regex(TypeError, 'no implementation found'):
np.concatenate((array, other)) # 尝试使用 np.concatenate() 合并 array 和 other
expected = np.concatenate((array, array)) # 预期的合并结果,两个相同的 array 数组
result = np.concatenate((array, no_override_sub)) # 使用 np.concatenate() 合并 array 和 no_override_sub
assert_equal(result, expected.view(NoOverrideSub)) # 断言结果与预期的 NoOverrideSub 类型的视图相等
result = np.concatenate((array, override_sub)) # 使用 np.concatenate() 合并 array 和 override_sub
assert_equal(result, expected.view(OverrideSub)) # 断言结果与预期的 OverrideSub 类型的视图相等
def test_no_wrapper(self):
# 这段代码不应该执行,除非用户有意调用具有无效参数的 __array_function__ 方法,
# 但我们仍然检查确保能够正确引发适当的错误。
array = np.array(1) # 创建一个包含单个元素的 numpy 数组
func = lambda x: x # 定义一个简单的 lambda 函数
with assert_raises_regex(AttributeError, '_implementation'):
array.__array_function__(func=func, types=(np.ndarray,),
args=(array,), kwargs={}) # 尝试调用 array 对象的 __array_function__ 方法
def test_interface(self):
# 定义一个内部类 MyArray,用于测试 __array_function__ 方法的调用
class MyArray:
# 自定义 __array_function__ 方法,返回传入的参数元组
def __array_function__(self, func, types, args, kwargs):
return (self, func, types, args, kwargs)
# 创建 MyArray 的实例 original
original = MyArray()
# 调用 dispatched_one_arg 函数,并解构返回的元组
(obj, func, types, args, kwargs) = dispatched_one_arg(original)
# 断言 original 和 obj 是同一个对象
assert_(obj is original)
# 断言 func 是 dispatched_one_arg 函数本身
assert_(func is dispatched_one_arg)
# 断言 types 是包含 MyArray 类型的集合
assert_equal(set(types), {MyArray})
# args 应包含 original 这个实例
# assert_equal 内部使用了重载的 np.iscomplexobj() 方法
assert_(args == (original,))
# kwargs 应该为空字典
assert_equal(kwargs, {})
def test_not_implemented(self):
# 定义一个内部类 MyArray,覆盖 __array_function__ 方法返回 NotImplemented
class MyArray:
def __array_function__(self, func, types, args, kwargs):
return NotImplemented
# 创建 MyArray 的实例 array
array = MyArray()
# 使用 assert_raises_regex 断言抛出 TypeError 异常,且异常信息包含 'no implementation found'
with assert_raises_regex(TypeError, 'no implementation found'):
# 调用 dispatched_one_arg 函数
dispatched_one_arg(array)
def test_where_dispatch(self):
# 定义一个内部类 DuckArray,覆盖 __array_function__ 方法返回固定字符串 "overridden"
class DuckArray:
def __array_function__(self, ufunc, method, *inputs, **kwargs):
return "overridden"
# 创建一个 numpy 数组 array
array = np.array(1)
# 创建 DuckArray 的实例 duck_array
duck_array = DuckArray()
# 调用 numpy 的 std 函数,传入 where 参数为 duck_array
result = np.std(array, where=duck_array)
# 断言 result 等于 "overridden"
assert_equal(result, "overridden")
class TestVerifyMatchingSignatures:
def test_verify_matching_signatures(self):
# 调用 verify_matching_signatures 函数,传入两个相同的匿名函数,预期不会引发异常
verify_matching_signatures(lambda x: 0, lambda x: 0)
verify_matching_signatures(lambda x=None: 0, lambda x=None: 0)
verify_matching_signatures(lambda x=1: 0, lambda x=None: 0)
# 使用 assert_raises 断言预期运行时错误将被引发
with assert_raises(RuntimeError):
# 传入两个参数名称不同的匿名函数,预期引发 RuntimeError
verify_matching_signatures(lambda a: 0, lambda b: 0)
with assert_raises(RuntimeError):
# 传入一个参数有默认值,另一个参数无默认值的匿名函数,预期引发 RuntimeError
verify_matching_signatures(lambda x: 0, lambda x=None: 0)
with assert_raises(RuntimeError):
# 传入两个参数名称相同但默认值类型不同的匿名函数,预期引发 RuntimeError
verify_matching_signatures(lambda x=None: 0, lambda y=None: 0)
with assert_raises(RuntimeError):
# 传入两个参数名称相同且默认值类型相同但值不同的匿名函数,预期引发 RuntimeError
verify_matching_signatures(lambda x=1: 0, lambda y=1: 0)
def test_array_function_dispatch(self):
# 使用 assert_raises 断言预期运行时错误将被引发
with assert_raises(RuntimeError):
# 带有 array_function_dispatch 装饰器的函数,传入的参数类型与声明的不匹配,预期引发 RuntimeError
@array_function_dispatch(lambda x: (x,))
def f(y):
pass
# 不应该引发异常
@array_function_dispatch(lambda x: (x,), verify=False)
def f(y):
pass
def _new_duck_type_and_implements():
"""创建一个鸭子类型的数组和实现函数。"""
HANDLED_FUNCTIONS = {}
class MyArray:
def __array_function__(self, func, types, args, kwargs):
if func not in HANDLED_FUNCTIONS:
return NotImplemented
if not all(issubclass(t, MyArray) for t in types):
return NotImplemented
return HANDLED_FUNCTIONS[func](*args, **kwargs)
def implements(numpy_function):
"""注册一个 __array_function__ 实现函数。"""
def decorator(func):
HANDLED_FUNCTIONS[numpy_function] = func
return func
return decorator
return (MyArray, implements)
class TestArrayFunctionImplementation:
def test_one_arg(self):
MyArray, implements = _new_duck_type_and_implements()
@implements(dispatched_one_arg)
def _(array):
return 'myarray'
# 断言 dispatch_one_arg 函数对不同类型参数的行为
assert_equal(dispatched_one_arg(1), 'original')
assert_equal(dispatched_one_arg(MyArray()), 'myarray')
def test_optional_args(self):
# 获取 MyArray 类型和其实现的 duck-typing 函数
MyArray, implements = _new_duck_type_and_implements()
# 定义带可选参数的函数 func_with_option,通过装饰器进行数组函数分派
@array_function_dispatch(lambda array, option=None: (array,))
def func_with_option(array, option='default'):
return option
# 将 func_with_option 函数实现为 my_array_func_with_option 函数
@implements(func_with_option)
def my_array_func_with_option(array, new_option='myarray'):
return new_option
# 断言 func_with_option 函数的不同调用方式返回预期结果
assert_equal(func_with_option(1), 'default')
assert_equal(func_with_option(1, option='extra'), 'extra')
assert_equal(func_with_option(MyArray()), 'myarray')
# 使用未实现的选项调用 func_with_option 应抛出 TypeError 异常
with assert_raises(TypeError):
func_with_option(MyArray(), option='extra')
# 断言 my_array_func_with_option 函数的调用结果
result = my_array_func_with_option(MyArray(), new_option='yes')
assert_equal(result, 'yes')
# 使用未实现的选项调用 func_with_option 应抛出 TypeError 异常
with assert_raises(TypeError):
func_with_option(MyArray(), new_option='no')
def test_not_implemented(self):
# 获取 MyArray 类型和其实现的 duck-typing 函数
MyArray, implements = _new_duck_type_and_implements()
# 定义函数 func,通过装饰器指定数组函数分派的模块为 'my'
@array_function_dispatch(lambda array: (array,), module='my')
def func(array):
return array
# 创建一个数组 array
array = np.array(1)
# 断言 func 函数返回的是输入的数组 array
assert_(func(array) is array)
# 断言 func 函数的模块名为 'my'
assert_equal(func.__module__, 'my')
# 使用未实现的 MyArray 类型调用 func 应抛出 TypeError 异常
with assert_raises_regex(
TypeError, "no implementation found for 'my.func'"):
func(MyArray())
@pytest.mark.parametrize("name", ["concatenate", "mean", "asarray"])
def test_signature_error_message_simple(self, name):
# 根据参数 name 获取 numpy 模块中对应的函数 func
func = getattr(np, name)
try:
# 调用 func 函数,期望抛出 TypeError 异常
func()
except TypeError as e:
exc = e
# 断言异常信息以 name + '()' 开头
assert exc.args[0].startswith(f"{name}()")
def test_signature_error_message(self):
# 定义一个分派函数 _dispatcher,返回空元组
# func 函数的分派器使用 _dispatcher
def _dispatcher():
return ()
# 定义函数 func,并使用 _dispatcher 作为其数组函数分派器
@array_function_dispatch(_dispatcher)
def func():
pass
try:
# 调用 func 的 _implementation 方法,传入错误的参数名 bad_arg
func._implementation(bad_arg=3)
except TypeError as e:
expected_exception = e
try:
# 调用 func 函数,传入错误的参数名 bad_arg
func(bad_arg=3)
raise AssertionError("must fail")
except TypeError as exc:
# 如果异常信息以 '_dispatcher' 开头,跳过此测试(Python 版本不使用 __qualname__ 格式化 TypeError)
if exc.args[0].startswith("_dispatcher"):
pytest.skip("Python version is not using __qualname__ for "
"TypeError formatting.")
# 断言异常信息与预期异常相同
assert exc.args == expected_exception.args
@pytest.mark.parametrize("value", [234, "this func is not replaced"])
def test_dispatcher_error(self, value):
# 如果调度程序引发错误,则不应尝试进行变异
# 创建一个 TypeError 的实例,用于后续引发异常
error = TypeError(value)
# 定义一个内部函数 dispatcher,用于引发之前创建的错误
def dispatcher():
raise error
# 使用 array_function_dispatch 装饰器将 func 关联到 dispatcher 函数
@array_function_dispatch(dispatcher)
def func():
return 3
try:
# 调用 func 函数,期望它引发之前定义的 TypeError 异常
func()
# 如果没有引发异常,则抛出 AssertionError
raise AssertionError("must fail")
except TypeError as exc:
# 断言捕获的异常对象和之前创建的 error 对象相同
assert exc is error # unmodified exception
def test_properties(self):
# 检查 str 和 repr 方法的行为是否合理
# 获取 dispatched_two_arg 函数的引用
func = dispatched_two_arg
# 断言函数的字符串表示应该与其实现对象的字符串表示相同
assert str(func) == str(func._implementation)
# 获取函数的 repr,去除其中的对象地址信息
repr_no_id = repr(func).split("at ")[0]
repr_no_id_impl = repr(func._implementation).split("at ")[0]
# 断言函数的 repr 和其实现对象的 repr 应该相同(去除对象地址信息后)
assert repr_no_id == repr_no_id_impl
@pytest.mark.parametrize("func", [
lambda x, y: 0, # 没有 like 参数
lambda like=None: 0, # 不是仅关键字参数
lambda *, like=None, a=3: 0, # 不是最后一个参数(尽管不影响测试)
])
def test_bad_like_sig(self, func):
# 我们对函数签名进行合理性检查,这些应该失败
# 使用 pytest.raises 检查 array_function_dispatch 装饰器对于 func 的应用是否引发 RuntimeError
with pytest.raises(RuntimeError):
array_function_dispatch()(func)
def test_bad_like_passing(self):
# 测试通过位置参数将 like 传递给装饰函数的内部一致性检查
# 定义一个函数 func,接受一个关键字参数 like
def func(*, like=None):
pass
# 使用 array_function_dispatch 装饰 func 函数
func_with_like = array_function_dispatch()(func)
# 使用 pytest.raises 检查调用 func_with_like 函数时传递 like 参数是否引发 TypeError
with pytest.raises(TypeError):
func_with_like()
with pytest.raises(TypeError):
func_with_like(like=234)
def test_too_many_args(self):
# 主要是为了增加代码覆盖率的单元测试
# 创建一个包含多个 MyArr 类的实例对象列表
objs = []
for i in range(80):
class MyArr:
def __array_function__(self, *args, **kwargs):
return NotImplemented
objs.append(MyArr())
# 定义一个 _dispatch 函数,用于作为 array_function_dispatch 的参数
def _dispatch(*args):
return args
# 使用 array_function_dispatch 装饰 func 函数
@array_function_dispatch(_dispatch)
def func(*args):
pass
# 使用 pytest.raises 检查调用 func 函数时传递太多参数是否引发 TypeError,并匹配指定的错误信息
with pytest.raises(TypeError, match="maximum number"):
func(*objs)
class TestNDArrayMethods:
def test_repr(self):
# 用于测试:即使 __array_function__ 没有实现 np.array_repr(),
# 这个方法仍应该被定义
class MyArray(np.ndarray):
def __array_function__(*args, **kwargs):
return NotImplemented
# 创建一个视图为 MyArray 的数组对象
array = np.array(1).view(MyArray)
# 断言该数组的 repr 应为 'MyArray(1)'
assert_equal(repr(array), 'MyArray(1)')
# 断言该数组的 str 应为 '1'
assert_equal(str(array), '1')
class TestNumPyFunctions:
def test_set_module(self):
# 断言 np.sum 函数的模块应为 'numpy'
assert_equal(np.sum.__module__, 'numpy')
# 断言 np.char.equal 函数的模块应为 'numpy.char'
assert_equal(np.char.equal.__module__, 'numpy.char')
# 断言 np.fft.fft 函数的模块应为 'numpy.fft'
assert_equal(np.fft.fft.__module__, 'numpy.fft')
# 断言 np.linalg.solve 函数的模块应为 'numpy.linalg'
assert_equal(np.linalg.solve.__module__, 'numpy.linalg')
def test_inspect_sum(self):
# 获取 np.sum 函数的签名
signature = inspect.signature(np.sum)
# 断言签名中包含 'axis' 参数
assert_('axis' in signature.parameters)
def test_override_sum(self):
# 创建一个自定义数组类 MyArray 和其对应的接口实现
MyArray, implements = _new_duck_type_and_implements()
# 使用装饰器实现对 np.sum 的重载
@implements(np.sum)
def _(array):
return 'yes'
# 断言对 MyArray 类型的数组调用 np.sum 应返回 'yes'
assert_equal(np.sum(MyArray()), 'yes')
def test_sum_on_mock_array(self):
# 由于 __array_function__ 只在类字典中查找,因此需要为模拟对象创建一个代理
class ArrayProxy:
def __init__(self, value):
self.value = value
# 重定向 __array_function__ 至实际值对象
def __array_function__(self, *args, **kwargs):
return self.value.__array_function__(*args, **kwargs)
# 重定向 __array__ 至实际值对象
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)
# 创建一个使用 ArrayProxy 的模拟对象
proxy = ArrayProxy(mock.Mock(spec=ArrayProxy))
# 设置模拟对象的 __array_function__ 返回值为 1
proxy.value.__array_function__.return_value = 1
# 调用 np.sum(proxy),期望结果为 1
result = np.sum(proxy)
assert_equal(result, 1)
# 断言 np.sum 调用时传入的参数符合预期
proxy.value.__array_function__.assert_called_once_with(
np.sum, (ArrayProxy,), (proxy,), {})
# 断言未调用 proxy.value.__array__ 方法
proxy.value.__array__.assert_not_called()
def test_sum_forwarding_implementation(self):
# 创建一个继承自 np.ndarray 的自定义数组类 MyArray
class MyArray(np.ndarray):
# 自定义 sum 方法
def sum(self, axis, out):
return 'summed'
# 重定向 __array_function__ 至父类实现
def __array_function__(self, func, types, args, kwargs):
return super().__array_function__(func, types, args, kwargs)
# 创建一个视图为 MyArray 的数组对象
array = np.array(1).view(MyArray)
# 断言 np.sum(array) 的结果为 'summed'
assert_equal(np.sum(array), 'summed')
class TestArrayLike:
def setup_method(self):
# 定义内部类 MyArray
class MyArray():
# MyArray 类的初始化函数
def __init__(self, function=None):
self.function = function
# 定义 __array_function__ 方法,用于处理数组函数
def __array_function__(self, func, types, args, kwargs):
# 确保 func 是 numpy 中的函数
assert func is getattr(np, func.__name__)
try:
# 获取 MyArray 实例中与 func 同名的方法
my_func = getattr(self, func.__name__)
except AttributeError:
# 如果找不到同名方法,则返回 NotImplemented
return NotImplemented
# 调用找到的方法并返回结果
return my_func(*args, **kwargs)
# 将 MyArray 类赋值给 self.MyArray
self.MyArray = MyArray
# 定义内部类 MyNoArrayFunctionArray
class MyNoArrayFunctionArray():
# MyNoArrayFunctionArray 类的初始化函数
def __init__(self, function=None):
self.function = function
# 将 MyNoArrayFunctionArray 类赋值给 self.MyNoArrayFunctionArray
self.MyNoArrayFunctionArray = MyNoArrayFunctionArray
def add_method(self, name, arr_class, enable_value_error=False):
# 定义内部函数 _definition,用于向 arr_class 动态添加方法
def _definition(*args, **kwargs):
# 检查 kwargs 中是否包含 'like' 键
assert 'like' not in kwargs
# 如果 enable_value_error 为 True 并且 kwargs 中包含 'value_error' 键,则抛出 ValueError
if enable_value_error and 'value_error' in kwargs:
raise ValueError
# 调用 arr_class 中 name 方法,并返回结果
return arr_class(getattr(arr_class, name))
# 将 _definition 方法添加为 arr_class 的 name 方法
setattr(arr_class, name, _definition)
# 定义函数 func_args,返回传入的 args 和 kwargs
def func_args(*args, **kwargs):
return args, kwargs
# 定义测试方法 test_array_like_not_implemented
def test_array_like_not_implemented(self):
# 向 self.MyArray 添加 'array' 方法
self.add_method('array', self.MyArray)
# 创建一个 self.MyArray 的实例 ref
ref = self.MyArray.array()
# 使用 assert_raises_regex 断言捕获 TypeError 异常,并检查异常消息中是否包含 'no implementation found'
with assert_raises_regex(TypeError, 'no implementation found'):
# 调用 np.asarray 函数,使用 ref 作为参数,但不应包含 'like' 参数
array_like = np.asarray(1, like=ref)
# 定义一个测试用例参数列表 _array_tests
_array_tests = [
('array', *func_args((1,))),
('asarray', *func_args((1,))),
('asanyarray', *func_args((1,))),
('ascontiguousarray', *func_args((2, 3))),
('asfortranarray', *func_args((2, 3))),
('require', *func_args((np.arange(6).reshape(2, 3),),
requirements=['A', 'F'])),
('empty', *func_args((1,))),
('full', *func_args((1,), 2)),
('ones', *func_args((1,))),
('zeros', *func_args((1,))),
('arange', *func_args(3)),
('frombuffer', *func_args(b'\x00' * 8, dtype=int)),
('fromiter', *func_args(range(3), dtype=int)),
('fromstring', *func_args('1,2', dtype=int, sep=',')),
('loadtxt', *func_args(lambda: StringIO('0 1\n2 3'))),
('genfromtxt', *func_args(lambda: StringIO('1,2.1'),
dtype=[('int', 'i8'), ('float', 'f8')],
delimiter=',')),
]
# 使用 pytest.mark.parametrize 装饰器为 test_array_like_not_implemented 方法参数化
@pytest.mark.parametrize('function, args, kwargs', _array_tests)
@pytest.mark.parametrize('numpy_ref', [True, False])
# 定义一个测试方法,用于测试类似数组的行为
def test_array_like(self, function, args, kwargs, numpy_ref):
# 添加自定义方法'array',将其与self.MyArray关联
self.add_method('array', self.MyArray)
# 添加指定的函数作为自定义方法,并将其与self.MyArray关联
self.add_method(function, self.MyArray)
# 从numpy中获取指定函数的引用
np_func = getattr(np, function)
# 从self.MyArray中获取指定函数的引用
my_func = getattr(self.MyArray, function)
# 如果numpy_ref为True,则将ref设为numpy数组,否则为self.MyArray的array方法返回值
if numpy_ref is True:
ref = np.array(1)
else:
ref = self.MyArray.array()
# 处理参数,将可调用的参数进行调用,保证参数为元组形式
like_args = tuple(a() if callable(a) else a for a in args)
# 使用指定的参数和kwargs,以及like参数为ref,调用np_func函数
array_like = np_func(*like_args, **kwargs, like=ref)
# 如果numpy_ref为True
if numpy_ref is True:
# 断言array_like的类型为np.ndarray
assert type(array_like) is np.ndarray
# 重新处理参数,保证参数为元组形式
np_args = tuple(a() if callable(a) else a for a in args)
# 使用指定的参数和kwargs,调用np_func函数
np_arr = np_func(*np_args, **kwargs)
# 特殊情况处理np.empty以确保数值匹配
if function == "empty":
np_arr.fill(1)
array_like.fill(1)
# 断言array_like与np_arr相等
assert_equal(array_like, np_arr)
else:
# 断言array_like的类型为self.MyArray
assert type(array_like) is self.MyArray
# 断言array_like的函数属性为my_func
assert array_like.function is my_func
# 使用pytest标记参数化测试
@pytest.mark.parametrize('function, args, kwargs', _array_tests)
# 使用pytest标记参数化测试
@pytest.mark.parametrize('ref', [1, [1], "MyNoArrayFunctionArray"])
# 定义一个测试方法,用于测试不支持数组函数like行为的情况
def test_no_array_function_like(self, function, args, kwargs, ref):
# 添加自定义方法'array',将其与self.MyNoArrayFunctionArray关联
self.add_method('array', self.MyNoArrayFunctionArray)
# 添加指定的函数作为自定义方法,并将其与self.MyNoArrayFunctionArray关联
self.add_method(function, self.MyNoArrayFunctionArray)
# 从numpy中获取指定函数的引用
np_func = getattr(np, function)
# 如果ref为"MyNoArrayFunctionArray"字符串,则将ref设为self.MyNoArrayFunctionArray的array方法返回值
if ref == "MyNoArrayFunctionArray":
ref = self.MyNoArrayFunctionArray.array()
# 处理参数,将可调用的参数进行调用,保证参数为元组形式
like_args = tuple(a() if callable(a) else a for a in args)
# 使用指定的参数和kwargs,以及like参数为ref,调用np_func函数,期望抛出TypeError异常
with assert_raises_regex(TypeError,
'The `like` argument must be an array-like that implements'):
np_func(*like_args, **kwargs, like=ref)
# 使用pytest标记参数化测试
@pytest.mark.parametrize('numpy_ref', [True, False])
# 定义一个测试方法,用于测试从文件中读取类似数组的行为
def test_array_like_fromfile(self, numpy_ref):
# 添加自定义方法'array',将其与self.MyArray关联
self.add_method('array', self.MyArray)
# 添加自定义方法'fromfile',将其与self.MyArray关联
self.add_method("fromfile", self.MyArray)
# 如果numpy_ref为True,则将ref设为numpy数组,否则为self.MyArray的array方法返回值
if numpy_ref is True:
ref = np.array(1)
else:
ref = self.MyArray.array()
# 生成一个包含5个随机数的numpy数组
data = np.random.random(5)
# 使用临时目录作为文件存储的位置
with tempfile.TemporaryDirectory() as tmpdir:
# 构造文件名,文件位于临时目录下,名为"testfile"
fname = os.path.join(tmpdir, "testfile")
# 将数据写入文件
data.tofile(fname)
# 使用np.fromfile函数从文件中读取数据,like参数为ref
array_like = np.fromfile(fname, like=ref)
# 如果numpy_ref为True
if numpy_ref is True:
# 断言array_like的类型为np.ndarray
assert type(array_like) is np.ndarray
# 使用np.fromfile函数再次从文件中读取数据,like参数为ref
np_res = np.fromfile(fname, like=ref)
# 断言np_res与data相等
assert_equal(np_res, data)
# 断言array_like与np_res相等
assert_equal(array_like, np_res)
else:
# 断言array_like的类型为self.MyArray
assert type(array_like) is self.MyArray
# 断言array_like的函数属性为self.MyArray的fromfile方法
assert array_like.function is self.MyArray.fromfile
# 定义一个测试方法,用于测试异常处理
def test_exception_handling(self):
# 向自定义数组类添加一个名为 'array' 的方法,并启用值错误异常处理选项
self.add_method('array', self.MyArray, enable_value_error=True)
# 调用自定义数组类的 array() 方法,返回一个引用
ref = self.MyArray.array()
# 使用 assert_raises 来断言会抛出 TypeError 异常
with assert_raises(TypeError):
# 预期此处会首先引发关于 'value_error' 参数无效的错误
np.array(1, value_error=True, like=ref)
# 使用 pytest 的参数化标记来定义一个测试方法,测试函数为 _array_tests 中指定的函数
@pytest.mark.parametrize('function, args, kwargs', _array_tests)
def test_like_as_none(self, function, args, kwargs):
# 向自定义数组类添加一个名为 'array' 的方法
self.add_method('array', self.MyArray)
# 向自定义数组类添加指定的 function 方法
self.add_method(function, self.MyArray)
# 获取 numpy 模块中的 function 函数
np_func = getattr(np, function)
# 将 args 中的每个元素如果是可调用的则调用,形成新的元组 like_args
like_args = tuple(a() if callable(a) else a for a in args)
# 对于 loadtxt 和 genfromtxt 函数,初始化时需要这样做以避免错误
like_args_exp = tuple(a() if callable(a) else a for a in args)
# 调用 numpy 模块中的 function 函数,like 参数设为 None
array_like = np_func(*like_args, **kwargs, like=None)
# 调用 numpy 模块中的 function 函数,使用 like_args_exp 和 kwargs
expected = np_func(*like_args_exp, **kwargs)
# 特殊处理 np.empty 函数,确保其生成的数组值与期望值一致
if function == "empty":
array_like.fill(1)
expected.fill(1)
# 使用 assert_equal 断言 array_like 与 expected 相等
assert_equal(array_like, expected)
def test_function_like():
# 检查 np.mean 的类型是否为特定的 ArrayFunctionDispatcher 类型
assert type(np.mean) is np._core._multiarray_umath._ArrayFunctionDispatcher
# 定义一个 MyClass 类
class MyClass:
# 实现 __array__ 方法以支持 mean 函数
def __array__(self, dtype=None, copy=None):
# 返回一个 ndarray 以供 np.mean 使用
return np.arange(3)
# 静态方法 func1 使用 np.mean
func1 = staticmethod(np.mean)
# 实例方法 func2 使用 np.mean
func2 = np.mean
# 类方法 func3 使用 np.mean
func3 = classmethod(np.mean)
# 创建 MyClass 实例 m
m = MyClass()
# 断言 func1 在 m 实例上调用时返回 10
assert m.func1([10]) == 10
# 断言 func2 在 m 实例上调用时返回 1(np.arange(3) 的均值)
assert m.func2() == 1
# 使用 pytest 断言 func3 在类上调用时引发 TypeError 异常,匹配给定的错误信息
with pytest.raises(TypeError, match="unsupported operand type"):
m.func3()
# 手动绑定 np.mean 方法也可以正常工作(上面的方法可能是一个捷径):
# 绑定 np.mean 方法到 MyClass 实例 m 上
bound = np.mean.__get__(m, MyClass)
# 断言绑定后调用返回 1
assert bound() == 1
# 未绑定的方法实际上是静态调用,这里传入 None
bound = np.mean.__get__(None, MyClass)
# 断言未绑定调用时返回 10
assert bound([10]) == 10
# 绑定 np.mean 方法作为类方法
bound = np.mean.__get__(MyClass)
# 使用 pytest 断言调用时引发 TypeError 异常,匹配给定的错误信息
with pytest.raises(TypeError, match="unsupported operand type"):
bound()
.\numpy\numpy\_core\tests\test_print.py
# 导入系统模块 sys
import sys
# 导入 pytest 模块
import pytest
# 导入 numpy 库,并从中导入 assert_, assert_equal, IS_MUSL
import numpy as np
from numpy.testing import assert_, assert_equal, IS_MUSL
# 从 numpy._core.tests._locales 中导入 CommaDecimalPointLocale
from numpy._core.tests._locales import CommaDecimalPointLocale
# 导入 io 模块中的 StringIO 类
from io import StringIO
# 定义一个全局变量 _REF,用于存储特定浮点数值与字符串表示的对应关系
_REF = {np.inf: 'inf', -np.inf: '-inf', np.nan: 'nan'}
# 使用 pytest 的 parametrize 装饰器,测试不同浮点数类型的字符串格式化输出
@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
def test_float_types(tp):
""" Check formatting.
This is only for the str function, and only for simple types.
The precision of np.float32 and np.longdouble aren't the same as the
python float precision.
"""
# 遍历不同的浮点数值,并进行字符串格式化输出的测试
for x in [0, 1, -1, 1e20]:
# 断言字符串格式化输出结果与预期的 Python 浮点数格式化输出一致
assert_equal(str(tp(x)), str(float(x)),
err_msg='Failed str formatting for type %s' % tp)
# 对于 itemsize 大于 4 的浮点数类型,继续进行字符串格式化输出测试
if tp(1e16).itemsize > 4:
assert_equal(str(tp(1e16)), str(float('1e16')),
err_msg='Failed str formatting for type %s' % tp)
else:
# 否则,使用预定义的字符串格式化输出进行断言
ref = '1e+16'
assert_equal(str(tp(1e16)), ref,
err_msg='Failed str formatting for type %s' % tp)
# 使用 pytest 的 parametrize 装饰器,测试 nan 和 inf 的字符串格式化输出
@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
def test_nan_inf_float(tp):
""" Check formatting of nan & inf.
This is only for the str function, and only for simple types.
The precision of np.float32 and np.longdouble aren't the same as the
python float precision.
"""
# 遍历 nan 和 inf 值,并进行字符串格式化输出的测试
for x in [np.inf, -np.inf, np.nan]:
# 断言字符串格式化输出结果与预期的特定值对应
assert_equal(str(tp(x)), _REF[x],
err_msg='Failed str formatting for type %s' % tp)
# 使用 pytest 的 parametrize 装饰器,测试复数类型的字符串格式化输出
@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble])
def test_complex_types(tp):
"""Check formatting of complex types.
This is only for the str function, and only for simple types.
The precision of np.float32 and np.longdouble aren't the same as the
python float precision.
"""
# 遍历不同的实数值和虚数值,并进行复数类型的字符串格式化输出测试
for x in [0, 1, -1, 1e20]:
# 断言字符串格式化输出结果与预期的复数格式化输出一致
assert_equal(str(tp(x)), str(complex(x)),
err_msg='Failed str formatting for type %s' % tp)
assert_equal(str(tp(x*1j)), str(complex(x*1j)),
err_msg='Failed str formatting for type %s' % tp)
assert_equal(str(tp(x + x*1j)), str(complex(x + x*1j)),
err_msg='Failed str formatting for type %s' % tp)
# 对于 itemsize 大于 8 的复数类型,继续进行字符串格式化输出测试
if tp(1e16).itemsize > 8:
assert_equal(str(tp(1e16)), str(complex(1e16)),
err_msg='Failed str formatting for type %s' % tp)
else:
# 否则,使用预定义的字符串格式化输出进行断言
ref = '(1e+16+0j)'
assert_equal(str(tp(1e16)), ref,
err_msg='Failed str formatting for type %s' % tp)
# 使用 pytest 的 parametrize 装饰器,测试复数类型中的 inf 和 nan 的字符串格式化输出
@pytest.mark.parametrize('dtype', [np.complex64, np.cdouble, np.clongdouble])
def test_complex_inf_nan(dtype):
"""Check inf/nan formatting of complex types."""
# 定义一个字典 TESTS,包含各种复数值和对应的字符串表示
TESTS = {
complex(np.inf, 0): "(inf+0j)", # 复数值为正无穷大加零,期望结果是字符串 "(inf+0j)"
complex(0, np.inf): "infj", # 复数值为零加正无穷大,期望结果是字符串 "infj"
complex(-np.inf, 0): "(-inf+0j)", # 复数值为负无穷大加零,期望结果是字符串 "(-inf+0j)"
complex(0, -np.inf): "-infj", # 复数值为零加负无穷大,期望结果是字符串 "-infj"
complex(np.inf, 1): "(inf+1j)", # 复数值为正无穷大加一,期望结果是字符串 "(inf+1j)"
complex(1, np.inf): "(1+infj)", # 复数值为一加正无穷大,期望结果是字符串 "(1+infj)"
complex(-np.inf, 1): "(-inf+1j)", # 复数值为负无穷大加一,期望结果是字符串 "(-inf+1j)"
complex(1, -np.inf): "(1-infj)", # 复数值为一加负无穷大,期望结果是字符串 "(1-infj)"
complex(np.nan, 0): "(nan+0j)", # 复数值为 NaN 加零,期望结果是字符串 "(nan+0j)"
complex(0, np.nan): "nanj", # 复数值为零加 NaN,期望结果是字符串 "nanj"
complex(-np.nan, 0): "(nan+0j)", # 复数值为负 NaN 加零,期望结果是字符串 "(nan+0j)"
complex(0, -np.nan): "nanj", # 复数值为零加负 NaN,期望结果是字符串 "nanj"
complex(np.nan, 1): "(nan+1j)", # 复数值为 NaN 加一,期望结果是字符串 "(nan+1j)"
complex(1, np.nan): "(1+nanj)", # 复数值为一加 NaN,期望结果是字符串 "(1+nanj)"
complex(-np.nan, 1): "(nan+1j)", # 复数值为负 NaN 加一,期望结果是字符串 "(nan+1j)"
complex(1, -np.nan): "(1+nanj)", # 复数值为一加负 NaN,期望结果是字符串 "(1+nanj)"
}
# 遍历 TESTS 字典中的每个复数和对应的字符串表示
for c, s in TESTS.items():
# 使用 assert_equal 函数检查复数 c 的数据类型的字符串表示是否等于预期字符串 s
assert_equal(str(dtype(c)), s)
# print tests
# 定义用于测试重定向打印输出的函数,接受参数 x (输入值)、tp (类型)、ref (参考值,默认为 None)
def _test_redirected_print(x, tp, ref=None):
# 创建一个内存中的字符串流对象用于捕获标准输出
file = StringIO()
# 创建另一个字符串流对象,用于比较标准输出
file_tp = StringIO()
# 保存当前标准输出对象
stdout = sys.stdout
try:
# 重定向标准输出到 file_tp 对象
sys.stdout = file_tp
# 打印 tp(x) 的结果到重定向的标准输出
print(tp(x))
# 再次重定向标准输出到 file 对象
sys.stdout = file
# 如果有提供参考值 ref,则打印 ref;否则打印 x
if ref:
print(ref)
else:
print(x)
finally:
# 恢复原始的标准输出
sys.stdout = stdout
# 断言重定向后的输出与预期一致,用于测试目的
assert_equal(file.getvalue(), file_tp.getvalue(),
err_msg='print failed for type%s' % tp)
# 使用 pytest 的 parametrize 装饰器,对不同的浮点类型进行格式化输出测试
@pytest.mark.parametrize('tp', [np.float32, np.double, np.longdouble])
def test_float_type_print(tp):
"""Check formatting when using print """
# 对一些常见的浮点数进行测试
for x in [0, 1, -1, 1e20]:
_test_redirected_print(float(x), tp)
# 对无穷大、负无穷大和 NaN 进行特殊测试
for x in [np.inf, -np.inf, np.nan]:
_test_redirected_print(float(x), tp, _REF[x])
# 如果浮点类型 tp 的字节大小大于 4,则继续测试 1e16 的输出
if tp(1e16).itemsize > 4:
_test_redirected_print(float(1e16), tp)
else:
# 否则使用参考值 '1e+16'
ref = '1e+16'
_test_redirected_print(float(1e16), tp, ref)
# 使用 pytest 的 parametrize 装饰器,对不同的复数类型进行格式化输出测试
@pytest.mark.parametrize('tp', [np.complex64, np.cdouble, np.clongdouble])
def test_complex_type_print(tp):
"""Check formatting when using print """
# 对一些常见的复数进行测试
# 因为 Python < 2.6 不支持直接创建具有无穷大和 NaN 的复数,所以避免直接使用它们
for x in [0, 1, -1, 1e20]:
_test_redirected_print(complex(x), tp)
# 如果复数类型 tp 的字节大小大于 8,则继续测试 1e16 的输出
if tp(1e16).itemsize > 8:
_test_redirected_print(complex(1e16), tp)
else:
# 否则使用参考值 '(1e+16+0j)'
ref = '(1e+16+0j)'
_test_redirected_print(complex(1e16), tp, ref)
# 测试特殊的复数情况:无穷大、负无穷大和 NaN
_test_redirected_print(complex(np.inf, 1), tp, '(inf+1j)')
_test_redirected_print(complex(-np.inf, 1), tp, '(-inf+1j)')
_test_redirected_print(complex(-np.nan, 1), tp, '(nan+1j)')
# 测试 NumPy 标量类型的字符串格式化输出方法
def test_scalar_format():
"""Test the str.format method with NumPy scalar types"""
# 定义一系列测试用例,包括格式字符串、值和值的类型
tests = [('{0}', True, np.bool),
('{0}', False, np.bool),
('{0:d}', 130, np.uint8),
('{0:d}', 50000, np.uint16),
('{0:d}', 3000000000, np.uint32),
('{0:d}', 15000000000000000000, np.uint64),
('{0:d}', -120, np.int8),
('{0:d}', -30000, np.int16),
('{0:d}', -2000000000, np.int32),
('{0:d}', -7000000000000000000, np.int64),
('{0:g}', 1.5, np.float16),
('{0:g}', 1.5, np.float32),
('{0:g}', 1.5, np.float64),
('{0:g}', 1.5, np.longdouble),
('{0:g}', 1.5+0.5j, np.complex64),
('{0:g}', 1.5+0.5j, np.complex128),
('{0:g}', 1.5+0.5j, np.clongdouble)]
# 遍历测试用例并执行测试
for (fmat, val, valtype) in tests:
try:
# 使用格式化字符串 fmat 对 val 和 valtype(val) 进行格式化
assert_equal(fmat.format(val), fmat.format(valtype(val)),
"failed with val %s, type %s" % (val, valtype))
except ValueError as e:
assert_(False,
"format raised exception (fmt='%s', val=%s, type=%s, exc='%s')" %
(fmat, repr(val), repr(valtype), str(e)))
#
# Locale tests: scalar types formatting should be independent of the locale
#
# 定义一个测试类 TestCommaDecimalPointLocale,继承自 CommaDecimalPointLocale 类
class TestCommaDecimalPointLocale(CommaDecimalPointLocale):
# 测试单精度浮点数在本地化环境下的字符串表示是否正确
def test_locale_single(self):
# 断言单精度浮点数的字符串表示应与普通浮点数相同
assert_equal(str(np.float32(1.2)), str(float(1.2)))
# 测试双精度浮点数在本地化环境下的字符串表示是否正确
def test_locale_double(self):
# 断言双精度浮点数的字符串表示应与普通浮点数相同
assert_equal(str(np.double(1.2)), str(float(1.2)))
# 根据条件标记,跳过测试长双精度浮点数在本地化环境下的字符串表示
@pytest.mark.skipif(IS_MUSL,
reason="test flaky on musllinux")
def test_locale_longdouble(self):
# 断言长双精度浮点数的字符串表示应与普通浮点数相同
assert_equal(str(np.longdouble('1.2')), str(float(1.2)))
.\numpy\numpy\_core\tests\test_protocols.py
# 导入 pytest 库,用于编写和运行测试用例
import pytest
# 导入 warnings 库,用于处理警告信息
import warnings
# 导入 numpy 库,并使用 np 别名引用
import numpy as np
# 使用 pytest.mark.filterwarnings("error") 装饰器标记此测试用例,将警告信息转换为错误
@pytest.mark.filterwarnings("error")
# 定义一个测试函数 test_getattr_warning,用于测试 __getattr__ 方法的警告行为
def test_getattr_warning():
# 解决问题 gh-14735:确保仅清除 getattr 错误,并允许警告通过
class Wrapper:
# 初始化方法,接受一个数组作为参数
def __init__(self, array):
self.array = array
# 定义 __len__ 方法,返回数组的长度
def __len__(self):
return len(self.array)
# 定义 __getitem__ 方法,用于获取数组的元素
def __getitem__(self, item):
return type(self)(self.array[item])
# 定义 __getattr__ 方法,处理动态属性访问
def __getattr__(self, name):
# 如果属性名以 "__array_" 开头,发出警告
if name.startswith("__array_"):
warnings.warn("object got converted", UserWarning, stacklevel=1)
# 否则委托给数组对象的相应属性
return getattr(self.array, name)
# 定义 __repr__ 方法,返回对象的字符串表示形式
def __repr__(self):
return "<Wrapper({self.array})>".format(self=self)
# 创建一个 Wrapper 实例,包装一个包含 0 到 9 的 numpy 数组
array = Wrapper(np.arange(10))
# 使用 pytest.raises 检测是否抛出特定类型和消息的警告
with pytest.raises(UserWarning, match="object got converted"):
np.asarray(array)
# 定义一个测试函数 test_array_called,测试自定义类的 __array__ 方法
def test_array_called():
# 定义一个名为 Wrapper 的类
class Wrapper:
# 类属性,包含一个长度为 100 的字符串 '0'
val = '0' * 100
# 定义 __array__ 方法,返回一个包含 self.val 的 numpy 数组
def __array__(self, dtype=None, copy=None):
return np.array([self.val], dtype=dtype, copy=copy)
# 创建一个 Wrapper 实例
wrapped = Wrapper()
# 将 Wrapper 实例转换为 numpy 数组,指定 dtype 为字符串
arr = np.array(wrapped, dtype=str)
# 断言 numpy 数组的 dtype 是否符合预期
assert arr.dtype == 'U100'
# 断言 numpy 数组的第一个元素是否与 Wrapper 类的 val 属性相等
assert arr[0] == Wrapper.val
.\numpy\numpy\_core\tests\test_records.py
# 导入 collections.abc 模块,这是 Python 中抽象基类集合的一部分
import collections.abc
# 导入 textwrap 模块,用于文本的换行和填充
import textwrap
# 导入 BytesIO 类,用于在内存中读写二进制数据
from io import BytesIO
# 导入 path 函数,提供对文件路径操作的函数
from os import path
# 导入 Path 类,用于高级路径操作
from pathlib import Path
# 导入 pickle 模块,用于序列化和反序列化 Python 对象
import pickle
# 导入 pytest 模块,用于编写和运行测试用例
import pytest
# 导入 numpy 库
import numpy as np
# 从 numpy.testing 模块导入一些用于测试的函数和类
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_array_almost_equal,
assert_raises, temppath,
)
# 定义一个测试类 TestFromrecords
class TestFromrecords:
# 定义测试方法 test_fromrecords
def test_fromrecords(self):
# 使用 np.rec.fromrecords 方法创建一个结构化数组 r
r = np.rec.fromrecords([[456, 'dbe', 1.2], [2, 'de', 1.3]],
names='col1,col2,col3')
# 断言结构化数组 r 的第一个元素
assert_equal(r[0].item(), (456, 'dbe', 1.2))
# 断言结构化数组 r 的 'col1' 列的数据类型是整数
assert_equal(r['col1'].dtype.kind, 'i')
# 断言结构化数组 r 的 'col2' 列的数据类型是 Unicode 字符串
assert_equal(r['col2'].dtype.kind, 'U')
# 断言结构化数组 r 的 'col2' 列的每个元素的大小为 12
assert_equal(r['col2'].dtype.itemsize, 12)
# 断言结构化数组 r 的 'col3' 列的数据类型是浮点数
assert_equal(r['col3'].dtype.kind, 'f')
# 定义测试方法 test_fromrecords_0len,测试空输入情况下 np.rec.fromrecords 方法的表现
def test_fromrecords_0len(self):
""" Verify fromrecords works with a 0-length input """
# 定义结构化数组的 dtype
dtype = [('a', float), ('b', float)]
# 使用空列表创建结构化数组 r
r = np.rec.fromrecords([], dtype=dtype)
# 断言结构化数组 r 的形状为 (0,)
assert_equal(r.shape, (0,))
# 定义测试方法 test_fromrecords_2d,测试处理二维数据情况下 np.rec.fromrecords 方法的表现
def test_fromrecords_2d(self):
# 定义输入数据
data = [
[(1, 2), (3, 4), (5, 6)],
[(6, 5), (4, 3), (2, 1)]
]
# 期望的 'a' 列数据
expected_a = [[1, 3, 5], [6, 4, 2]]
# 期望的 'b' 列数据
expected_b = [[2, 4, 6], [5, 3, 1]]
# 使用指定 dtype 创建结构化数组 r1
r1 = np.rec.fromrecords(data, dtype=[('a', int), ('b', int)])
# 断言结构化数组 r1 的 'a' 列数据符合预期
assert_equal(r1['a'], expected_a)
# 断言结构化数组 r1 的 'b' 列数据符合预期
assert_equal(r1['b'], expected_b)
# 使用列名创建结构化数组 r2
r2 = np.rec.fromrecords(data, names=['a', 'b'])
# 断言结构化数组 r2 的 'a' 列数据符合预期
assert_equal(r2['a'], expected_a)
# 断言结构化数组 r2 的 'b' 列数据符合预期
assert_equal(r2['b'], expected_b)
# 断言结构化数组 r1 等于结构化数组 r2
assert_equal(r1, r2)
# 定义测试方法 test_method_array,测试 np.rec.array 方法的表现
def test_method_array(self):
# 使用 np.rec.array 方法创建结构化数组 r
r = np.rec.array(
b'abcdefg' * 100, formats='i2,S3,i4', shape=3, byteorder='big'
)
# 断言结构化数组 r 的第二个元素
assert_equal(r[1].item(), (25444, b'efg', 1633837924))
# 定义测试方法 test_method_array2,测试处理特定数据情况下 np.rec.array 方法的表现
def test_method_array2(self):
# 使用 np.rec.array 方法创建结构化数组 r
r = np.rec.array(
[
(1, 11, 'a'), (2, 22, 'b'), (3, 33, 'c'), (4, 44, 'd'),
(5, 55, 'ex'), (6, 66, 'f'), (7, 77, 'g')
],
formats='u1,f4,S1'
)
# 断言结构化数组 r 的第二个元素
assert_equal(r[1].item(), (2, 22.0, b'b'))
# 定义测试方法 test_recarray_slices,测试对结构化数组切片的行为
def test_recarray_slices(self):
# 使用 np.rec.array 方法创建结构化数组 r
r = np.rec.array(
[
(1, 11, 'a'), (2, 22, 'b'), (3, 33, 'c'), (4, 44, 'd'),
(5, 55, 'ex'), (6, 66, 'f'), (7, 77, 'g')
],
formats='u1,f4,S1'
)
# 断言对结构化数组 r 进行切片后的第二行元素
assert_equal(r[1::2][1].item(), (4, 44.0, b'd'))
# 定义测试方法 test_recarray_fromarrays,测试从数组创建结构化数组的表现
def test_recarray_fromarrays(self):
# 创建三个数组
x1 = np.array([1, 2, 3, 4])
x2 = np.array(['a', 'dd', 'xyz', '12'])
x3 = np.array([1.1, 2, 3, 4])
# 使用 np.rec.fromarrays 方法创建结构化数组 r
r = np.rec.fromarrays([x1, x2, x3], names='a,b,c')
# 断言结构化数组 r 的第二个元素
assert_equal(r[1].item(), (2, 'dd', 2.0))
# 修改 x1 的第二个元素
x1[1] = 34
# 断言结构化数组 r 的 'a' 列与原始数组 x1 相同
assert_equal(r.a, np.array([1, 2, 3, 4]))
# 定义测试函数,测试从文件中创建记录数组
def test_recarray_fromfile(self):
# 获取数据文件目录路径
data_dir = path.join(path.dirname(__file__), 'data')
# 获取文件完整路径
filename = path.join(data_dir, 'recarray_from_file.fits')
# 以二进制方式打开文件
fd = open(filename, 'rb')
# 移动文件指针到指定位置
fd.seek(2880 * 2)
# 从文件中读取数据创建记录数组
r1 = np.rec.fromfile(fd, formats='f8,i4,S5', shape=3, byteorder='big')
# 再次移动文件指针到指定位置
fd.seek(2880 * 2)
# 从文件中读取数据创建记录数组
r2 = np.rec.array(fd, formats='f8,i4,S5', shape=3, byteorder='big')
# 再次移动文件指针到指定位置
fd.seek(2880 * 2)
# 创建一个字节流对象
bytes_array = BytesIO()
# 将文件内容写入字节流
bytes_array.write(fd.read())
# 将字节流指针移动到起始位置
bytes_array.seek(0)
# 从字节流中读取数据创建记录数组
r3 = np.rec.fromfile(
bytes_array, formats='f8,i4,S5', shape=3, byteorder='big'
)
# 关闭文件
fd.close()
# 断言r1和r2的内容相等
assert_equal(r1, r2)
# 断言r2和r3的内容相等
assert_equal(r2, r3)
# 定义测试函数,测试从对象中创建记录数组
def test_recarray_from_obj(self):
# 计数值
count = 10
# 创建dtype为对象类型、值全为0的数组a
a = np.zeros(count, dtype='O')
# 创建dtype为浮点数类型、值全为0的数组b
b = np.zeros(count, dtype='f8')
# 创建dtype为浮点数类型、值全为0的数组c
c = np.zeros(count, dtype='f8')
# 循环将a数组的每个元素赋值为从1到9的列表
for i in range(len(a)):
a[i] = list(range(1, 10))
# 从数组a、b和c创建记录数组mine
mine = np.rec.fromarrays([a, b, c], names='date,data1,data2')
# 循环验证记录数组mine中的数据
for i in range(len(a)):
assert_((mine.date[i] == list(range(1, 10))))
assert_((mine.data1[i] == 0.0))
assert_((mine.data2[i] == 0.0))
# 定义测试函数,测试记录数组的显示
def test_recarray_repr(self):
# 创建元组数组a
a = np.array([(1, 0.1), (2, 0.2)],
dtype=[('foo', '<i4'), ('bar', '<f8')])
# 转换为记录数组
a = np.rec.array(a)
# 断言记录数组a的显示格式
assert_equal(
repr(a),
textwrap.dedent("""\
rec.array([(1, 0.1), (2, 0.2)],
dtype=[('foo', '<i4'), ('bar', '<f8')])""")
)
# 创建dtype为浮点数类型、值全为1的数组a
a = np.array(np.ones(4, dtype='f8'))
# 断言将数组a转换为记录数组的显示格式
assert_(repr(np.rec.array(a)).startswith('rec.array'))
# 创建dtype为整数和整数类型的数组a
a = np.rec.array(np.ones(3, dtype='i4,i4'))
# 断言将数组a转换为记录数组的显示格式
assert_equal(repr(a).find('numpy.record'), -1)
# 创建dtype为整数类型的数组a
a = np.rec.array(np.ones(3, dtype='i4'))
# 断言将数组a转换为记录数组的显示格式
assert_(repr(a).find('dtype=int32') != -1)
# 定义测试函数,测试0维记录数组的显示
def test_0d_recarray_repr(self):
# 创建0维记录数组arr_0d
arr_0d = np.rec.array((1, 2.0, '2003'), dtype='<i4,<f8,<M8[Y]')
# 断言0维记录数组arr_0d的显示格式
assert_equal(repr(arr_0d), textwrap.dedent("""\
rec.array((1, 2., '2003'),
dtype=[('f0', '<i4'), ('f1', '<f8'), ('f2', '<M8[Y]')])"""))
# 获取0维记录数组的记录
record = arr_0d[()]
# 断言获取的记录的显示格式
assert_equal(repr(record),
"np.record((1, 2.0, '2003'), "
"dtype=[('f0', '<i4'), ('f1', '<f8'), ('f2', '<M8[Y]')])")
# 断言设置为1.13时的记录显示格式
try:
np.set_printoptions(legacy='1.13')
assert_equal(repr(record), '(1, 2.0, datetime.date(2003, 1, 1))')
finally:
np.set_printoptions(legacy=False)
# 定义一个测试方法,测试从字符串表示形式创建的结构化数组
def test_recarray_from_repr(self):
# 创建一个包含两个元组的 numpy 数组 a,每个元组包含一个整数和一个字符串
a = np.array([(1,'ABC'), (2, "DEF")],
dtype=[('foo', int), ('bar', 'S4')])
# 使用 np.rec.array 将数组 a 转换为结构化数组 recordarr
recordarr = np.rec.array(a)
# 将数组 a 视图转换为结构化数组 recarr
recarr = a.view(np.recarray)
# 将数组 a 视图转换为具有记录 dtype 的 numpy 数组 recordview
recordview = a.view(np.dtype((np.record, a.dtype)))
# 通过 eval 函数将字符串表示形式的 recordarr 转换回 numpy 结构化数组
recordarr_r = eval("np." + repr(recordarr), {'np': np})
# 通过 eval 函数将字符串表示形式的 recarr 转换回 numpy 结构化数组
recarr_r = eval("np." + repr(recarr), {'np': np})
# 通过 eval 函数将字符串表示形式的 recordview 转换回 numpy 数组,保留记录 dtype
recordview_r = eval("np." + repr(recordview), {'np': np, 'numpy': np})
# 断言转换后的 recordarr_r 的类型为 np.recarray
assert_equal(type(recordarr_r), np.recarray)
# 断言转换后的 recordarr_r 的 dtype 的类型为 np.record
assert_equal(recordarr_r.dtype.type, np.record)
# 断言转换后的 recordarr_r 等于原始的 recordarr
# 断言转换后的 recarr_r 的类型为 np.recarray
assert_equal(type(recarr_r), np.recarray)
# 断言转换后的 recarr_r 的 dtype 的类型为 np.record
assert_equal(recarr_r.dtype.type, np.record)
# 断言转换后的 recarr_r 等于原始的 recarr
# 断言转换后的 recordview_r 的类型为 np.ndarray
assert_equal(type(recordview_r), np.ndarray)
# 断言转换后的 recordview 的 dtype 的类型为 np.record
assert_equal(recordview.dtype.type, np.record)
# 断言转换后的 recordview_r 等于原始的 recordview
# 定义一个测试方法,用于测试记录数组的视图操作
def test_recarray_views(self):
# 创建一个结构化数组a,包含两个元组
a = np.array([(1,'ABC'), (2, "DEF")],
dtype=[('foo', int), ('bar', 'S4')])
# 创建一个普通数组b,包含整数值
b = np.array([1,2,3,4,5], dtype=np.int64)
# 检查 np.rec.array 是否返回正确的数据类型
assert_equal(np.rec.array(a).dtype.type, np.record)
# 检查 np.rec.array 是否返回 np.recarray 类型
assert_equal(type(np.rec.array(a)), np.recarray)
# 检查 np.rec.array 是否返回正确的数据类型
assert_equal(np.rec.array(b).dtype.type, np.int64)
# 检查 np.rec.array 是否返回 np.recarray 类型
assert_equal(type(np.rec.array(b)), np.recarray)
# 检查将数组a视图为 np.recarray 是否产生相同的结果
assert_equal(a.view(np.recarray).dtype.type, np.record)
# 检查将数组a视图为 np.recarray 是否返回 np.recarray 类型
assert_equal(type(a.view(np.recarray)), np.recarray)
# 检查将数组b视图为 np.recarray 是否产生相同的结果
assert_equal(b.view(np.recarray).dtype.type, np.int64)
# 检查将数组b视图为 np.recarray 是否返回 np.recarray 类型
assert_equal(type(b.view(np.recarray)), np.recarray)
# 检查视图转换到非结构化数据类型时是否保留 np.recarray 类型
r = np.rec.array(np.ones(4, dtype="f4,i4"))
rv = r.view('f8').view('f4,i4')
assert_equal(type(rv), np.recarray)
assert_equal(rv.dtype.type, np.record)
# 检查通过索引访问是否保留 np.recarray 和 np.record 类型
r = np.rec.array(np.ones(4, dtype=[('a', 'i4'), ('b', 'i4'),
('c', 'i4,i4')]))
assert_equal(r['c'].dtype.type, np.record)
assert_equal(type(r['c']), np.recarray)
# 检查是否保留子类 (gh-6949)
class C(np.recarray):
pass
c = r.view(C)
assert_equal(type(c['c']), C)
# 检查访问嵌套结构时是否保留记录类型,但对于子数组、非空结构、非结构化空类型不保留
test_dtype = [('a', 'f4,f4'), ('b', 'V8'), ('c', ('f4',2)),
('d', ('i8', 'i4,i4'))]
r = np.rec.array([((1,1), b'11111111', [1,1], 1),
((1,1), b'11111111', [1,1], 1)], dtype=test_dtype)
assert_equal(r.a.dtype.type, np.record)
assert_equal(r.b.dtype.type, np.void)
assert_equal(r.c.dtype.type, np.float32)
assert_equal(r.d.dtype.type, np.int64)
# 检查相同操作在视图中是否保持一致
r = np.rec.array(np.ones(4, dtype='i4,i4'))
assert_equal(r.view('f4,f4').dtype.type, np.record)
assert_equal(r.view(('i4',2)).dtype.type, np.int32)
assert_equal(r.view('V8').dtype.type, np.void)
assert_equal(r.view(('i8', 'i4,i4')).dtype.type, np.int64)
# 检查能否撤销视图操作
arrs = [np.ones(4, dtype='f4,i4'), np.ones(4, dtype='f8')]
for arr in arrs:
rec = np.rec.array(arr)
# 建议的视图为 ndarray 的方法:
arr2 = rec.view(rec.dtype.fields or rec.dtype, np.ndarray)
assert_equal(arr2.dtype.type, arr.dtype.type)
assert_equal(type(arr2), type(arr))
# 定义测试函数,用于测试从给定字段名创建的记录数组
def test_recarray_from_names(self):
# 创建具有指定字段名的记录数组 ra
ra = np.rec.array([
(1, 'abc', 3.7000002861022949, 0),
(2, 'xy', 6.6999998092651367, 1),
(0, ' ', 0.40000000596046448, 0)],
names='c1, c2, c3, c4')
# 使用给定记录列表创建记录数组 pa
pa = np.rec.fromrecords([
(1, 'abc', 3.7000002861022949, 0),
(2, 'xy', 6.6999998092651367, 1),
(0, ' ', 0.40000000596046448, 0)],
names='c1, c2, c3, c4')
# 断言 ra 的数据类型与 pa 的数据类型相等
assert_(ra.dtype == pa.dtype)
# 断言 ra 的形状与 pa 的形状相等
assert_(ra.shape == pa.shape)
# 遍历 ra 的每一行,断言每一行的条目与 pa 中对应行的条目相等
for k in range(len(ra)):
assert_(ra[k].item() == pa[k].item())
# 定义测试函数,用于测试记录数组中字段冲突的情况
def test_recarray_conflict_fields(self):
# 创建具有字段 'field', 'shape', 'mean' 的记录数组 ra
ra = np.rec.array([(1, 'abc', 2.3), (2, 'xyz', 4.2),
(3, 'wrs', 1.3)],
names='field, shape, mean')
# 将 ra 的 mean 属性设置为指定列表
ra.mean = [1.1, 2.2, 3.3]
# 断言 ra 的 'mean' 字段几乎等于指定列表
assert_array_almost_equal(ra['mean'], [1.1, 2.2, 3.3])
# 断言 ra.mean 的类型与 ra.var 的类型相同
assert_(type(ra.mean) is type(ra.var))
# 修改 ra 的形状为 (1, 3)
ra.shape = (1, 3)
# 断言 ra 的形状确实为 (1, 3)
assert_(ra.shape == (1, 3))
# 修改 ra 的形状为指定列表
ra.shape = ['A', 'B', 'C']
# 断言 ra 的 'shape' 字段等于指定二维列表
assert_array_equal(ra['shape'], [['A', 'B', 'C']])
# 将 ra 的 'field' 字段设置为指定值
ra.field = 5
# 断言 ra 的 'field' 字段等于指定二维数组
assert_array_equal(ra['field'], [[5, 5, 5]])
# 断言 ra.field 的类型为 collections.abc.Callable 的实例
assert_(isinstance(ra.field, collections.abc.Callable))
# 定义测试函数,用于测试从记录列表创建记录数组并指定数据类型的情况
def test_fromrecords_with_explicit_dtype(self):
# 使用指定数据类型创建记录数组 a
a = np.rec.fromrecords([(1, 'a'), (2, 'bbb')],
dtype=[('a', int), ('b', object)])
# 断言记录数组 a 的 'a' 字段等于指定列表
assert_equal(a.a, [1, 2])
# 断言记录数组 a 的第一个元素的 'a' 字段等于指定值
assert_equal(a[0].a, 1)
# 断言记录数组 a 的 'b' 字段等于指定列表
assert_equal(a.b, ['a', 'bbb'])
# 断言记录数组 a 的最后一个元素的 'b' 字段等于指定值
assert_equal(a[-1].b, 'bbb')
#
# 使用指定的 numpy 数据类型创建记录数组 a
ndtype = np.dtype([('a', int), ('b', object)])
a = np.rec.fromrecords([(1, 'a'), (2, 'bbb')], dtype=ndtype)
# 断言记录数组 a 的 'a' 字段等于指定列表
assert_equal(a.a, [1, 2])
# 断言记录数组 a 的第一个元素的 'a' 字段等于指定值
assert_equal(a[0].a, 1)
# 断言记录数组 a 的 'b' 字段等于指定列表
assert_equal(a.b, ['a', 'bbb'])
# 断言记录数组 a 的最后一个元素的 'b' 字段等于指定值
assert_equal(a[-1].b, 'bbb')
# 定义测试函数,用于测试字符串类型的记录数组创建
def test_recarray_stringtypes(self):
# 创建具有字段 'foo' 和 'bar' 的字符串类型的记录数组 a
a = np.array([('abc ', 1), ('abc', 2)],
dtype=[('foo', 'S4'), ('bar', int)])
# 将 ndarray 转换为 recarray 类型
a = a.view(np.recarray)
# 断言记录数组 a 的 'foo' 字段的第一个元素不等于 'foo' 字段的第二个元素
assert_equal(a.foo[0] == a.foo[1], False)
def test_recarray_returntypes(self):
# 定义一个包含结构化数据的 numpy 递归数组
qux_fields = {'C': (np.dtype('S5'), 0), 'D': (np.dtype('S5'), 6)}
a = np.rec.array([('abc ', (1,1), 1, ('abcde', 'fgehi')),
('abc', (2,3), 1, ('abcde', 'jklmn'))],
dtype=[('foo', 'S4'),
('bar', [('A', int), ('B', int)]),
('baz', int), ('qux', qux_fields)])
# 断言各属性的返回类型
assert_equal(type(a.foo), np.ndarray)
assert_equal(type(a['foo']), np.ndarray)
assert_equal(type(a.bar), np.recarray)
assert_equal(type(a['bar']), np.recarray)
# 检查 bar 属性的数据类型是记录数组
assert_equal(a.bar.dtype.type, np.record)
# 检查 qux 属性及其字段是否与预定义的 qux_fields 字典匹配
assert_equal(type(a['qux']), np.recarray)
assert_equal(a.qux.dtype.type, np.record)
assert_equal(dict(a.qux.dtype.fields), qux_fields)
# 检查 baz 属性的返回类型
assert_equal(type(a.baz), np.ndarray)
assert_equal(type(a['baz']), np.ndarray)
# 检查索引访问方式下 bar 属性的数据类型是记录
assert_equal(type(a[0].bar), np.record)
assert_equal(type(a[0]['bar']), np.record)
# 访问特定字段及其值,确保值的正确性
assert_equal(a[0].bar.A, 1)
assert_equal(a[0].bar['A'], 1)
assert_equal(a[0]['bar'].A, 1)
assert_equal(a[0]['bar']['A'], 1)
# 访问 qux 属性中特定字段 D,并检查其值是否正确
assert_equal(a[0].qux.D, b'fgehi')
assert_equal(a[0].qux['D'], b'fgehi')
assert_equal(a[0]['qux'].D, b'fgehi')
assert_equal(a[0]['qux']['D'], b'fgehi')
# 定义一个测试类 TestPathUsage,用于测试 pathlib.Path 的使用
class TestPathUsage:
# 定义测试方法 test_tofile_fromfile,测试将数组写入文件并从文件读取
def test_tofile_fromfile(self):
# 使用 temppath 创建临时文件路径,文件名后缀为 '.bin'
with temppath(suffix='.bin') as path:
# 将路径对象转换为 pathlib.Path 对象
path = Path(path)
# 设定随机数种子为 123
np.random.seed(123)
# 创建一个包含随机数据的数组,数据类型包括 float64、int32、字符串长度为 5
a = np.random.rand(10).astype('f8,i4,S5')
a[5] = (0.5, 10, 'abcde')
# 以二进制写入方式打开路径对应的文件
with path.open("wb") as fd:
# 将数组 a 写入文件
a.tofile(fd)
# 从文件中读取指定格式的数据为数组 x
x = np._core.records.fromfile(
path, formats='f8,i4,S5', shape=10
)
# 断言数组 x 的内容与数组 a 相等
assert_array_equal(x, a)
# 定义测试类 TestRecord,用于测试记录数组的操作
class TestRecord:
# 在每个测试方法运行前设置测试数据
def setup_method(self):
self.data = np.rec.fromrecords([(1, 2, 3), (4, 5, 6)],
dtype=[("col1", "<i4"),
("col2", "<i4"),
("col3", "<i4")])
# 测试记录数组的元素赋值操作
def test_assignment1(self):
a = self.data
# 断言记录数组的第一个元素的 col1 字段为 1
assert_equal(a.col1[0], 1)
# 修改记录数组的第一个元素的 col1 字段为 0
a[0].col1 = 0
# 断言记录数组的第一个元素的 col1 字段已修改为 0
assert_equal(a.col1[0], 0)
# 测试记录数组的字段直接赋值操作
def test_assignment2(self):
a = self.data
# 断言记录数组的第一个元素的 col1 字段为 1
assert_equal(a.col1[0], 1)
# 直接修改记录数组的第一个元素的 col1 字段为 0
a.col1[0] = 0
# 断言记录数组的第一个元素的 col1 字段已修改为 0
assert_equal(a.col1[0], 0)
# 测试尝试对记录数组中不存在的字段进行赋值操作
def test_invalid_assignment(self):
a = self.data
def assign_invalid_column(x):
# 尝试给记录数组中不存在的 col5 字段赋值
x[0].col5 = 1
# 断言会抛出 AttributeError 异常
assert_raises(AttributeError, assign_invalid_column, a)
# 测试记录数组字段不可写时的 setfield 操作
def test_nonwriteable_setfield(self):
# 创建一个记录数组 r,包含两个元素,每个元素有一个整数字段 'f'
r = np.rec.array([(0,), (1,)], dtype=[('f', 'i4')])
# 设置记录数组为不可写
r.flags.writeable = False
# 断言使用赋值方式修改字段值会引发 ValueError 异常
with assert_raises(ValueError):
r.f = [2, 3]
# 断言使用 setfield 方法修改字段值会引发 ValueError 异常
with assert_raises(ValueError):
r.setfield([2,3], *r.dtype.fields['f'])
# 测试记录数组选择字段时的 dtype 结构
def test_out_of_order_fields(self):
# 选择记录数组中的 'col1' 和 'col2' 字段,创建新的数组 x
x = self.data[['col1', 'col2']]
# 断言新数组 x 的字段名为 ('col1', 'col2')
assert_equal(x.dtype.names, ('col1', 'col2'))
# 断言新数组 x 的 dtype 描述结构为 [('col1', '<i4'), ('col2', '<i4'), ('', '|V4')]
assert_equal(x.dtype.descr,
[('col1', '<i4'), ('col2', '<i4'), ('', '|V4')])
# 将字段名顺序调整为与索引顺序一致,作为 1.14 版本后的索引,预期引发 ValueError 异常
y = self.data[['col2', 'col1']]
assert_equal(y.dtype.names, ('col2', 'col1'))
assert_raises(ValueError, lambda: y.dtype.descr)
# 测试记录数组的序列化和反序列化
def test_pickle_1(self):
# 创建一个具有一个整数字段 'a' 和一个空数组字段 'b' 的数组 a
a = np.array([(1, [])], dtype=[('a', np.int32), ('b', np.int32, 0)])
# 针对所有协议版本进行序列化和反序列化的测试
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
assert_equal(a, pickle.loads(pickle.dumps(a, protocol=proto)))
assert_equal(a[0], pickle.loads(pickle.dumps(a[0],
protocol=proto)))
# 测试记录数组对象的序列化和反序列化
def test_pickle_2(self):
a = self.data
# 针对所有协议版本进行记录数组对象的序列化和反序列化的测试
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
assert_equal(a, pickle.loads(pickle.dumps(a, protocol=proto)))
assert_equal(a[0], pickle.loads(pickle.dumps(a[0],
protocol=proto)))
def test_pickle_3(self):
# Issue #7140
# 从 self.data 中获取数据
a = self.data
# 遍历从协议版本 2 到最高协议版本的范围
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
# 使用指定协议序列化和反序列化对象 a[0]
pa = pickle.loads(pickle.dumps(a[0], protocol=proto))
# 断言 pa 是 C 连续存储的
assert_(pa.flags.c_contiguous)
# 断言 pa 是 Fortran 连续存储的
assert_(pa.flags.f_contiguous)
# 断言 pa 是可写的
assert_(pa.flags.writeable)
# 断言 pa 是对齐的
assert_(pa.flags.aligned)
def test_pickle_void(self):
# issue gh-13593
# 定义一个复杂的数据类型 dt
dt = np.dtype([('obj', 'O'), ('int', 'i')])
# 创建一个空数组 a,使用复杂数据类型 dt
a = np.empty(1, dtype=dt)
# 创建一个数据对象 data,是一个元组,包含一个字节数组
data = (bytearray(b'eman'),)
# 将 data 赋值给 a 的 'obj' 字段
a['obj'] = data
# 将整数值 42 赋值给 a 的 'int' 字段
a['int'] = 42
# 获取 a[0] 的 __reduce__ 方法的返回值
ctor, args = a[0].__reduce__()
# 断言构造函数是 np._core.multiarray.scalar
assert ctor is np._core.multiarray.scalar
# 解析参数之前检查构造函数是否符合预期
dtype, obj = args
# 确保未序列化地址信息
assert not isinstance(obj, bytes)
# 断言调用构造函数会引发 RuntimeError 异常
assert_raises(RuntimeError, ctor, dtype, 13)
# 测试序列化和反序列化的往返
dump = pickle.dumps(a[0])
unpickled = pickle.loads(dump)
assert a[0] == unpickled
# 还要检查类似但不可能的 "object scalar" 路径
with pytest.warns(DeprecationWarning):
assert ctor(np.dtype("O"), data) is data
def test_objview_record(self):
# https://github.com/numpy/numpy/issues/2599
# 定义一个复杂的数据类型 dt
dt = np.dtype([('foo', 'i8'), ('bar', 'O')])
# 创建一个形状为 (1, 3) 的全零数组,使用复杂数据类型 dt,再视图为结构化数组
r = np.zeros((1, 3), dtype=dt).view(np.recarray)
# 设置 r 的 'foo' 字段为数组 [1, 2, 3],可能会引发 TypeError
r.foo = np.array([1, 2, 3]) # TypeError?
# https://github.com/numpy/numpy/issues/3256
# 创建一个结构化数组 ra,包含 2 个元素,数据类型是 [('x', object), ('y', float), ('z', int)]
ra = np.recarray(
(2,), dtype=[('x', object), ('y', float), ('z', int)]
)
# 尝试访问 ra 的特定字段数组 ['x','y'],可能会引发 TypeError
ra[['x','y']] # TypeError?
def test_record_scalar_setitem(self):
# https://github.com/numpy/numpy/issues/3561
# 创建一个结构化数组 rec,包含 1 个元素,数据类型是 [('x', float, 5)]
rec = np.recarray(1, dtype=[('x', float, 5)])
# 设置 rec[0] 的 'x' 字段为 1,期望其等于全为 1 的数组
rec[0].x = 1
assert_equal(rec[0].x, np.ones(5))
def test_missing_field(self):
# https://github.com/numpy/numpy/issues/4806
# 创建一个全零数组 arr,形状为 (3,),数据类型为 [('x', int), ('y', int)]
arr = np.zeros((3,), dtype=[('x', int), ('y', int)])
# 尝试访问 arr 的不存在的字段 'nofield',期望引发 KeyError 异常
assert_raises(KeyError, lambda: arr[['nofield']])
def test_fromarrays_nested_structured_arrays(self):
# 创建一个数组列表 arrays,包含两个元素:一个范围为 [0, 9] 的整数数组和一个结构化数组
arrays = [
np.arange(10),
np.ones(10, dtype=[('a', '<u2'), ('b', '<f4')]),
]
# 使用 fromarrays 方法创建一个结构化数组 arr,可能会引发 ValueError
arr = np.rec.fromarrays(arrays) # ValueError?
@pytest.mark.parametrize('nfields', [0, 1, 2])
def test_assign_dtype_attribute(self, nfields):
# 创建一个复杂的数据类型 dt,包含从 [('a', np.uint8), ('b', np.uint8), ('c', np.uint8)] 切片得到的字段
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)][:nfields])
# 创建一个全零数组 data,形状为 (3,),使用复杂数据类型 dt,再视图为结构化数组
data = np.zeros(3, dt).view(np.recarray)
# 确保原始和结果数据类型的类型都是记录型
assert data.dtype.type == np.record
assert dt.type != np.record
# 确保即使被赋予不同的数据类型,data 的数据类型仍然是记录型
data.dtype = dt
assert data.dtype.type == np.record
def test_nested_fields_are_records(self, nfields):
""" Test that nested structured types are treated as records too """
# 创建一个包含指定字段的数据类型描述符
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)][:nfields])
# 创建一个外层数据类型描述符,其中包含一个内部结构类型
dt_outer = np.dtype([('inner', dt)])
# 使用外层数据类型描述符创建一个全零数组,并转换为记录数组
data = np.zeros(3, dt_outer).view(np.recarray)
# 断言数据对象是 np.recarray 类型
assert isinstance(data, np.recarray)
# 断言数据对象的 'inner' 字段是 np.recarray 类型
assert isinstance(data['inner'], np.recarray)
# 获取数组中的第一个元素
data0 = data[0]
# 断言第一个元素是 np.record 类型
assert isinstance(data0, np.record)
# 断言第一个元素的 'inner' 字段是 np.record 类型
assert isinstance(data0['inner'], np.record)
def test_nested_dtype_padding(self):
""" test that trailing padding is preserved """
# 构建一个具有尾部填充的数据类型描述符
dt = np.dtype([('a', np.uint8), ('b', np.uint8), ('c', np.uint8)])
dt_padded_end = dt[['a', 'b']]
# 断言尾部填充后的数据类型的字节大小与原始数据类型相同
assert dt_padded_end.itemsize == dt.itemsize
# 创建一个外层数据类型描述符,其中包含一个尾部填充后的数据类型
dt_outer = np.dtype([('inner', dt_padded_end)])
# 使用外层数据类型描述符创建一个全零数组,并转换为记录数组
data = np.zeros(3, dt_outer).view(np.recarray)
# 断言 'inner' 字段的数据类型与尾部填充后的数据类型相同
assert_equal(data['inner'].dtype, dt_padded_end)
# 获取数组中的第一个元素
data0 = data[0]
# 断言第一个元素的 'inner' 字段的数据类型与尾部填充后的数据类型相同
assert_equal(data0['inner'].dtype, dt_padded_end)
# 定义一个测试函数,用于测试查找列表中重复元素的函数 np.rec.find_duplicate()
def test_find_duplicate():
# 创建一个包含整数的列表 l1
l1 = [1, 2, 3, 4, 5, 6]
# 断言调用 np.rec.find_duplicate(l1) 返回一个空列表
assert_(np.rec.find_duplicate(l1) == [])
# 创建一个包含整数的列表 l2,其中有重复元素 1
l2 = [1, 2, 1, 4, 5, 6]
# 断言调用 np.rec.find_duplicate(l2) 返回包含元素 1 的列表
assert_(np.rec.find_duplicate(l2) == [1])
# 创建一个包含整数的列表 l3,其中有多个重复元素 1 和 2
l3 = [1, 2, 1, 4, 1, 6, 2, 3]
# 断言调用 np.rec.find_duplicate(l3) 返回包含元素 1 和 2 的列表
assert_(np.rec.find_duplicate(l3) == [1, 2])
# 修改列表 l3,使其包含多个重复元素 2 和 1
l3 = [2, 2, 1, 4, 1, 6, 2, 3]
# 断言调用 np.rec.find_duplicate(l3) 返回包含元素 2 和 1 的列表
assert_(np.rec.find_duplicate(l3) == [2, 1])
.\numpy\numpy\_core\tests\test_regression.py
import copy # 导入copy模块,用于深拷贝和浅拷贝操作
import sys # 导入sys模块,用于访问系统相关的变量和函数
import gc # 导入gc模块,用于垃圾回收
import tempfile # 导入tempfile模块,用于创建临时文件和目录
import pytest # 导入pytest模块,用于编写和运行测试用例
from os import path # 从os模块中导入path子模块,用于路径操作
from io import BytesIO # 从io模块中导入BytesIO类,用于处理二进制数据流
from itertools import chain # 导入itertools模块中的chain函数,用于迭代器操作
import pickle # 导入pickle模块,用于序列化和反序列化Python对象
import numpy as np # 导入NumPy库,使用别名np
from numpy.exceptions import AxisError, ComplexWarning # 导入NumPy的异常类
from numpy.testing import ( # 导入NumPy的测试工具函数
assert_, assert_equal, IS_PYPY, assert_almost_equal,
assert_array_equal, assert_array_almost_equal, assert_raises,
assert_raises_regex, assert_warns, suppress_warnings,
_assert_valid_refcount, HAS_REFCOUNT, IS_PYSTON, IS_WASM
)
from numpy.testing._private.utils import _no_tracing, requires_memory # 导入NumPy测试工具的私有函数
from numpy._utils import asbytes, asunicode # 导入NumPy内部工具函数
class TestRegression:
def test_invalid_round(self):
# Ticket #3
v = 4.7599999999999998 # 定义一个浮点数变量v
assert_array_equal(np.array([v]), np.array(v)) # 断言numpy数组的相等性
def test_mem_empty(self):
# Ticket #7
np.empty((1,), dtype=[('x', np.int64)]) # 创建一个空的结构化数组
def test_pickle_transposed(self):
# Ticket #16
a = np.transpose(np.array([[2, 9], [7, 0], [3, 8]])) # 创建一个转置后的二维数组
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): # 遍历pickle协议的范围
with BytesIO() as f: # 使用BytesIO创建一个内存中的文件对象
pickle.dump(a, f, protocol=proto) # 将数组a序列化到文件对象f中
f.seek(0) # 将文件指针移到文件开头
b = pickle.load(f) # 从文件对象f中反序列化数据到变量b
assert_array_equal(a, b) # 断言数组a和反序列化后的数组b相等
def test_dtype_names(self):
# Ticket #35
# Should succeed
np.dtype([(('name', 'label'), np.int32, 3)]) # 创建一个复杂的dtype对象
def test_reduce(self):
# Ticket #40
assert_almost_equal(np.add.reduce([1., .5], dtype=None), 1.5) # 对数组元素进行累加并进行精度比较
def test_zeros_order(self):
# Ticket #43
np.zeros([3], int, 'C') # 创建一个C顺序的全0数组
np.zeros([3], order='C') # 创建一个C顺序的全0数组
np.zeros([3], int, order='C') # 创建一个C顺序的全0数组
def test_asarray_with_order(self):
# Check that nothing is done when order='F' and array C/F-contiguous
a = np.ones(2) # 创建一个全1数组
assert_(a is np.asarray(a, order='F')) # 断言在指定F顺序时不会复制数组
def test_ravel_with_order(self):
# Check that ravel works when order='F' and array C/F-contiguous
a = np.ones(2) # 创建一个全1数组
assert_(not a.ravel('F').flags.owndata) # 断言ravel后的数组在指定F顺序时不拥有数据
def test_sort_bigendian(self):
# Ticket #47
a = np.linspace(0, 10, 11) # 创建一个均匀分布的数组
c = a.astype(np.dtype('<f8')) # 将数组a转换为小端序的浮点数组
c.sort() # 对数组c进行排序
assert_array_almost_equal(c, a) # 断言数组c和原始数组a的近似相等性
def test_negative_nd_indexing(self):
# Ticket #49
c = np.arange(125).reshape((5, 5, 5)) # 创建一个形状为(5, 5, 5)的三维数组
origidx = np.array([-1, 0, 1]) # 创建一个原始索引数组
idx = np.array(origidx) # 复制一份索引数组
c[idx] # 使用索引数组获取数组c的切片
assert_array_equal(idx, origidx) # 断言索引数组和原始索引数组相等
def test_char_dump(self):
# Ticket #50
ca = np.char.array(np.arange(1000, 1010), itemsize=4) # 创建一个字符数组
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1): # 遍历pickle协议的范围
with BytesIO() as f: # 使用BytesIO创建一个内存中的文件对象
pickle.dump(ca, f, protocol=proto) # 将字符数组ca序列化到文件对象f中
f.seek(0) # 将文件指针移到文件开头
ca = np.load(f, allow_pickle=True) # 从文件对象f中加载数据到ca
def test_noncontiguous_fill(self):
# Ticket #58.
a = np.zeros((5, 3)) # 创建一个全0数组
b = a[:, :2,] # 创建一个非连续的切片数组
def rs():
b.shape = (10,) # 修改数组b的形状为(10,)
assert_raises(AttributeError, rs) # 断言在修改形状时会抛出AttributeError异常
def test_bool(self):
# Ticket #60
# 创建一个布尔值的 numpy 数组,传入整数参数 1
np.bool(1) # Should succeed
def test_indexing1(self):
# Ticket #64
# 定义一个复杂的结构化数据类型描述符
descr = [('x', [('y', [('z', 'c16', (2,)),]),]),]
# 定义一个多层嵌套的 Python 元组作为数据缓冲区
buffer = ((([6j, 4j],),),)
# 使用给定的描述符创建一个 numpy 数组 h
h = np.array(buffer, dtype=descr)
# 对结构化数组 h 进行多层索引操作
h['x']['y']['z']
def test_indexing2(self):
# Ticket #65
# 定义一个简单的结构化数据类型描述符
descr = [('x', 'i4', (2,))]
# 定义一个一维数组作为数据缓冲区
buffer = ([3, 2],)
# 使用给定的描述符创建一个 numpy 数组 h
h = np.array(buffer, dtype=descr)
# 对结构化数组 h 进行索引操作
h['x']
def test_round(self):
# Ticket #67
# 创建一个复数 numpy 数组 x
x = np.array([1+2j])
# 使用 assert_almost_equal 断言比较复数数组 x 的幂运算结果
assert_almost_equal(x**(-1), [1/(1+2j)])
def test_scalar_compare(self):
# Trac Ticket #72
# 创建一个字符串类型的 numpy 数组 a
a = np.array(['test', 'auto'])
# 使用 assert_array_equal 断言比较数组 a 中的元素与字符串 'auto' 的相等性
assert_array_equal(a == 'auto', np.array([False, True]))
# 使用 assert_ 断言比较数组 a 中特定索引处的元素
assert_(a[1] == 'auto')
assert_(a[0] != 'auto')
# 创建一个等间隔的浮点数数组 b
b = np.linspace(0, 10, 11)
# 使用 assert_array_equal 断言比较数组 b 与字符串 'auto' 的不等性
assert_array_equal(b != 'auto', np.ones(11, dtype=bool))
# 使用 assert_ 断言比较数组 b 中特定索引处的元素
assert_(b[0] != 'auto')
def test_unicode_swapping(self):
# Ticket #79
# 定义 Unicode 字符串长度 ulen 和 Unicode 字符串值 ucs_value
ulen = 1
ucs_value = '\U0010FFFF'
# 使用给定的 ulen 创建一个多维数组 ua,元素为 Unicode 字符串 ucs_value
ua = np.array([[[ucs_value*ulen]*2]*3]*4, dtype='U%s' % ulen)
# 调用 view 方法以新的字节顺序查看数组 ua 的数据类型
ua.view(ua.dtype.newbyteorder()) # Should succeed.
def test_object_array_fill(self):
# Ticket #86
# 创建一个对象类型的全零数组 x
x = np.zeros(1, 'O')
# 使用 fill 方法填充数组 x 的所有元素为空列表
x.fill([])
def test_mem_dtype_align(self):
# Ticket #93
# 使用 assert_raises 断言捕获 TypeError 异常,因为传入的 dtype 参数不合法
assert_raises(TypeError, np.dtype,
{'names':['a'], 'formats':['foo']}, align=1)
def test_endian_bool_indexing(self):
# Ticket #105
# 创建两个大端和小端格式的浮点数数组 a 和 b
a = np.arange(10., dtype='>f8')
b = np.arange(10., dtype='<f8')
# 使用 where 函数查找满足条件的索引,并使用 assert_array_almost_equal 断言比较结果
xa = np.where((a > 2) & (a < 6))
xb = np.where((b > 2) & (b < 6))
ya = ((a > 2) & (a < 6))
yb = ((b > 2) & (b < 6))
assert_array_almost_equal(xa, ya.nonzero())
assert_array_almost_equal(xb, yb.nonzero())
# 使用 assert_ 断言比较满足条件的数组元素
assert_(np.all(a[ya] > 0.5))
assert_(np.all(b[yb] > 0.5))
def test_endian_where(self):
# GitHub issue #369
# 创建一个大端格式的浮点数数组 net
net = np.zeros(3, dtype='>f4')
net[1] = 0.00458849
net[2] = 0.605202
# 使用 where 函数处理数组 net,使得不满足条件的元素替换为数组最大值
max_net = net.max()
test = np.where(net <= 0., max_net, net)
correct = np.array([ 0.60520202, 0.00458849, 0.60520202])
# 使用 assert_array_almost_equal 断言比较处理后的数组 test 与正确结果 correct
assert_array_almost_equal(test, correct)
def test_endian_recarray(self):
# Ticket #2185
# 定义一个大端格式的记录数组数据类型 dt
dt = np.dtype([
('head', '>u4'),
('data', '>u4', 2),
])
# 创建一个符合数据类型 dt 的记录数组 buf
buf = np.recarray(1, dtype=dt)
buf[0]['head'] = 1
buf[0]['data'][:] = [1, 1]
h = buf[0]['head']
d = buf[0]['data'][0]
buf[0]['head'] = h
buf[0]['data'][0] = d
# 使用 assert_ 断言检查记录数组 buf 的特定字段是否等于预期值
assert_(buf[0]['head'] == 1)
def test_mem_dot(self):
# Ticket #106
# 生成一个形状为 (0, 1) 的随机数组 x
x = np.random.randn(0, 1)
# 生成一个形状为 (10, 1) 的随机数组 y
y = np.random.randn(10, 1)
# 创建一个用于检测错误内存访问的虚拟数组 _z
_z = np.ones(10)
# 创建一个空数组 _dummy,用于描述步幅
_dummy = np.empty((0, 10))
# 使用 as_strided 方法创建一个视图 z,共享 _z 的数据,但使用 _dummy 的形状和步幅
z = np.lib.stride_tricks.as_strided(_z, _dummy.shape, _dummy.strides)
# 计算 x 和 y 的转置的点积,将结果保存到 z
np.dot(x, np.transpose(y), out=z)
# 断言 _z 等于形状为 (10,) 的全 1 数组
assert_equal(_z, np.ones(10))
# 使用内置的 dot 方法同样计算 x 和 y 的点积,结果保存到 z
np._core.multiarray.dot(x, np.transpose(y), out=z)
# 再次断言 _z 等于全 1 的数组
assert_equal(_z, np.ones(10))
def test_arange_endian(self):
# Ticket #111
# 创建一个从 0 到 9 的整数数组 ref
ref = np.arange(10)
# 创建一个数据类型为小端序的浮点型数组 x,包含 0 到 9 的数
x = np.arange(10, dtype='<f8')
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个数据类型为大端序的浮点型数组 x,包含 0 到 9 的数
x = np.arange(10, dtype='>f8')
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
def test_arange_inf_step(self):
# Ticket #113
# 创建一个从 0 到 1,步长为 10 的数组 ref
ref = np.arange(0, 1, 10)
# 创建一个从 0 到 1,步长为无穷大的数组 x
x = np.arange(0, 1, np.inf)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个从 0 到 1,步长为 -10 的数组 ref
ref = np.arange(0, 1, -10)
# 创建一个从 0 到 1,步长为负无穷大的数组 x
x = np.arange(0, 1, -np.inf)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个从 0 到 -1,步长为 -10 的数组 ref
ref = np.arange(0, -1, -10)
# 创建一个从 0 到 -1,步长为负无穷大的数组 x
x = np.arange(0, -1, -np.inf)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个从 0 到 -1,步长为 10 的数组 ref
ref = np.arange(0, -1, 10)
# 创建一个从 0 到 -1,步长为无穷大的数组 x
x = np.arange(0, -1, np.inf)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
def test_arange_underflow_stop_and_step(self):
# Ticket #114
# 获取 np.float64 类型的机器精度信息
finfo = np.finfo(np.float64)
# 创建一个从 0 到机器精度,步长为 2 * 机器精度 的数组 ref
ref = np.arange(0, finfo.eps, 2 * finfo.eps)
# 创建一个从 0 到机器精度,步长为最大浮点数的数组 x
x = np.arange(0, finfo.eps, finfo.max)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个从 0 到机器精度,步长为 -2 * 机器精度 的数组 ref
ref = np.arange(0, finfo.eps, -2 * finfo.eps)
# 创建一个从 0 到机器精度,步长为负最大浮点数的数组 x
x = np.arange(0, finfo.eps, -finfo.max)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个从 0 到 -机器精度,步长为 -2 * 机器精度 的数组 ref
ref = np.arange(0, -finfo.eps, -2 * finfo.eps)
# 创建一个从 0 到 -机器精度,步长为负最大浮点数的数组 x
x = np.arange(0, -finfo.eps, -finfo.max)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
# 创建一个从 0 到 -机器精度,步长为 2 * 机器精度 的数组 ref
ref = np.arange(0, -finfo.eps, 2 * finfo.eps)
# 创建一个从 0 到 -机器精度,步长为最大浮点数的数组 x
x = np.arange(0, -finfo.eps, finfo.max)
# 断言两个数组 ref 和 x 相等
assert_array_equal(ref, x)
def test_argmax(self):
# Ticket #119
# 创建一个形状为 (4, 5, 6, 7, 8) 的正态分布随机数组 a
a = np.random.normal(0, 1, (4, 5, 6, 7, 8))
# 对数组 a 沿着各维度进行 argmax 操作,验证是否成功
for i in range(a.ndim):
a.argmax(i) # 应该成功
def test_mem_divmod(self):
# Ticket #126
# 对 0 到 9 的数组进行循环,每次使用 divmod 函数计算一个数和 10 的商和余数
for i in range(10):
divmod(np.array([i])[0], 10)
def test_hstack_invalid_dims(self):
# Ticket #128
# 创建一个形状为 (3, 3) 的数组 x 和一个形状为 (3,) 的数组 y
x = np.arange(9).reshape((3, 3))
y = np.array([0, 0, 0])
# 断言 np.hstack((x, y)) 会引发 ValueError 异常
assert_raises(ValueError, np.hstack, (x, y))
def test_squeeze_type(self):
# Ticket #133
# 创建一个包含单个元素 3 的数组 a 和一个标量 3 的数组 b
a = np.array([3])
b = np.array(3)
# 断言 a.squeeze() 和 b.squeeze() 的类型都是 np.ndarray
assert_(type(a.squeeze()) is np.ndarray)
assert_(type(b.squeeze()) is np.ndarray)
def test_add_identity(self):
# Ticket #143
# 断言 np.add.identity 的值为 0
assert_equal(0, np.add.identity)
def test_numpy_float_python_long_addition(self):
# 检查 numpy 浮点数和 Python 长整型的加法是否能正确执行
a = np.float64(23.) + 2**135
assert_equal(a, 23. + 2**135)
def test_binary_repr_0(self):
# Ticket #151
# 断言 np.binary_repr(0) 的结果为字符串 '0'
assert_equal('0', np.binary_repr(0))
def test_rec_iterate(self):
# Ticket #160
# 定义一个结构化数据类型描述符,包括整数、浮点数和字符串字段
descr = np.dtype([('i', int), ('f', float), ('s', '|S3')])
# 创建一个结构化数组,并初始化数据
x = np.rec.array([(1, 1.1, '1.0'),
(2, 2.2, '2.0')], dtype=descr)
# 将第一个元素转换为普通 Python 元组
x[0].tolist()
# 使用列表推导式遍历结构化数组的第一个元素
[i for i in x[0]]
def test_unicode_string_comparison(self):
# Ticket #190
# 创建一个包含字符串 'hello' 的数组,数据类型为 np.str_
a = np.array('hello', np.str_)
# 创建一个包含字符串 'world' 的数组,数据类型为默认的字符串
b = np.array('world')
# 比较两个字符串数组的元素是否相等
a == b
def test_tobytes_FORTRANORDER_discontiguous(self):
# Fix in r2836
# 创建一个非连续的 Fortran 顺序数组
x = np.array(np.random.rand(3, 3), order='F')[:, :2]
# 将数组展平并转换为字节序列,与从字节序列解析回来的数组进行近似相等的断言
assert_array_almost_equal(x.ravel(), np.frombuffer(x.tobytes()))
def test_flat_assignment(self):
# Correct behaviour of ticket #194
# 创建一个空的 3x1 数组
x = np.empty((3, 1))
# 将数组的 flat 属性设置为从 0 开始的整数序列
x.flat = np.arange(3)
# 断言数组 x 是否近似等于 [[0], [1], [2]]
assert_array_almost_equal(x, [[0], [1], [2]])
# 再次将数组的 flat 属性设置为从 0 开始的浮点数序列
x.flat = np.arange(3, dtype=float)
# 再次断言数组 x 是否近似等于 [[0], [1], [2]]
assert_array_almost_equal(x, [[0], [1], [2]])
def test_broadcast_flat_assignment(self):
# Ticket #194
# 创建一个空的 3x1 数组
x = np.empty((3, 1))
def bfa():
# 尝试将整个数组 x 广播分配为从 0 开始的整数序列
x[:] = np.arange(3)
def bfb():
# 尝试将整个数组 x 广播分配为从 0 开始的浮点数序列
x[:] = np.arange(3, dtype=float)
# 断言 bfa 和 bfb 函数会引发 ValueError 异常
assert_raises(ValueError, bfa)
assert_raises(ValueError, bfb)
@pytest.mark.xfail(IS_WASM, reason="not sure why")
@pytest.mark.parametrize("index",
[np.ones(10, dtype=bool), np.arange(10)],
ids=["boolean-arr-index", "integer-arr-index"])
def test_nonarray_assignment(self, index):
# See also Issue gh-2870, test for non-array assignment
# and equivalent unsafe casted array assignment
# 创建一个包含 0 到 9 的整数数组
a = np.arange(10)
# 使用 pytest.raises 检查将 NaN 分配给数组的非数组索引时是否会引发 ValueError 异常
with pytest.raises(ValueError):
a[index] = np.nan
# 使用 np.errstate 设置无效值警告,并使用 pytest.warns 检查将 NaN 分配给数组时是否会发出 RuntimeWarning 警告
with np.errstate(invalid="warn"):
with pytest.warns(RuntimeWarning, match="invalid value"):
a[index] = np.array(np.nan) # 只会发出警告
def test_unpickle_dtype_with_object(self):
# Implemented in r2840
# 创建一个包含整数、对象和对象类型的结构化数据类型描述符
dt = np.dtype([('x', int), ('y', np.object_), ('z', 'O')])
# 使用 pickle 序列化和反序列化 dt 对象,并断言它们是否相等
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
with BytesIO() as f:
pickle.dump(dt, f, protocol=proto)
f.seek(0)
dt_ = pickle.load(f)
assert_equal(dt, dt_)
def test_mem_array_creation_invalid_specification(self):
# Ticket #196
# 创建一个包含整数和对象的结构化数据类型描述符
dt = np.dtype([('x', int), ('y', np.object_)])
# 使用 assert_raises 检查错误的 np.array 调用是否会引发 ValueError 异常
assert_raises(ValueError, np.array, [1, 'object'], dt)
# 正确的方式是使用正确的结构化数组初始化方法
np.array([(1, 'object')], dt)
def test_recarray_single_element(self):
# Ticket #202
# 创建一个包含整数的数组,并复制它
a = np.array([1, 2, 3], dtype=np.int32)
b = a.copy()
# 使用给定的形状和格式创建一个记录数组,并断言原始数组与复制的数组相等
r = np.rec.array(a, shape=1, formats=['3i4'], names=['d'])
assert_array_equal(a, b)
assert_equal(a, r[0][0])
def test_zero_sized_array_indexing(self):
# Ticket #205
# 创建一个空的 NumPy 数组
tmp = np.array([])
def index_tmp():
# 尝试对空数组进行索引操作
tmp[np.array(10)]
# 确保在索引空数组时引发 IndexError 异常
assert_raises(IndexError, index_tmp)
def test_chararray_rstrip(self):
# Ticket #222
# 创建一个长度为 1 的字符数组,每个元素为长度为 5 的字节串
x = np.char.chararray((1,), 5)
x[0] = b'a '
# 对字符数组进行右侧空白字符的去除操作
x = x.rstrip()
# 确保去除空白字符后的第一个元素为 b'a'
assert_equal(x[0], b'a')
def test_object_array_shape(self):
# Ticket #239
# 测试对象数组的形状
assert_equal(np.array([[1, 2], 3, 4], dtype=object).shape, (3,))
assert_equal(np.array([[1, 2], [3, 4]], dtype=object).shape, (2, 2))
assert_equal(np.array([(1, 2), (3, 4)], dtype=object).shape, (2, 2))
assert_equal(np.array([], dtype=object).shape, (0,))
assert_equal(np.array([[], [], []], dtype=object).shape, (3, 0))
assert_equal(np.array([[3, 4], [5, 6], None], dtype=object).shape, (3,))
def test_mem_around(self):
# Ticket #243
# 测试 np.around 函数在极小值比较时的精度
x = np.zeros((1,))
y = [0]
decimal = 6
np.around(abs(x-y), decimal) <= 10.0**(-decimal)
def test_character_array_strip(self):
# Ticket #246
# 创建一个字符数组,去除每个元素末尾的空白字符
x = np.char.array(("x", "x ", "x "))
for c in x:
assert_equal(c, "x")
def test_lexsort(self):
# Lexsort memory error
# 创建一个包含整数的 NumPy 数组,测试 np.lexsort 的行为
v = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
assert_equal(np.lexsort(v), 0)
def test_lexsort_invalid_sequence(self):
# Issue gh-4123
# 创建一个具有错误行为的序列类,测试 np.lexsort 是否正确处理异常
class BuggySequence:
def __len__(self):
return 4
def __getitem__(self, key):
raise KeyError
assert_raises(KeyError, np.lexsort, BuggySequence())
def test_lexsort_zerolen_custom_strides(self):
# Ticket #14228
# 创建一个空数组,测试其自定义步幅下 np.lexsort 的行为
xs = np.array([], dtype='i8')
assert np.lexsort((xs,)).shape[0] == 0 # Works
xs.strides = (16,)
assert np.lexsort((xs,)).shape[0] == 0 # Was: MemoryError
def test_lexsort_zerolen_custom_strides_2d(self):
xs = np.array([], dtype='i8')
xs.shape = (0, 2)
xs.strides = (16, 16)
assert np.lexsort((xs,), axis=0).shape[0] == 0
xs.shape = (2, 0)
xs.strides = (16, 16)
assert np.lexsort((xs,), axis=0).shape[0] == 2
def test_lexsort_invalid_axis(self):
# 测试 np.lexsort 在给定非法轴参数时是否正确引发异常
assert_raises(AxisError, np.lexsort, (np.arange(1),), axis=2)
assert_raises(AxisError, np.lexsort, (np.array([]),), axis=1)
assert_raises(AxisError, np.lexsort, (np.array(1),), axis=10)
def test_lexsort_zerolen_element(self):
# 创建一个空的 void 类型数组,测试 np.lexsort 对其的行为
dt = np.dtype([]) # a void dtype with no fields
xs = np.empty(4, dt)
assert np.lexsort((xs,)).shape[0] == xs.shape[0]
def test_pickle_py2_bytes_encoding(self):
# 测试在 Python 2 中使用 encoding='bytes' 序列化的数组和标量,在 Python 3 中可以正确反序列化
test_data = [
# (original, py2_pickle)
(
np.str_('\u6f2c'), # 创建一个包含特定 Unicode 字符串的 NumPy 字符串对象
b"cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\nI0\ntp6\nbS',o\\x00\\x00'\np7\ntp8\nRp9\n." # noqa
),
(
np.array([9e123], dtype=np.float64), # 创建一个包含单个浮点数的 NumPy 数组对象
b"cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n(S'f8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'<'\np11\nNNNI-1\nI-1\nI0\ntp12\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np13\ntp14\nb." # noqa
),
(
np.array([(9e123,)], dtype=[('name', float)]), # 创建一个包含结构化数据的 NumPy 数组对象
b"cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n(S'V8'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nN(S'name'\np12\ntp13\n(dp14\ng12\n(g7\n(S'f8'\np15\nI0\nI1\ntp16\nRp17\n(I3\nS'<'\np18\nNNNI-1\nI-1\nI0\ntp19\nbI0\ntp20\nsI8\nI1\nI0\ntp21\nbI00\nS'O\\x81\\xb7Z\\xaa:\\xabY'\np22\ntp23\nb." # noqa
),
]
for original, data in test_data:
result = pickle.loads(data, encoding='bytes') # 使用 bytes 编码解析 pickle 数据
assert_equal(result, original) # 断言反序列化的结果与原始数据一致
if isinstance(result, np.ndarray) and result.dtype.names is not None:
for name in result.dtype.names:
assert_(isinstance(name, str)) # 断言结构化数组的字段名是字符串类型
def test_pickle_dtype(self):
# 测试 pickle.dumps() 对不同协议版本的浮点数序列化
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
pickle.dumps(float, protocol=proto) # 使用不同的协议版本序列化 float 类型对象
def test_swap_real(self):
# Ticket #265
# 测试不同端序下复数类型数组的实部和虚部的最大值
assert_equal(np.arange(4, dtype='>c8').imag.max(), 0.0) # 断言大端序复数数组的虚部最大值为 0.0
assert_equal(np.arange(4, dtype='<c8').imag.max(), 0.0) # 断言小端序复数数组的虚部最大值为 0.0
assert_equal(np.arange(4, dtype='>c8').real.max(), 3.0) # 断言大端序复数数组的实部最大值为 3.0
assert_equal(np.arange(4, dtype='<c8').real.max(), 3.0) # 断言小端序复数数组的实部最大值为 3.0
def test_object_array_from_list(self):
# Ticket #270 (gh-868)
# 测试从包含不同类型元素的列表创建对象数组时的形状
assert_(np.array([1, None, 'A']).shape == (3,)) # 断言从包含整数、None 和字符串的列表创建的对象数组形状为 (3,)
def test_multiple_assign(self):
# Ticket #273
# 测试多个索引赋值操作
a = np.zeros((3, 1), int) # 创建一个形状为 (3, 1) 的全零数组
a[[1, 2]] = 1 # 对索引为 1 和 2 的位置赋值为 1
def test_empty_array_type(self):
# 测试空数组的 dtype
assert_equal(np.array([]).dtype, np.zeros(0).dtype) # 断言空数组的 dtype 与创建长度为 0 的全零数组的 dtype 相同
def test_void_copyswap(self):
# 测试结构化数据类型的字节交换操作
dt = np.dtype([('one', '<i4'), ('two', '<i4')]) # 定义一个结构化数据类型
x = np.array((1, 2), dtype=dt) # 创建一个符合上述数据类型的数组对象
x = x.byteswap() # 执行字节交换操作
assert_(x['one'] > 1 and x['two'] > 2) # 断言交换后的数据满足条件
def test_method_args(self):
# 确保方法和函数具有相同的默认轴关键字和参数
funcs1 = ['argmax', 'argmin', 'sum', 'any', 'all', 'cumsum',
'cumprod', 'prod', 'std', 'var', 'mean',
'round', 'min', 'max', 'argsort', 'sort']
funcs2 = ['compress', 'take', 'repeat']
# 针对 funcs1 中的每个函数进行测试
for func in funcs1:
# 创建一个随机的 8x7 的数组
arr = np.random.rand(8, 7)
arr2 = arr.copy()
# 调用数组对象自身的函数 func
res1 = getattr(arr, func)()
# 调用 numpy 的函数 func,并传入数组 arr2
res2 = getattr(np, func)(arr2)
# 如果 res1 是 None,则将其设置为 arr
if res1 is None:
res1 = arr
# 如果 res1 的 dtype 的种类在 'uib' 中
if res1.dtype.kind in 'uib':
# 使用 assert_ 进行断言,确保所有元素相等
assert_((res1 == res2).all(), func)
else:
# 否则,使用 assert_ 确保元素之间的最大差值小于 1e-8
assert_(abs(res1-res2).max() < 1e-8, func)
# 针对 funcs2 中的每个函数进行测试
for func in funcs2:
arr1 = np.random.rand(8, 7)
arr2 = np.random.rand(8, 7)
res1 = None
if func == 'compress':
# 如果 func 是 'compress',则将 arr1 展平,并调用 compress 函数
arr1 = arr1.ravel()
res1 = getattr(arr2, func)(arr1)
else:
# 否则,将 arr2 转换为整数类型并展平
arr2 = (15*arr2).astype(int).ravel()
# 如果 res1 是 None,则调用 arr1 对象的 func 函数
if res1 is None:
res1 = getattr(arr1, func)(arr2)
# 调用 numpy 的 func 函数,传入 arr1 和 arr2
res2 = getattr(np, func)(arr1, arr2)
# 使用 assert_ 确保元素之间的最大差值小于 1e-8
assert_(abs(res1-res2).max() < 1e-8, func)
def test_mem_lexsort_strings(self):
# Ticket #298
# 创建一个字符串列表 lst
lst = ['abc', 'cde', 'fgh']
# 使用 lexsort 函数对 lst 进行排序
np.lexsort((lst,))
def test_fancy_index(self):
# Ticket #302
# 创建一个 numpy 数组 x,使用 fancy indexing
x = np.array([1, 2])[np.array([0])]
# 使用 assert_equal 确保数组 x 的形状为 (1,)
assert_equal(x.shape, (1,))
def test_recarray_copy(self):
# Ticket #312
# 定义一个结构化数据类型 dt
dt = [('x', np.int16), ('y', np.float64)]
# 创建一个 ndarray ra,并指定数据类型为 dt
ra = np.array([(1, 2.3)], dtype=dt)
# 使用 rec.array 函数将 ra 转换为 recarray 类型 rb
rb = np.rec.array(ra, dtype=dt)
# 修改 rb 中 'x' 字段的值为 2.0
rb['x'] = 2.
# 使用 assert_ 确保 ra 中 'x' 字段的值不等于 rb 中对应字段的值
assert_(ra['x'] != rb['x'])
def test_rec_fromarray(self):
# Ticket #322
# 创建三个数组 x1, x2, x3
x1 = np.array([[1, 2], [3, 4], [5, 6]])
x2 = np.array(['a', 'dd', 'xyz'])
x3 = np.array([1.1, 2, 3])
# 使用 rec.fromarrays 函数将 x1, x2, x3 转换为结构化数组
np.rec.fromarrays([x1, x2, x3], formats="(2,)i4,S3,f8")
def test_object_array_assign(self):
# 创建一个形状为 (2, 2) 的对象数组 x
x = np.empty((2, 2), object)
# 将 x 中第 2 个元素赋值为元组 (1, 2, 3)
x.flat[2] = (1, 2, 3)
# 使用 assert_equal 确保 x 中第 2 个元素的值为 (1, 2, 3)
assert_equal(x.flat[2], (1, 2, 3))
def test_ndmin_float64(self):
# Ticket #324
# 创建一个 dtype 为 np.float64 的数组 x
x = np.array([1, 2, 3], dtype=np.float64)
# 使用 assert_equal 确保将 x 转换为 dtype 为 np.float32 且 ndmin 为 2 后的数组维度为 2
assert_equal(np.array(x, dtype=np.float32, ndmin=2).ndim, 2)
# 使用 assert_equal 确保将 x 转换为 dtype 为 np.float64 且 ndmin 为 2 后的数组维度为 2
assert_equal(np.array(x, dtype=np.float64, ndmin=2).ndim, 2)
def test_ndmin_order(self):
# Issue #465 and related checks
# 使用 assert_ 确保将数组 [1, 2] 转换为 ndmin 为 3 且 order 为 'C' 后是 C 连续的
assert_(np.array([1, 2], order='C', ndmin=3).flags.c_contiguous)
# 使用 assert_ 确保将数组 [1, 2] 转换为 ndmin 为 3 且 order 为 'F' 后是 Fortran 连续的
assert_(np.array([1, 2], order='F', ndmin=3).flags.f_contiguous)
# 使用 assert_ 确保将 order 为 'F' 的全 1 数组转换为 ndmin 为 3 后仍然是 Fortran 连续的
assert_(np.array(np.ones((2, 2), order='F'), ndmin=3).flags.f_contiguous)
# 使用 assert_ 确保将 order 为 'C' 的全 1 数组转换为 ndmin 为 3 后仍然是 C 连续的
assert_(np.array(np.ones((2, 2), order='C'), ndmin=3).flags.c_contiguous)
def test_mem_axis_minimization(self):
# Ticket #327
# 创建一个包含 0 到 4 的数组 data
data = np.arange(5)
# 计算 data 与其外积的和
data = np.add.outer(data, data)
# 定义一个测试函数,测试 np.float64 类型的虚部属性
def test_mem_float_imag(self):
# 根据 Ticket #330 的要求,访问 np.float64 对象的 imag 属性
np.float64(1.0).imag
# 定义一个测试函数,测试 np.dtype('i4') 和 np.dtype(('i4', ())) 的相等性
def test_dtype_tuple(self):
# 根据 Ticket #334 的要求,断言两个数据类型对象是否相等
assert_(np.dtype('i4') == np.dtype(('i4', ())))
# 定义一个测试函数,测试使用带有空元组的 dtype 的创建
def test_dtype_posttuple(self):
# 根据 Ticket #335 的要求,创建一个包含空元组的数据类型对象
np.dtype([('col1', '()i4')])
# 定义一个测试函数,测试字符数组和字节串之间的比较
def test_numeric_carray_compare(self):
# 根据 Ticket #341 的要求,断言字符数组和字节串的相等性
assert_equal(np.array(['X'], 'c'), b'X')
# 定义一个测试函数,测试字符串数组的大小设置是否会引发 ValueError
def test_string_array_size(self):
# 根据 Ticket #342 的要求,断言创建特定大小的字符串数组时是否会引发异常
assert_raises(ValueError,
np.array, [['X'], ['X', 'X', 'X']], '|S1')
# 定义一个测试函数,测试 dtype 对象的字符串表示是否一致
def test_dtype_repr(self):
# 根据 Ticket #344 的要求,创建两个相似的 dtype 对象并断言它们的字符串表示相同
dt1 = np.dtype(('uint32', 2))
dt2 = np.dtype(('uint32', (2,)))
assert_equal(dt1.__repr__(), dt2.__repr__())
# 定义一个测试函数,测试reshape函数中的顺序参数(order)是否正常工作
def test_reshape_order(self):
# 确保 reshape 函数的 order 参数正常工作
a = np.arange(6).reshape(2, 3, order='F')
assert_equal(a, [[0, 2, 4], [1, 3, 5]])
a = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
b = a[:, 1]
assert_equal(b.reshape(2, 2, order='F'), [[2, 6], [4, 8]])
# 定义一个测试函数,测试零步幅数组的重塑
def test_reshape_zero_strides(self):
# 根据 Issue #380 的要求,测试零步幅数组的 reshape 操作
a = np.ones(1)
a = np.lib.stride_tricks.as_strided(a, shape=(5,), strides=(0,))
assert_(a.reshape(5, 1).strides[0] == 0)
# 定义一个测试函数,测试零大小数组的重塑
def test_reshape_zero_size(self):
# 根据 GitHub Issue #2700 的要求,测试零大小数组的形状设置
a = np.ones((0, 2))
a.shape = (-1, 2)
# 定义一个测试函数,测试新形状中末尾为1的步幅
def test_reshape_trailing_ones_strides(self):
# 根据 GitHub issue gh-2949 的要求,测试新形状中末尾为1的步幅设置是否正确
a = np.zeros(12, dtype=np.int32)[::2] # not contiguous
strides_c = (16, 8, 8, 8)
strides_f = (8, 24, 48, 48)
assert_equal(a.reshape(3, 2, 1, 1).strides, strides_c)
assert_equal(a.reshape(3, 2, 1, 1, order='F').strides, strides_f)
assert_equal(np.array(0, dtype=np.int32).reshape(1, 1).strides, (4, 4))
# 定义一个测试函数,测试重复数组操作
def test_repeat_discont(self):
# 根据 Ticket #352 的要求,测试数组的重复操作
a = np.arange(12).reshape(4, 3)[:, 2]
assert_equal(a.repeat(3), [2, 2, 2, 5, 5, 5, 8, 8, 8, 11, 11, 11])
# 定义一个测试函数,测试数组索引操作
def test_array_index(self):
# 确保在这种情况下不调用优化
a = np.array([1, 2, 3])
a2 = np.array([[1, 2, 3]])
assert_equal(a[np.where(a == 3)], a2[np.where(a2 == 3)])
# 定义一个测试函数,测试对象数组的 argmax 方法
def test_object_argmax(self):
# 根据 Ticket #369 的要求,测试对象数组的 argmax 方法
a = np.array([1, 2, 3], dtype=object)
assert_(a.argmax() == 2)
# 定义一个测试函数,测试记录数组的字段处理
def test_recarray_fields(self):
# 根据 Ticket #372 的要求,测试记录数组的字段是否正确处理
dt0 = np.dtype([('f0', 'i4'), ('f1', 'i4')])
dt1 = np.dtype([('f0', 'i8'), ('f1', 'i8')])
for a in [np.array([(1, 2), (3, 4)], "i4,i4"),
np.rec.array([(1, 2), (3, 4)], "i4,i4"),
np.rec.array([(1, 2), (3, 4)]),
np.rec.fromarrays([(1, 2), (3, 4)], "i4,i4"),
np.rec.fromarrays([(1, 2), (3, 4)])]:
assert_(a.dtype in [dt0, dt1])
def test_random_shuffle(self):
# Ticket #374
# 创建一个包含0到4的数组,并重塑成5行1列的形状
a = np.arange(5).reshape((5, 1))
# 复制数组a到数组b
b = a.copy()
# 对数组b进行随机重排列
np.random.shuffle(b)
# 断言对数组b按列排序后与数组a相等
assert_equal(np.sort(b, axis=0), a)
def test_refcount_vdot(self):
# Changeset #3443
# 调用_assert_valid_refcount函数,验证np.vdot的引用计数
_assert_valid_refcount(np.vdot)
def test_startswith(self):
# 使用char.array创建一个包含字符串'Hi'和'There'的数组ca
ca = np.char.array(['Hi', 'There'])
# 断言数组ca中的每个字符串是否以'H'开头
assert_equal(ca.startswith('H'), [True, False])
def test_noncommutative_reduce_accumulate(self):
# Ticket #413
# 创建一个0到4的数组tosubtract
tosubtract = np.arange(5)
# 创建一个包含[2.0, 0.5, 0.25]的数组todivide
todivide = np.array([2.0, 0.5, 0.25])
# 断言对tosubtract数组使用subtract.reduce操作后的结果为-10
assert_equal(np.subtract.reduce(tosubtract), -10)
# 断言对todivide数组使用divide.reduce操作后的结果为16.0
assert_equal(np.divide.reduce(todivide), 16.0)
# 断言对tosubtract数组使用subtract.accumulate操作后的结果与给定数组相等
assert_array_equal(np.subtract.accumulate(tosubtract),
np.array([0, -1, -3, -6, -10]))
# 断言对todivide数组使用divide.accumulate操作后的结果与给定数组相等
assert_array_equal(np.divide.accumulate(todivide),
np.array([2., 4., 16.]))
def test_convolve_empty(self):
# Convolve操作对空输入数组应该引发ValueError异常
assert_raises(ValueError, np.convolve, [], [1])
assert_raises(ValueError, np.convolve, [1], [])
def test_multidim_byteswap(self):
# Ticket #449
# 创建一个复合dtype数组r,包含元组和数组,使用i2表示短整型
r = np.array([(1, (0, 1, 2))], dtype="i2,3i2")
# 断言对数组r调用byteswap()后的结果与给定数组相等
assert_array_equal(r.byteswap(),
np.array([(256, (0, 256, 512))], r.dtype))
def test_string_NULL(self):
# Changeset 3557
# 创建一个包含特殊字符的字符串数组,并获取其item
assert_equal(np.array("a\x00\x0b\x0c\x00").item(),
'a\x00\x0b\x0c')
def test_junk_in_string_fields_of_recarray(self):
# Ticket #483
# 创建一个结构化数组r,包含一个字符串字段var1,初始化为'abc'
r = np.array([[b'abc']], dtype=[('var1', '|S20')])
# 断言数组r中var1字段的第一个元素的第一个字符与b'abc'相等
assert_(asbytes(r['var1'][0][0]) == b'abc')
def test_take_output(self):
# 确保'take'函数遵循输出参数的规范
# 创建一个3行4列的数组x
x = np.arange(12).reshape((3, 4))
# 从数组x中取出第0列和第2列,存入数组a
a = np.take(x, [0, 2], axis=1)
# 创建一个与a形状相同的全零数组b
b = np.zeros_like(a)
# 将数组x的第0列和第2列取出,并存入数组b
np.take(x, [0, 2], axis=1, out=b)
# 断言数组a与数组b相等
assert_array_equal(a, b)
def test_take_object_fail(self):
# Issue gh-3001
# 创建一个包含浮点数对象的数组a
d = 123.
a = np.array([d, 1], dtype=object)
# 如果支持引用计数,则获取对象d的引用计数
if HAS_REFCOUNT:
ref_d = sys.getrefcount(d)
try:
# 尝试对数组a进行索引操作,应该引发IndexError异常
a.take([0, 100])
except IndexError:
pass
# 如果支持引用计数,则断言引用计数是否与预期相等
if HAS_REFCOUNT:
assert_(ref_d == sys.getrefcount(d))
def test_array_str_64bit(self):
# Ticket #501
# 创建一个包含NaN的双精度浮点数组s
s = np.array([1, np.nan], dtype=np.float64)
# 使用with语句设置错误状态为全部抛出,并调用np.array_str(s)
# 应该成功执行
with np.errstate(all='raise'):
np.array_str(s)
def test_frompyfunc_endian(self):
# Ticket #503
# 导入radians函数,从度到弧度转换
from math import radians
# 使用frompyfunc创建uradians函数,处理单个输入参数和输出结果
uradians = np.frompyfunc(radians, 1, 1)
# 创建一个大端序和小端序的双精度浮点数组
big_endian = np.array([83.4, 83.5], dtype='>f8')
little_endian = np.array([83.4, 83.5], dtype='<f8')
# 断言经uradians函数处理后的大端序和小端序数组结果几乎相等
assert_almost_equal(uradians(big_endian).astype(float),
uradians(little_endian).astype(float))
def test_mem_string_arr(self):
# Ticket #514
# 创建一个长为 40 的字符串 s
s = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
# 创建一个空列表 t,然后将字符串 s 横向堆叠到 t 中
t = []
np.hstack((t, s))
def test_arr_transpose(self):
# Ticket #516
# 创建一个形状为 (2, 2, ..., 2) 的随机数组 x(总共 16 个维度为 2)
x = np.random.rand(*(2,)*16)
# 将数组 x 沿着给定轴序列(0 到 15)进行转置操作,应该成功
x.transpose(list(range(16))) # 应该成功
def test_string_mergesort(self):
# Ticket #540
# 创建一个长度为 32 的字符串数组 x,所有元素为 'a'
x = np.array(['a']*32)
# 对数组 x 进行 mergesort 排序,检查排序后的索引是否与 np.arange(32) 相等
assert_array_equal(x.argsort(kind='m'), np.arange(32))
def test_argmax_byteorder(self):
# Ticket #546
# 创建一个包含 0.0, 1.0, 2.0 的大端序浮点数数组 a
a = np.arange(3, dtype='>f')
# 断言数组 a 中最大元素的索引等于最大元素的值
assert_(a[a.argmax()] == a.max())
def test_rand_seed(self):
# Ticket #555
# 循环设置 np.random 的种子为 0 到 3
for l in np.arange(4):
np.random.seed(l)
def test_mem_deallocation_leak(self):
# Ticket #562
# 创建一个长度为 5 的浮点数零数组 a
a = np.zeros(5, dtype=float)
# 创建一个与数组 a 具有相同数据类型的数组 b
b = np.array(a, dtype=float)
# 删除数组 a 和 b
del a, b
def test_mem_on_invalid_dtype(self):
"Ticket #583"
# 断言使用 np.fromiter 从给定的嵌套列表中创建数组时会引发 ValueError
assert_raises(ValueError, np.fromiter, [['12', ''], ['13', '']], str)
def test_dot_negative_stride(self):
# Ticket #588
# 创建两个数组 x 和 y
x = np.array([[1, 5, 25, 125., 625]])
y = np.array([[20.], [160.], [640.], [1280.], [1024.]])
# 创建 y 的反向切片拷贝 z,以及 y 的反向切片视图 y2
z = y[::-1].copy()
y2 = y[::-1]
# 断言使用反向切片拷贝 z 或反向切片视图 y2 计算的点积与 x 的点积结果相等
assert_equal(np.dot(x, z), np.dot(x, y2))
def test_object_casting(self):
# This used to trigger the object-type version of
# the bitwise_or operation, because float64 -> object
# casting succeeds
# 定义一个函数 rs,用于进行 object 类型转换的测试
def rs():
# 创建全为 1 的数组 x 和全为 0 的数组 y,然后对 x 进行按位或运算
x = np.ones([484, 286])
y = np.zeros([484, 286])
x |= y
# 断言调用 rs 函数会触发 TypeError
assert_raises(TypeError, rs)
def test_unicode_scalar(self):
# Ticket #600
# 创建一个 Unicode 字符串数组 x,指定数据类型为 'U6'
x = np.array(["DROND", "DROND1"], dtype="U6")
# 获取数组 x 的第二个元素 el,并使用 pickle 序列化和反序列化 el,验证是否相等
el = x[1]
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
new = pickle.loads(pickle.dumps(el, protocol=proto))
assert_equal(new, el)
def test_arange_non_native_dtype(self):
# Ticket #616
# 针对两种数据类型 '>f4' 和 '<f4' 进行测试
for T in ('>f4', '<f4'):
# 创建指定数据类型的 arange 数组,并断言其数据类型与预期相等
dt = np.dtype(T)
assert_equal(np.arange(0, dtype=dt).dtype, dt)
assert_equal(np.arange(0.5, dtype=dt).dtype, dt)
assert_equal(np.arange(5, dtype=dt).dtype, dt)
def test_bool_flat_indexing_invalid_nr_elements(self):
# 创建一个全为 1 的浮点数数组 s 和一个长度为 1 的浮点数数组 x
s = np.ones(10, dtype=float)
x = np.array((15,), dtype=float)
def ia(x, s, v):
# 定义一个函数 ia,用于测试对 x 进行索引操作时是否会引发 IndexError
x[(s > 0)] = v
# 断言调用 ia 函数时会引发 IndexError,因为索引数组的长度不匹配
assert_raises(IndexError, ia, x, s, np.zeros(9, dtype=float))
assert_raises(IndexError, ia, x, s, np.zeros(11, dtype=float))
# 老的特殊情况(不同的代码路径):同样断言调用 ia 函数时会引发 ValueError
assert_raises(ValueError, ia, x.flat, s, np.zeros(9, dtype=float))
assert_raises(ValueError, ia, x.flat, s, np.zeros(11, dtype=float))
def test_mem_scalar_indexing(self):
# Ticket #603
# 创建一个包含单个浮点数 0 的数组 x
x = np.array([0], dtype=float)
# 创建一个包含单个整数 0 的索引数组 index
index = np.array(0, dtype=np.int32)
# 对数组 x 进行标量索引操作
x[index]
def test_binary_repr_0_width(self):
# 断言调用 np.binary_repr(0, width=3) 的结果为 '000'
assert_equal(np.binary_repr(0, width=3), '000')
def test_fromstring(self):
# 使用 np.fromstring() 函数将字符串 "12:09:09" 按照指定的 int 类型和分隔符 ":" 解析为数组
assert_equal(np.fromstring("12:09:09", dtype=int, sep=":"),
[12, 9, 9])
def test_searchsorted_variable_length(self):
# 创建包含字符串数组的 numpy 数组 x 和 y
x = np.array(['a', 'aa', 'b'])
y = np.array(['d', 'e'])
# 使用 x.searchsorted(y) 查找 y 中每个元素在 x 中的插入位置,并断言结果为 [3, 3]
assert_equal(x.searchsorted(y), [3, 3])
def test_string_argsort_with_zeros(self):
# 检查包含零的字符串数组的 argsort 行为
x = np.frombuffer(b"\x00\x02\x00\x01", dtype="|S2")
# 使用 'm' 类型的 argsort 对 x 进行排序,并断言结果为 [1, 0]
assert_array_equal(x.argsort(kind='m'), np.array([1, 0]))
# 使用 'q' 类型的 argsort 对 x 进行排序,并断言结果为 [1, 0]
assert_array_equal(x.argsort(kind='q'), np.array([1, 0]))
def test_string_sort_with_zeros(self):
# 检查包含零的字符串数组的排序行为
x = np.frombuffer(b"\x00\x02\x00\x01", dtype="|S2")
y = np.frombuffer(b"\x00\x01\x00\x02", dtype="|S2")
# 使用 'q' 类型的排序对 x 进行排序,并断言结果等于数组 y
assert_array_equal(np.sort(x, kind="q"), y)
def test_copy_detection_zero_dim(self):
# Ticket #658
# 创建一个零维度的索引数组,并将其转置并重塑为三列形式
np.indices((0, 3, 4)).T.reshape(-1, 3)
def test_flat_byteorder(self):
# Ticket #657
# 创建一个包含 0 到 9 的 numpy 数组 x
x = np.arange(10)
# 断言将 x 转换为大端字节顺序后与 x 转换为小端字节顺序后的扁平化数组内容相等
assert_array_equal(x.astype('>i4'), x.astype('<i4').flat[:])
# 断言 x 转换为大端字节顺序后的扁平化数组内容与 x 转换为小端字节顺序相等
assert_array_equal(x.astype('>i4').flat[:], x.astype('<i4'))
def test_sign_bit(self):
# 创建一个包含 0, -0.0, 0 的 numpy 数组 x
x = np.array([0, -0.0, 0])
# 断言 np.abs(x) 的结果为字符串 '[0. 0. 0.]'
assert_equal(str(np.abs(x)), '[0. 0. 0.]')
def test_flat_index_byteswap(self):
# Ticket #658
# 对每种数据类型(小端和大端)的数组进行测试
for dt in (np.dtype('<i4'), np.dtype('>i4')):
# 创建一个包含 [-1, 0, 1] 的指定数据类型的 numpy 数组 x
x = np.array([-1, 0, 1], dtype=dt)
# 断言 x 的扁平化数组的第一个元素的数据类型与 x 的第一个元素的数据类型相等
assert_equal(x.flat[0].dtype, x[0].dtype)
def test_copy_detection_corner_case(self):
# Ticket #658
# 创建一个零维度的索引数组,并将其转置并重塑为三列形式
np.indices((0, 3, 4)).T.reshape(-1, 3)
# 测试对象数组的引用计数
def test_object_array_refcounting(self):
# 检查系统是否支持获取引用计数的方法,若不支持则退出
if not hasattr(sys, 'getrefcount'):
return
# 注意:以下内容可能仅适用于 CPython
# 获取引用计数函数的引用
cnt = sys.getrefcount
# 创建三个独立的对象
a = object()
b = object()
c = object()
# 记录各对象的初始引用计数
cnt0_a = cnt(a)
cnt0_b = cnt(b)
cnt0_c = cnt(c)
# -- 0d -> 1-d 广播切片赋值
# 创建一个元素类型为对象的长度为 5 的全零数组
arr = np.zeros(5, dtype=np.object_)
# 对整个数组赋值为对象 a
arr[:] = a
# 断言对象 a 的引用计数增加了 5
assert_equal(cnt(a), cnt0_a + 5)
# 对整个数组赋值为对象 b
arr[:] = b
# 断言对象 a 的引用计数恢复到原始值
assert_equal(cnt(a), cnt0_a)
# 断言对象 b 的引用计数增加了 5
assert_equal(cnt(b), cnt0_b + 5)
# 对数组的前两个元素赋值为对象 c
arr[:2] = c
# 断言对象 b 的引用计数减少了 2
assert_equal(cnt(b), cnt0_b + 3)
# 断言对象 c 的引用计数增加了 2
assert_equal(cnt(c), cnt0_c + 2)
# 删除数组对象
del arr
# -- 1-d -> 2-d 广播切片赋值
# 创建一个形状为 (5, 2) 的对象类型全零数组
arr = np.zeros((5, 2), dtype=np.object_)
# 创建一个形状为 (2,) 的对象类型全零数组
arr0 = np.zeros(2, dtype=np.object_)
# 将 arr0 的第一个元素赋值为对象 a
arr0[0] = a
# 断言对象 a 的引用计数增加了 1
assert_(cnt(a) == cnt0_a + 1)
# 将 arr0 的第二个元素赋值为对象 b
arr0[1] = b
# 断言对象 b 的引用计数增加了 1
assert_(cnt(b) == cnt0_b + 1)
# 对整个二维数组 arr 赋值为二维数组 arr0
arr[:, :] = arr0
# 断言对象 a 和对象 b 的引用计数都增加了 5
assert_(cnt(a) == cnt0_a + 6)
assert_(cnt(b) == cnt0_b + 6)
# 对二维数组 arr 的第一列赋值为 None
arr[:, 0] = None
# 断言对象 a 的引用计数减少了 5
assert_(cnt(a) == cnt0_a + 1)
# 删除数组对象及其引用的对象
del arr, arr0
# -- 2-d 复制和展平
# 创建一个形状为 (5, 2) 的对象类型全零数组
arr = np.zeros((5, 2), dtype=np.object_)
# 对数组 arr 的第一列赋值为对象 a
arr[:, 0] = a
# 对数组 arr 的第二列赋值为对象 b
arr[:, 1] = b
# 断言对象 a 和对象 b 的引用计数都增加了 5
assert_(cnt(a) == cnt0_a + 5)
assert_(cnt(b) == cnt0_b + 5)
# 对数组 arr 进行深拷贝,arr2 指向新的对象
arr2 = arr.copy()
# 断言对象 a 和对象 b 的引用计数都增加了 5
assert_(cnt(a) == cnt0_a + 10)
assert_(cnt(b) == cnt0_b + 10)
# 将 arr 的第一列展平赋值给 arr2,arr2 指向新的对象
arr2 = arr[:, 0].copy()
# 断言对象 a 的引用计数增加了 5,对象 b 的引用计数不变
assert_(cnt(a) == cnt0_a + 10)
assert_(cnt(b) == cnt0_b + 5)
# 将数组 arr 展平赋值给 arr2,arr2 指向新的对象
arr2 = arr.flatten()
# 断言对象 a 和对象 b 的引用计数都增加了 5
assert_(cnt(a) == cnt0_a + 10)
assert_(cnt(b) == cnt0_b + 10)
# 删除数组对象及其引用的对象
del arr, arr2
# -- concatenate, repeat, take, choose
# 创建一个形状为 (5, 1) 的对象类型全零数组 arr1
arr1 = np.zeros((5, 1), dtype=np.object_)
# 创建一个形状为 (5, 1) 的对象类型全零数组 arr2
arr2 = np.zeros((5, 1), dtype=np.object_)
# 将 arr1 整体赋值为对象 a
arr1[...] = a
# 将 arr2 整体赋值为对象 b
arr2[...] = b
# 断言对象 a 和对象 b 的引用计数都增加了 5
assert_(cnt(a) == cnt0_a + 5)
assert_(cnt(b) == cnt0_b + 5)
# 将 arr1 和 arr2 拼接为 tmp
tmp = np.concatenate((arr1, arr2))
# 断言对象 a 和对象 b 的引用计数都增加了 5 + 5
assert_(cnt(a) == cnt0_a + 5 + 5)
assert_(cnt(b) == cnt0_b + 5 + 5)
# 将 arr1 沿指定轴重复 3 次赋值给 tmp
tmp = arr1.repeat(3, axis=0)
# 断言对象 a 的引用计数增加了 5 + 3*5
assert_(cnt(a) == cnt0_a + 5 + 3*5)
# 从 arr1 中按给定索引数组取元素赋值给 tmp
tmp = arr1.take([1, 2, 3], axis=0)
# 断言对象 a 的引用计数增加了 5 + 3
assert_(cnt(a) == cnt0_a + 5 + 3)
# 创建一个 (5, 1) 的整数数组 x
x = np.array([[0], [1], [0], [1], [1]], int)
# 根据 x 的值从 arr1 或 arr2 中选择赋值给 tmp
tmp = x.choose(arr1, arr2)
# 断言对象 a 的引用计数增加了 5 + 2,对象 b 的引用计数增加了 5 + 3
assert_(cnt(a) == cnt0_a + 5 + 2)
assert_(cnt(b) == cnt0_b + 5 + 3)
# 删除临时变量 tmp 避免 pyflakes 的未使用变量警告
# 测试自定义浮点数转换为数组的方法
def test_mem_custom_float_to_array(self):
# Ticket 702
# 定义一个自定义类 MyFloat
class MyFloat:
# 类方法 __float__ 返回浮点数 1.0
def __float__(self):
return 1.0
# 创建一个至少为 1 维的数组,包含 MyFloat() 的实例
tmp = np.atleast_1d([MyFloat()])
# 将数组元素类型转换为浮点数,应该成功
tmp.astype(float) # 应该成功
def test_object_array_refcount_self_assign(self):
# 测试用例:测试对象数组的自我赋值和引用计数
# Ticket #711
# 定义一个受影响的对象类 VictimObject
class VictimObject:
deleted = False
# 析构函数,设置 deleted 标志为 True
def __del__(self):
self.deleted = True
# 创建一个 VictimObject 的实例 d
d = VictimObject()
# 创建一个包含 5 个元素的对象数组,初始化为零
arr = np.zeros(5, dtype=np.object_)
# 将数组 arr 的所有元素赋值为对象 d
arr[:] = d
# 删除对象 d
del d
# 再次将数组 arr 的所有元素赋值为数组本身,这可能导致 'd' 的引用计数在此处变为零
arr[:] = arr
# 断言:数组 arr 的第一个元素的 deleted 标志为 False
assert_(not arr[0].deleted)
# 再次将数组 arr 的所有元素赋值为数组本身,尝试再次诱导段错误...
arr[:] = arr
# 断言:数组 arr 的第一个元素的 deleted 标志为 False
assert_(not arr[0].deleted)
def test_mem_fromiter_invalid_dtype_string(self):
# 测试用例:测试从迭代器创建数组时使用非法字符串类型的 dtype
# Ticket #712
# 列表 x
x = [1, 2, 3]
# 断言:使用 dtype='S' 从迭代器 [xi for xi in x] 创建数组将引发 ValueError
assert_raises(ValueError, np.fromiter, [xi for xi in x], dtype='S')
def test_reduce_big_object_array(self):
# 测试用例:测试减少大型对象数组
# Ticket #713
# 设置新的缓冲区大小
oldsize = np.setbufsize(10*16)
# 创建一个包含 161 个 None 元素的对象数组 a
a = np.array([None]*161, object)
# 断言:数组 a 中没有任何元素为真值
assert_(not np.any(a))
# 恢复旧的缓冲区大小
np.setbufsize(oldsize)
def test_mem_0d_array_index(self):
# 测试用例:测试 0 维数组的索引
# Ticket #714
# 创建一个包含 10 个零元素的 1 维数组,然后使用索引为 0 的 0 维数组
np.zeros(10)[np.array(0)]
def test_nonnative_endian_fill(self):
# 测试用例:测试非本机字节序数组的填充
# 如果 sys.byteorder 为 'little',则创建大端序 int32 类型的 dtype
# 否则,创建小端序 int32 类型的 dtype
if sys.byteorder == 'little':
dtype = np.dtype('>i4')
else:
dtype = np.dtype('<i4')
# 创建一个形状为 [1],类型为 dtype 的空数组 x
x = np.empty([1], dtype=dtype)
# 使用值 1 填充数组 x
x.fill(1)
# 断言:数组 x 应该等于 [1],类型为 dtype
assert_equal(x, np.array([1], dtype=dtype))
def test_dot_alignment_sse2(self):
# 测试用例:测试 dot 函数在 SSE2 下的对齐性
# Ticket #551, changeset r5140
# 创建一个形状为 (30, 40) 的零数组 x
x = np.zeros((30, 40))
# 对于范围在 2 到 pickle.HIGHEST_PROTOCOL+1 的协议号 proto
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
# 通过 pickle 序列化和反序列化数组 x 得到数组 y
y = pickle.loads(pickle.dumps(x, protocol=proto))
# 现在,数组 y 通常不会在 8 字节边界上对齐
# 创建一个形状为 (1, y.shape[0]),值全为 1 的数组 z
z = np.ones((1, y.shape[0]))
# 这应该不会引发段错误:计算 z 和 y 的点积
np.dot(z, y)
def test_astype_copy(self):
# 测试用例:测试 astype 方法的复制
# Ticket #788, changeset r5155
# 从文件 'data/astype_copy.pkl' 加载数据
data_dir = path.join(path.dirname(__file__), 'data')
filename = path.join(data_dir, "astype_copy.pkl")
with open(filename, 'rb') as f:
xp = pickle.load(f, encoding='latin1')
# 将 xp 转换为 float64 类型的数组 xpd
xpd = xp.astype(np.float64)
# 断言:xp 和 xpd 的数据地址不相同
assert_((xp.__array_interface__['data'][0] !=
xpd.__array_interface__['data'][0]))
def test_compress_small_type(self):
# 测试用例:测试 compress 方法对小型类型的处理
# Ticket #789, changeset 5217.
# 如果无法安全转换,使用 out 参数调用 compress 方法会导致段错误
import numpy as np
a = np.array([[1, 2], [3, 4]])
b = np.zeros((2, 1), dtype=np.single)
try:
# 尝试使用 out 参数调用 compress 方法,如果无法安全转换应该引发 TypeError
a.compress([True, False], axis=1, out=b)
raise AssertionError("compress with an out which cannot be "
"safely casted should not return "
"successfully")
except TypeError:
pass
def test_attributes(self):
# 定义一个继承自 np.ndarray 的子类 TestArray
class TestArray(np.ndarray):
# __new__ 方法用于创建新的实例
def __new__(cls, data, info):
# 创建一个普通的 numpy 数组
result = np.array(data)
# 将其视图转换为 TestArray 类型的实例
result = result.view(cls)
# 设置额外的属性 info
result.info = info
return result
# __array_finalize__ 方法用于在实例化之后进行属性设置
def __array_finalize__(self, obj):
# 如果传入的 obj 有 info 属性,则继承该属性
self.info = getattr(obj, 'info', '')
# 创建一个 TestArray 类型的实例 dat
dat = TestArray([[1, 2, 3, 4], [5, 6, 7, 8]], 'jubba')
# 断言 dat 的 info 属性为 'jubba'
assert_(dat.info == 'jubba')
# 调整 dat 的尺寸为 (4, 2),并保持 info 属性不变
dat.resize((4, 2))
assert_(dat.info == 'jubba')
# 对 dat 进行排序,并保持 info 属性不变
dat.sort()
assert_(dat.info == 'jubba')
# 使用数值 2 填充 dat,并保持 info 属性不变
dat.fill(2)
assert_(dat.info == 'jubba')
# 在索引位置 [2, 3, 4] 处放置数值 [6, 3, 4],并保持 info 属性不变
dat.put([2, 3, 4], [6, 3, 4])
assert_(dat.info == 'jubba')
# 设置第一个元素为 4,类型为 np.int32,并保持 info 属性不变
dat.setfield(4, np.int32, 0)
assert_(dat.info == 'jubba')
# 设置 dat 的标志,并保持 info 属性不变
dat.setflags()
assert_(dat.info == 'jubba')
# 对 dat 的每行应用 all 方法,并保持 info 属性不变
assert_(dat.all(1).info == 'jubba')
# 对 dat 的每行应用 any 方法,并保持 info 属性不变
assert_(dat.any(1).info == 'jubba')
# 对 dat 的每行应用 argmax 方法,并保持 info 属性不变
assert_(dat.argmax(1).info == 'jubba')
# 对 dat 的每行应用 argmin 方法,并保持 info 属性不变
assert_(dat.argmin(1).info == 'jubba')
# 对 dat 的每行应用 argsort 方法,并保持 info 属性不变
assert_(dat.argsort(1).info == 'jubba')
# 将 dat 转换为 TestArray 类型,并保持 info 属性不变
assert_(dat.astype(TestArray).info == 'jubba')
# 对 dat 的每个元素进行字节交换,并保持 info 属性不变
assert_(dat.byteswap().info == 'jubba')
# 将 dat 的值限制在 [2, 7] 的范围内,并保持 info 属性不变
assert_(dat.clip(2, 7).info == 'jubba')
# 对 dat 进行压缩,保留索引 [0, 1, 1] 对应的值,并保持 info 属性不变
assert_(dat.compress([0, 1, 1]).info == 'jubba')
# 对 dat 进行共轭操作,并保持 info 属性不变
assert_(dat.conj().info == 'jubba')
# 对 dat 进行共轭操作,并保持 info 属性不变
assert_(dat.conjugate().info == 'jubba')
# 复制 dat,并保持 info 属性不变
assert_(dat.copy().info == 'jubba')
# 创建一个新的 TestArray 实例 dat2,并保持 info 属性不变
dat2 = TestArray([2, 3, 1, 0], 'jubba')
# 从 choices 中根据 dat2 的值进行选择,并保持 info 属性不变
choices = [[0, 1, 2, 3], [10, 11, 12, 13],
[20, 21, 22, 23], [30, 31, 32, 33]]
assert_(dat2.choose(choices).info == 'jubba')
# 对 dat 的每行应用 cumprod 方法,并保持 info 属性不变
assert_(dat.cumprod(1).info == 'jubba')
# 对 dat 的每行应用 cumsum 方法,并保持 info 属性不变
assert_(dat.cumsum(1).info == 'jubba')
# 返回 dat 的对角线元素,并保持 info 属性不变
assert_(dat.diagonal().info == 'jubba')
# 将 dat 展平为一维数组,并保持 info 属性不变
assert_(dat.flatten().info == 'jubba')
# 返回 dat 中指定字段的值,并保持 info 属性不变
assert_(dat.getfield(np.int32, 0).info == 'jubba')
# 返回 dat 的虚部,并保持 info 属性不变
assert_(dat.imag.info == 'jubba')
# 返回 dat 的每行的最大值,并保持 info 属性不变
assert_(dat.max(1).info == 'jubba')
# 返回 dat 的每行的均值,并保持 info 属性不变
assert_(dat.mean(1).info == 'jubba')
# 返回 dat 的每行的最小值,并保持 info 属性不变
assert_(dat.min(1).info == 'jubba')
# 返回 dat 的每行的乘积,并保持 info 属性不变
assert_(dat.prod(1).info == 'jubba')
# 返回 dat 展平为一维数组,并保持 info 属性不变
assert_(dat.ravel().info == 'jubba')
# 返回 dat 的实部,并保持 info 属性不变
assert_(dat.real.info == 'jubba')
# 返回 dat 的重复值,并保持 info 属性不变
assert_(dat.repeat(2).info == 'jubba')
# 返回 dat 的形状为 (2, 4) 的数组,并保持 info 属性不变
assert_(dat.reshape((2, 4)).info == 'jubba')
# 返回 dat 的每个元素进行四舍五入后的值,并保持 info 属性不变
assert_(dat.round().info == 'jubba')
# 去除 dat 中的单维度条目,并保持 info 属性不变
assert_(dat.squeeze().info == 'jubba')
# 返回 dat 的每行的标准差,并保持 info 属性不变
assert_(dat.std(1).info == 'jubba')
# 返回 dat 的每行的总和,并保持 info 属性不变
assert_(dat.sum(1).info == 'jubba')
# 交换 dat 的轴,并保持 info 属性不变
assert_(dat.swapaxes(0, 1).info == 'jubba')
# 返回 dat 中指定索引的元素,并保持 info 属性不变
assert_(dat.take([2, 3, 5]).info == 'jubba')
# 返回 dat 的转置,并保持 info 属性不变
assert_(dat.transpose().info == 'jubba')
# 返回 dat 的转置,并保持 info 属性不变
assert_(dat.T.info == 'jubba')
# 返回 dat 的每行的方差,并保持 info 属性不变
assert_(dat.var(1).info == 'jubba')
# 返回 dat 的视图,并保持 info 属性不变
assert_(dat.view(TestArray).info == 'jubba')
# 这些方法不保留子类,所以要确保类型是 np.ndarray
# 检查 dat.nonzero() 的结果类型是否为
def test_recarray_tolist(self):
# Ticket #793, changeset r5215
# Comparisons fail for NaN, so we can't use random memory
# for the test.
# 创建一个长度为40的全零数组,数据类型为int8
buf = np.zeros(40, dtype=np.int8)
# 使用给定的缓冲区创建一个记录数组,包含2行,字段格式为'i4,f8,f8',字段名为'id', 'x', 'y'
a = np.recarray(2, formats="i4,f8,f8", names="id,x,y", buf=buf)
# 将记录数组转换为普通的Python列表
b = a.tolist()
# 断言:检查记录数组中第一个元素转换为列表后的值与原始记录数组第一个元素的值相等
assert_(a[0].tolist() == b[0])
# 断言:检查记录数组中第二个元素转换为列表后的值与原始记录数组第二个元素的值相等
assert_(a[1].tolist() == b[1])
def test_nonscalar_item_method(self):
# Make sure that .item() fails graciously when it should
# 创建一个包含0到4的数组
a = np.arange(5)
# 断言:验证当尝试在非标量上调用.item()时,会引发ValueError异常
assert_raises(ValueError, a.item)
def test_char_array_creation(self):
# 创建一个包含字符'123'的数组,数据类型为'c'(字符)
a = np.array('123', dtype='c')
# 创建一个包含字节串b'1', b'2', b'3'的数组
b = np.array([b'1', b'2', b'3'])
# 断言:验证字符数组a与字节串数组b相等
assert_equal(a, b)
def test_unaligned_unicode_access(self):
# Ticket #825
# 遍历1到8之间的数值
for i in range(1, 9):
# 生成格式为'Si, U2'的自定义dtype
msg = 'unicode offset: %d chars' % i
t = np.dtype([('a', 'S%d' % i), ('b', 'U2')])
# 创建一个包含一个元素的数组,元素是一个元组,第一个元素为字节串b'a',第二个元素为字符串'b'
x = np.array([(b'a', 'b')], dtype=t)
# 断言:验证数组x的字符串表示与预期的字符串相等,错误消息为msg
assert_equal(str(x), "[(b'a', 'b')]", err_msg=msg)
def test_sign_for_complex_nan(self):
# Ticket 794.
# 忽略无效操作的警告
with np.errstate(invalid='ignore'):
# 创建一个包含特定复数值的数组
C = np.array([-np.inf, -3+4j, 0, 4-3j, np.inf, np.nan])
# 计算数组C中各元素的符号值
have = np.sign(C)
# 创建一个预期结果的数组
want = np.array([-1+0j, -0.6+0.8j, 0+0j, 0.8-0.6j, 1+0j,
complex(np.nan, np.nan)])
# 断言:验证计算得到的符号数组与预期的数组相等
assert_equal(have, want)
def test_for_equal_names(self):
# Ticket #674
# 创建一个自定义dtype,包含两个字段,分别命名为'foo'和'bar',数据类型为float
dt = np.dtype([('foo', float), ('bar', float)])
# 创建一个包含10个元素的零数组,数据类型为自定义dtype dt
a = np.zeros(10, dt)
# 将dtype的字段名转换为列表
b = list(a.dtype.names)
# 修改列表中的第一个元素
b[0] = "notfoo"
# 将修改后的字段名列表重新赋值给数组的dtype字段名
a.dtype.names = b
# 断言:验证修改后的第一个字段名为"notfoo"
assert_(a.dtype.names[0] == "notfoo")
# 断言:验证第二个字段名为"bar"
assert_(a.dtype.names[1] == "bar")
def test_for_object_scalar_creation(self):
# Ticket #816
# 创建一个对象类型的标量
a = np.object_()
# 创建一个整数类型的对象标量
b = np.object_(3)
# 创建一个浮点数类型的对象标量
b2 = np.object_(3.0)
# 创建一个包含整数的数组对象
c = np.object_([4, 5])
# 创建一个包含None、空字典和空列表的数组对象
d = np.object_([None, {}, []])
# 断言:验证a是None
assert_(a is None)
# 断言:验证b的类型是int
assert_(type(b) is int)
# 断言:验证b2的类型是float
assert_(type(b2) is float)
# 断言:验证c的类型是np.ndarray
assert_(type(c) is np.ndarray)
# 断言:验证c的dtype是object
assert_(c.dtype == object)
# 断言:验证d的dtype是object
assert_(d.dtype == object)
def test_array_resize_method_system_error(self):
# Ticket #840 - order should be an invalid keyword.
# 创建一个2x2的二维数组
x = np.array([[0, 1], [2, 3]])
# 断言:验证尝试使用无效关键字'order'调整数组大小时会引发TypeError异常
assert_raises(TypeError, x.resize, (2, 2), order='C')
def test_for_zero_length_in_choose(self):
"Ticket #882"
# 创建一个包含一个整数的数组
a = np.array(1)
# 断言:验证在调用choose函数时,如果传递空列表,会引发ValueError异常
assert_raises(ValueError, lambda x: x.choose([]), a)
def test_array_ndmin_overflow(self):
"Ticket #947."
# 断言:验证在创建数组时,如果指定的ndmin值大于系统能处理的最大值,会引发ValueError异常
assert_raises(ValueError, lambda: np.array([1], ndmin=65))
def test_void_scalar_with_titles(self):
# No ticket
# 创建一个包含数据和标题的数据列表
data = [('john', 4), ('mary', 5)]
# 创建一个复合dtype,包含两个字段,第一个字段名为('source:yy', 'name'),数据类型为'O',第二个字段名为('source:xx', 'id'),数据类型为int
dtype1 = [(('source:yy', 'name'), 'O'), (('source:xx', 'id'), int)]
# 创建一个数组,使用自定义dtype1
arr = np.array(data, dtype=dtype1)
# 断言:验证数组arr的第一个元素的第一个字段值为'john'
assert_(arr[0][0] == 'john')
# 断言:验证数组arr的第一个元素的第二个字段值为4
assert_(arr[0][1] == 4)
def test_void_scalar_constructor(self):
#Issue #1550
# 对 void 标量构造函数的测试
#Create test string data, construct void scalar from data and assert
#that void scalar contains original data.
# 创建测试字符串数据,从数据构造 void 标量并断言
# void 标量包含原始数据。
test_string = np.array("test")
test_string_void_scalar = np._core.multiarray.scalar(
np.dtype(("V", test_string.dtype.itemsize)), test_string.tobytes())
assert_(test_string_void_scalar.view(test_string.dtype) == test_string)
#Create record scalar, construct from data and assert that
#reconstructed scalar is correct.
# 创建记录标量,从数据构造并断言重建的标量是正确的。
test_record = np.ones((), "i,i")
test_record_void_scalar = np._core.multiarray.scalar(
test_record.dtype, test_record.tobytes())
assert_(test_record_void_scalar == test_record)
# Test pickle and unpickle of void and record scalars
# 测试 void 和记录标量的序列化和反序列化
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
assert_(pickle.loads(
pickle.dumps(test_string, protocol=proto)) == test_string)
assert_(pickle.loads(
pickle.dumps(test_record, protocol=proto)) == test_record)
@_no_tracing
def test_blasdot_uninitialized_memory(self):
# Ticket #950
# 对未初始化内存的 BLAS dot 进行测试
for m in [0, 1, 2]:
for n in [0, 1, 2]:
for k in range(3):
# Try to ensure that x->data contains non-zero floats
# 尝试确保 x->data 包含非零浮点数
x = np.array([123456789e199], dtype=np.float64)
if IS_PYPY:
x.resize((m, 0), refcheck=False)
else:
x.resize((m, 0))
y = np.array([123456789e199], dtype=np.float64)
if IS_PYPY:
y.resize((0, n), refcheck=False)
else:
y.resize((0, n))
# `dot` should just return zero (m, n) matrix
# `dot` 应该返回一个全零的 (m, n) 矩阵
z = np.dot(x, y)
assert_(np.all(z == 0))
assert_(z.shape == (m, n))
def test_zeros(self):
# Regression test for #1061.
# Set a size which cannot fit into a 64 bits signed integer
# #1061 的回归测试
# 设置一个超出 64 位有符号整数范围的大小
sz = 2 ** 64
with assert_raises_regex(ValueError,
'Maximum allowed dimension exceeded'):
np.empty(sz)
def test_huge_arange(self):
# Regression test for #1062.
# Set a size which cannot fit into a 64 bits signed integer
# #1062 的回归测试
# 设置一个超出 64 位有符号整数范围的大小
sz = 2 ** 64
with assert_raises_regex(ValueError,
'Maximum allowed size exceeded'):
np.arange(sz)
assert_(np.size == sz)
def test_fromiter_bytes(self):
# Ticket #1058
# 对 fromiter 函数处理字节的测试
a = np.fromiter(list(range(10)), dtype='b')
b = np.fromiter(list(range(10)), dtype='B')
assert_(np.all(a == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))
assert_(np.all(b == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))
def test_array_from_sequence_scalar_array(self):
# Ticket #1078: segfaults when creating an array with a sequence of
# 0d arrays.
# 创建包含两个元素的对象数组,每个元素是一个 0 维数组和一个标量
a = np.array((np.ones(2), np.array(2)), dtype=object)
assert_equal(a.shape, (2,))
assert_equal(a.dtype, np.dtype(object))
assert_equal(a[0], np.ones(2)) # 断言第一个元素是包含两个 1 的数组
assert_equal(a[1], np.array(2)) # 断言第二个元素是标量 2
# 创建包含两个元素的对象数组,每个元素是一个包含一个元素的元组和一个标量
a = np.array(((1,), np.array(1)), dtype=object)
assert_equal(a.shape, (2,))
assert_equal(a.dtype, np.dtype(object))
assert_equal(a[0], (1,)) # 断言第一个元素是元组 (1,)
assert_equal(a[1], np.array(1)) # 断言第二个元素是标量 1
def test_array_from_sequence_scalar_array2(self):
# Ticket #1081: weird array with strange input...
# 创建包含两个元素的对象数组,每个元素是一个空数组和一个标量 0
t = np.array([np.array([]), np.array(0, object)], dtype=object)
assert_equal(t.shape, (2,))
assert_equal(t.dtype, np.dtype(object))
def test_array_too_big(self):
# Ticket #1080.
# 断言创建指定大小的零数组会引发 ValueError 异常
assert_raises(ValueError, np.zeros, [975]*7, np.int8)
assert_raises(ValueError, np.zeros, [26244]*5, np.int8)
def test_dtype_keyerrors_(self):
# Ticket #1106.
# 创建具有单个字段 'f1' 的结构化数据类型,并断言访问不存在的字段名、索引和非整数索引时会引发对应异常
dt = np.dtype([('f1', np.uint)])
assert_raises(KeyError, dt.__getitem__, "f2")
assert_raises(IndexError, dt.__getitem__, 1)
assert_raises(TypeError, dt.__getitem__, 0.0)
def test_lexsort_buffer_length(self):
# Ticket #1217, don't segfault.
# 创建两个不同类型的数组 a 和 b,然后进行 lexsort 操作,断言不会发生段错误
a = np.ones(100, dtype=np.int8)
b = np.ones(100, dtype=np.int32)
i = np.lexsort((a[::-1], b))
assert_equal(i, np.arange(100, dtype=int))
def test_object_array_to_fixed_string(self):
# Ticket #1235.
# 创建包含两个字符串元素的对象数组 a,并将其转换为固定长度为 8 的字符串数组 b、长度为 5 的字符串数组 c 和长度为 12 的字符串数组 d
a = np.array(['abcdefgh', 'ijklmnop'], dtype=np.object_)
b = np.array(a, dtype=(np.str_, 8))
assert_equal(a, b)
c = np.array(a, dtype=(np.str_, 5))
assert_equal(c, np.array(['abcde', 'ijklm']))
d = np.array(a, dtype=(np.str_, 12))
assert_equal(a, d)
e = np.empty((2, ), dtype=(np.str_, 8))
e[:] = a[:]
assert_equal(a, e)
def test_unicode_to_string_cast(self):
# Ticket #1240.
# 创建包含 Unicode 字符串的数组 a,并尝试将其转换为字节字符串数组,预期会引发 UnicodeEncodeError 异常
a = np.array([['abc', '\u03a3'],
['asdf', 'erw']],
dtype='U')
assert_raises(UnicodeEncodeError, np.array, a, 'S4')
def test_unicode_to_string_cast_error(self):
# gh-15790
# 创建包含特定 Unicode 字符的数组 a,并尝试将其重塑为二维数组 b,预期在转换为字节字符串数组时会引发 UnicodeEncodeError 异常
a = np.array(['\x80'] * 129, dtype='U3')
assert_raises(UnicodeEncodeError, np.array, a, 'S')
b = a.reshape(3, 43)[:-1, :-1]
assert_raises(UnicodeEncodeError, np.array, b, 'S')
def test_mixed_string_byte_array_creation(self):
# 测试混合字符串和字节数组的创建
a = np.array(['1234', b'123'])
# 断言数组的每个元素所占空间大小为16字节
assert_(a.itemsize == 16)
a = np.array([b'123', '1234'])
assert_(a.itemsize == 16)
a = np.array(['1234', b'123', '12345'])
assert_(a.itemsize == 20)
a = np.array([b'123', '1234', b'12345'])
assert_(a.itemsize == 20)
a = np.array([b'123', '1234', b'1234'])
assert_(a.itemsize == 16)
def test_misaligned_objects_segfault(self):
# Ticket #1198 and #1267
# 创建一个dtype为对象和字符的全零数组
a1 = np.zeros((10,), dtype='O,c')
# 创建一个长度为10的字节数组
a2 = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'], 'S10')
# 将a2赋值给a1的第一列
a1['f0'] = a2
repr(a1)
# 求a1['f0']中最大值的索引
np.argmax(a1['f0'])
# 修改a1['f0']的第二个元素为"FOO"
a1['f0'][1] = "FOO"
# 将a1['f0']所有元素赋值为"FOO"
a1['f0'] = "FOO"
# 将a1['f0']转换为dtype为'S'的字节数组
np.array(a1['f0'], dtype='S')
# 返回a1['f0']中非零元素的索引
np.nonzero(a1['f0'])
# 对a1进行排序
a1.sort()
# 深拷贝a1
copy.deepcopy(a1)
def test_misaligned_scalars_segfault(self):
# Ticket #1267
# 创建一个dtype为字符和对象的数组s1
s1 = np.array(('a', 'Foo'), dtype='c,O')
# 创建一个dtype为字符和对象的数组s2
s2 = np.array(('b', 'Bar'), dtype='c,O')
# 将s2['f1']赋值给s1['f1']
s1['f1'] = s2['f1']
# 将s1['f1']的所有元素赋值为"Baz"
s1['f1'] = 'Baz'
def test_misaligned_dot_product_objects(self):
# Ticket #1267
# 测试未对齐的点积对象
# 创建一个dtype为对象和字符的二维数组a
a = np.array([[(1, 'a'), (0, 'a')], [(0, 'a'), (1, 'a')]], dtype='O,c')
# 创建一个dtype为对象和字符的二维数组b
b = np.array([[(4, 'a'), (1, 'a')], [(2, 'a'), (2, 'a')]], dtype='O,c')
# 计算a['f0']和b['f0']的点积
np.dot(a['f0'], b['f0'])
def test_byteswap_complex_scalar(self):
# Ticket #1259 and gh-441
# 测试复数标量的字节交换
for dtype in [np.dtype('<'+t) for t in np.typecodes['Complex']]:
z = np.array([2.2-1.1j], dtype)
x = z[0] # 总是本机字节序
y = x.byteswap()
if x.dtype.byteorder == z.dtype.byteorder:
# 小端机器
assert_equal(x, np.frombuffer(y.tobytes(), dtype=dtype.newbyteorder()))
else:
# 大端机器
assert_equal(x, np.frombuffer(y.tobytes(), dtype=dtype))
# 再次检查实部和虚部:
assert_equal(x.real, y.real.byteswap())
assert_equal(x.imag, y.imag.byteswap())
def test_structured_arrays_with_objects1(self):
# Ticket #1299
# 测试包含对象的结构化数组
stra = 'aaaa'
strb = 'bbbb'
# 创建一个包含元组的二维数组x,元组中包含一个整数和一个字符串
x = np.array([[(0, stra), (1, strb)]], 'i8,O')
# 将x中非零元素赋值为x的扁平化版本的第一个元素
x[x.nonzero()] = x.ravel()[:1]
assert_(x[0, 1] == x[0, 0])
@pytest.mark.skipif(
sys.version_info >= (3, 12),
reason="Python 3.12 has immortal refcounts, this test no longer works."
)
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_structured_arrays_with_objects2(self):
# Ticket #1299 second test
# 定义两个字符串变量
stra = 'aaaa'
strb = 'bbbb'
# 获取字符串引用计数
numb = sys.getrefcount(strb)
numa = sys.getrefcount(stra)
# 创建一个结构化数组,包含一个整数和一个对象类型的元组
x = np.array([[(0, stra), (1, strb)]], 'i8,O')
# 将非零元素替换为数组展平后的第一个元素
x[x.nonzero()] = x.ravel()[:1]
# 断言字符串引用计数未改变
assert_(sys.getrefcount(strb) == numb)
# 断言字符串引用计数增加了2
assert_(sys.getrefcount(stra) == numa + 2)
def test_duplicate_title_and_name(self):
# Ticket #1254
# 定义一个数据类型规范列表,其中包含重复的字段名
dtspec = [(('a', 'a'), 'i'), ('b', 'i')]
# 断言创建该数据类型会引发 ValueError 异常
assert_raises(ValueError, np.dtype, dtspec)
def test_signed_integer_division_overflow(self):
# Ticket #1317.
# 定义一个测试函数,用于测试不同整数类型的负数除法溢出
def test_type(t):
# 创建包含最小值的数组,并进行负数除法
min = np.array([np.iinfo(t).min])
min //= -1
# 忽略溢出警告进行测试
with np.errstate(over="ignore"):
# 对每种整数类型调用测试函数
for t in (np.int8, np.int16, np.int32, np.int64, int):
test_type(t)
def test_buffer_hashlib(self):
# 导入 hashlib 库中的 sha256 函数
from hashlib import sha256
# 创建一个整数数组,并计算其 sha256 哈希值的十六进制表示
x = np.array([1, 2, 3], dtype=np.dtype('<i4'))
assert_equal(
sha256(x).hexdigest(),
'4636993d3e1da4e9d6b8f87b79e8f7c6d018580d52661950eabc3845c5897a4d'
)
def test_0d_string_scalar(self):
# Bug #1436; the following should succeed
# 将字符串 'x' 转换为零维字符串标量数组
np.asarray('x', '>c')
def test_log1p_compiler_shenanigans(self):
# 检查在32位Intel系统上log1p函数的行为是否正常
assert_(np.isfinite(np.log1p(np.exp2(-53))))
def test_fromiter_comparison(self):
# 创建一个有符号字节类型的数组a和无符号字节类型的数组b,并进行相等性比较
a = np.fromiter(list(range(10)), dtype='b')
b = np.fromiter(list(range(10)), dtype='B')
assert_(np.all(a == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))
assert_(np.all(b == np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])))
def test_fromstring_crash(self):
# Ticket #1345: the following should not cause a crash
# 使用逗号作为分隔符解析字节字符串,不应导致崩溃,但会警告过时
with assert_warns(DeprecationWarning):
np.fromstring(b'aa, aa, 1.0', sep=',')
def test_ticket_1539(self):
# 获取所有数字类型并排除时间间隔类型,创建一个空的布尔类型数组a
dtypes = [x for x in np._core.sctypeDict.values()
if (issubclass(x, np.number)
and not issubclass(x, np.timedelta64))]
a = np.array([], np.bool) # not x[0] because it is unordered
failures = []
for x in dtypes:
b = a.astype(x)
for y in dtypes:
c = a.astype(y)
try:
# 尝试使用dot函数计算b和c的点积
d = np.dot(b, c)
except TypeError:
failures.append((x, y))
else:
# 如果点积不为0,将(x, y)添加到失败列表中
if d != 0:
failures.append((x, y))
# 如果存在失败,抛出断言错误
if failures:
raise AssertionError("Failures: %r" % failures)
def test_ticket_1538(self):
# 获取np.float32类型的数值信息对象x
x = np.finfo(np.float32)
# 遍历'eps epsneg max min resolution tiny'属性名称列表
for name in 'eps epsneg max min resolution tiny'.split():
# 断言获取的属性值的类型为np.float32
assert_equal(type(getattr(x, name)), np.float32,
err_msg=name)
def test_ticket_1434(self):
# 检查 var 和 std 方法中 out 参数的影响
data = np.array(((1, 2, 3), (4, 5, 6), (7, 8, 9)))
out = np.zeros((3,))
# 计算沿着 axis=1 方向的方差,并将结果存入 out 数组
ret = data.var(axis=1, out=out)
assert_(ret is out)
assert_array_equal(ret, data.var(axis=1))
# 计算沿着 axis=1 方向的标准差,并将结果存入 out 数组
ret = data.std(axis=1, out=out)
assert_(ret is out)
assert_array_equal(ret, data.std(axis=1))
def test_complex_nan_maximum(self):
cnan = complex(0, np.nan)
# 检查 np.maximum 函数处理复数和 NaN 的情况
assert_equal(np.maximum(1, cnan), cnan)
def test_subclass_int_tuple_assignment(self):
# ticket #1563
# 定义一个继承自 np.ndarray 的子类
class Subclass(np.ndarray):
def __new__(cls, i):
return np.ones((i,)).view(cls)
x = Subclass(5)
# 对子类实例进行索引赋值,验证是否会引发异常
x[(0,)] = 2 # 不应该引发异常
assert_equal(x[0], 2)
def test_ufunc_no_unnecessary_views(self):
# ticket #1548
# 定义一个继承自 np.ndarray 的子类
class Subclass(np.ndarray):
pass
x = np.array([1, 2, 3]).view(Subclass)
# 测试 ufunc 是否会返回不必要的视图
y = np.add(x, x, x)
assert_equal(id(x), id(y))
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_take_refcount(self):
# ticket #939
a = np.arange(16, dtype=float)
a.shape = (4, 4)
lut = np.ones((5 + 3, 4), float)
rgba = np.empty(shape=a.shape + (4,), dtype=lut.dtype)
c1 = sys.getrefcount(rgba)
try:
# 使用 take 方法填充 rgba 数组,检查引用计数是否变化
lut.take(a, axis=0, mode='clip', out=rgba)
except TypeError:
pass
c2 = sys.getrefcount(rgba)
assert_equal(c1, c2)
def test_fromfile_tofile_seeks(self):
# 在 Python 3 中,tofile/fromfile 操作会导致文件句柄不同步 (#1610)
f0 = tempfile.NamedTemporaryFile()
f = f0.file
f.write(np.arange(255, dtype='u1').tobytes())
f.seek(20)
# 从文件中读取指定数量的数据,验证读取是否正确
ret = np.fromfile(f, count=4, dtype='u1')
assert_equal(ret, np.array([20, 21, 22, 23], dtype='u1'))
assert_equal(f.tell(), 24)
f.seek(40)
# 将指定数据写入文件,验证写入位置是否正确
np.array([1, 2, 3], dtype='u1').tofile(f)
assert_equal(f.tell(), 43)
f.seek(40)
# 从文件中读取指定字节数据,验证读取结果是否正确
data = f.read(3)
assert_equal(data, b"\x01\x02\x03")
f.seek(80)
# 再次从文件中读取数据,验证读取结果是否正确
data = np.fromfile(f, dtype='u1', count=4)
assert_equal(data, np.array([84, 85, 86, 87], dtype='u1'))
f.close()
def test_complex_scalar_warning(self):
for tp in [np.csingle, np.cdouble, np.clongdouble]:
x = tp(1+2j)
# 检查在处理复数标量时是否会发出警告
assert_warns(ComplexWarning, float, x)
with suppress_warnings() as sup:
sup.filter(ComplexWarning)
assert_equal(float(x), float(x.real))
def test_complex_scalar_complex_cast(self):
for tp in [np.csingle, np.cdouble, np.clongdouble]:
x = tp(1+2j)
# 检查复数标量到复数类型的转换是否正确
assert_equal(complex(x), 1+2j)
def test_complex_boolean_cast(self):
# 测试复数类型转换为布尔类型的情况
# Ticket #2218
for tp in [np.csingle, np.cdouble, np.clongdouble]:
x = np.array([0, 0+0.5j, 0.5+0j], dtype=tp)
# 断言转换后的数组与期望的布尔数组相等
assert_equal(x.astype(bool), np.array([0, 1, 1], dtype=bool))
# 断言数组中至少存在一个非零元素
assert_(np.any(x))
# 断言数组中除第一个元素外,其余元素都为真
assert_(np.all(x[1:]))
def test_uint_int_conversion(self):
# 测试无符号整数与有符号整数之间的转换
x = 2**64 - 1
assert_equal(int(np.uint64(x)), x)
def test_duplicate_field_names_assign(self):
# 测试重复的字段名分配
ra = np.fromiter(((i*3, i*2) for i in range(10)), dtype='i8,f8')
ra.dtype.names = ('f1', 'f2')
repr(ra) # 不应该导致分段错误
# 断言尝试设置字段名为重复时会抛出 ValueError 异常
assert_raises(ValueError, setattr, ra.dtype, 'names', ('f1', 'f1'))
def test_eq_string_and_object_array(self):
# 测试字符串数组与对象数组的相等比较
# 来自电子邮件线程 "__eq__ with str and object" (Keith Goodman)
a1 = np.array(['a', 'b'], dtype=object)
a2 = np.array(['a', 'c'])
assert_array_equal(a1 == a2, [True, False])
assert_array_equal(a2 == a1, [True, False])
def test_nonzero_byteswap(self):
# 测试非零元素查找与字节交换
a = np.array([0x80000000, 0x00000080, 0], dtype=np.uint32)
a.dtype = np.float32
assert_equal(a.nonzero()[0], [1])
a = a.byteswap()
a = a.view(a.dtype.newbyteorder())
assert_equal(a.nonzero()[0], [1]) # 如果 nonzero() 忽略交换则为 [0]
def test_empty_mul(self):
# 测试空数组乘法操作
a = np.array([1.])
a[1:1] *= 2
assert_equal(a, [1.])
def test_array_side_effect(self):
# 测试数组副作用
# 在 ctors.c 中,discover_itemsize 调用 PyObject_Length 时没有检查返回码,
# 导致无法获取数字 2 的长度,异常一直存在,直到某处检查到 PyErr_Occurred() 并返回错误。
assert_equal(np.dtype('S10').itemsize, 10)
np.array([['abc', 2], ['long ', '0123456789']], dtype=np.bytes_)
assert_equal(np.dtype('S10').itemsize, 10)
def test_any_float(self):
# 测试浮点数的 all 和 any 操作
# 对于浮点数,all 和 any 的行为
a = np.array([0.1, 0.9])
assert_(np.any(a))
assert_(np.all(a))
def test_large_float_sum(self):
# 测试大浮点数数组求和
a = np.arange(10000, dtype='f')
assert_equal(a.sum(dtype='d'), a.astype('d').sum())
def test_ufunc_casting_out(self):
# 测试 ufunc 中的输出类型转换
a = np.array(1.0, dtype=np.float32)
b = np.array(1.0, dtype=np.float64)
c = np.array(1.0, dtype=np.float32)
np.add(a, b, out=c)
assert_equal(c, 2.0)
def test_array_scalar_contiguous(self):
# 测试数组标量的连续性
# 数组标量既是 C 连续的也是 Fortran 连续的
assert_(np.array(1.0).flags.c_contiguous)
assert_(np.array(1.0).flags.f_contiguous)
assert_(np.array(np.float32(1.0)).flags.c_contiguous)
assert_(np.array(np.float32(1.0)).flags.f_contiguous)
def test_squeeze_contiguous(self):
# 测试squeeze方法在处理连续数组时的情况
a = np.zeros((1, 2)).squeeze() # 创建一个2列的零数组,并使用squeeze方法压缩维度
b = np.zeros((2, 2, 2), order='F')[:, :, ::2].squeeze() # 创建一个Fortran顺序的3维零数组,并在某些轴上进行切片和squeeze操作
assert_(a.flags.c_contiguous) # 断言a数组是C连续的
assert_(a.flags.f_contiguous) # 断言a数组是Fortran连续的
assert_(b.flags.f_contiguous) # 断言b数组是Fortran连续的
def test_squeeze_axis_handling(self):
# 测试squeeze方法在处理轴参数时的行为
# 确保在squeeze时正确处理不支持轴参数的对象
class OldSqueeze(np.ndarray):
# 自定义的继承自ndarray的类
def __new__(cls, input_array):
obj = np.asarray(input_array).view(cls)
return obj
# 对于旧版本的API,可能没有期望squeeze方法有轴参数的预期行为
# 注意:这个例子有些人为,旨在模拟旧API的预期行为以防止回归
def squeeze(self):
return super().squeeze() # 调用父类的squeeze方法
oldsqueeze = OldSqueeze(np.array([[1],[2],[3]]))
# 如果没有指定轴参数,旧API的预期行为应该得到正确的结果
assert_equal(np.squeeze(oldsqueeze),
np.array([1,2,3]))
# 同样地,axis=None在旧API的预期行为下应该正常工作
assert_equal(np.squeeze(oldsqueeze, axis=None),
np.array([1,2,3]))
# 然而,指定任何具体的轴参数应该在旧API的规范下引发TypeError异常
with assert_raises(TypeError):
np.squeeze(oldsqueeze, axis=1)
# 当使用无效的轴参数时应该有相同的行为检查
with assert_raises(TypeError):
np.squeeze(oldsqueeze, axis=0)
# 新API知道如何处理轴参数,如果试图squeeze一个长度不为1的轴,会引发ValueError异常
with assert_raises(ValueError):
np.squeeze(np.array([[1],[2],[3]]), axis=0)
def test_reduce_contiguous(self):
# 测试reduce操作在处理连续数组时的情况
a = np.add.reduce(np.zeros((2, 1, 2)), (0, 1)) # 对一个3维零数组进行reduce操作,指定轴参数为(0, 1)
b = np.add.reduce(np.zeros((2, 1, 2)), 1) # 对一个3维零数组进行reduce操作,指定轴参数为1
assert_(a.flags.c_contiguous) # 断言a数组是C连续的
assert_(a.flags.f_contiguous) # 断言a数组是Fortran连续的
assert_(b.flags.c_contiguous) # 断言b数组是C连续的
# 使用 pytest 的装饰器标记此测试,在 Pyston 环境下跳过测试,因为 Pyston 禁用了递归检查
@pytest.mark.skipif(IS_PYSTON, reason="Pyston disables recursion checking")
def test_object_array_self_reference(self):
# 创建一个对象数组 a,元素为整数 0,dtype 为 object
a = np.array(0, dtype=object)
# 将数组 a 中的空元组位置赋值为数组 a 自身,创建自引用
a[()] = a
# 断言应触发递归错误异常,因为存在自引用
assert_raises(RecursionError, int, a)
assert_raises(RecursionError, float, a)
# 将数组 a 中的空元组位置重新赋值为 None,解除自引用
a[()] = None
@pytest.mark.skipif(IS_PYSTON, reason="Pyston disables recursion checking")
def test_object_array_circular_reference(self):
# 同样测试循环引用的情况
a = np.array(0, dtype=object)
b = np.array(0, dtype=object)
a[()] = b
b[()] = a
# 断言应触发递归错误异常,由于 NumPy 当前不支持 tp_traverse,无法检测循环引用,因此解除循环引用
assert_raises(RecursionError, int, a)
# 解除循环引用
a[()] = None
# 下面的代码导致 a 变成上述自引用的形式
a = np.array(0, dtype=object)
a[...] += 1
# 断言 a 的值应为 1
assert_equal(a, 1)
def test_object_array_nested(self):
# 但引用到不同数组的情况是可以的
a = np.array(0, dtype=object)
b = np.array(0, dtype=object)
a[()] = b
# 断言 int(a) 和 float(a) 均等于 0
assert_equal(int(a), int(0))
assert_equal(float(a), float(0))
def test_object_array_self_copy(self):
# 对象数组在复制到自身之前,先 DECREF 再 INCREF 可能会导致段错误 (gh-3787)
a = np.array(object(), dtype=object)
np.copyto(a, a)
if HAS_REFCOUNT:
# 断言 a[()] 的引用计数为 2
assert_(sys.getrefcount(a[()]) == 2)
# 访问 a[()] 的 __class__ 属性,如果对象已被删除,将导致段错误
a[()].__class__
def test_zerosize_accumulate(self):
# "Ticket #1733"
x = np.array([[42, 0]], dtype=np.uint32)
# 断言对 x[:-1, 0] 进行累积加法操作应得到空列表
assert_equal(np.add.accumulate(x[:-1, 0]), [])
def test_objectarray_setfield(self):
# Setfield 不应该用非对象数据覆盖对象字段
x = np.array([1, 2, 3], dtype=object)
# 断言设置 x 的第一个元素为 np.int32 类型的值 4 时应触发 TypeError
assert_raises(TypeError, x.setfield, 4, np.int32, 0)
def test_setting_rank0_string(self):
# "Ticket #1736"
s1 = b"hello1"
s2 = b"hello2"
# 创建一个 dtype 为 'S10' 的零维数组 a,设置其值为 s1
a = np.zeros((), dtype="S10")
a[()] = s1
# 断言 a 应等于 np.array(s1)
assert_equal(a, np.array(s1))
# 将 a 的值设置为 np.array(s2),断言 a 应等于 np.array(s2)
a[()] = np.array(s2)
assert_equal(a, np.array(s2))
# 创建一个 dtype 为 'f4' 的零维数组 a,将其值设置为 3
a = np.zeros((), dtype='f4')
a[()] = 3
# 断言 a 应等于 np.array(3)
assert_equal(a, np.array(3))
# 将 a 的值设置为 np.array(4),断言 a 应等于 np.array(4)
a[()] = np.array(4)
assert_equal(a, np.array(4))
def test_string_astype(self):
# "Ticket #1748"
s1 = b'black'
s2 = b'white'
s3 = b'other'
# 创建一个二维数组 a,包含字符串 s1、s2 和 s3
a = np.array([[s1], [s2], [s3]])
# 断言 a 的 dtype 应为 'S5'
assert_equal(a.dtype, np.dtype('S5'))
# 将 a 转换为 dtype 为 'S0' 的数组 b,断言 b 的 dtype 为 'S5'
b = a.astype(np.dtype('S0'))
assert_equal(b.dtype, np.dtype('S5'))
def test_ticket_1756(self):
# Ticket #1756
# 定义字节串,长度为16
s = b'0123456789abcdef'
# 创建一个包含5个相同字节串的数组
a = np.array([s]*5)
# 循环从1到16
for i in range(1, 17):
# 使用不同的字符串长度创建一个新的数组a1
a1 = np.array(a, "|S%d" % i)
# 创建一个包含5个相同长度子串的数组a2
a2 = np.array([s[:i]]*5)
# 断言两个数组相等
assert_equal(a1, a2)
def test_fields_strides(self):
# "gh-2355"
# 从字节缓冲区创建一个结构化数组r,其中包含两个字段:'f0'为i4类型,'f1'为(2,3)u2类型
r = np.frombuffer(b'abcdefghijklmnop'*4*3, dtype='i4,(2,3)u2')
# 断言切片的字段 'f1' 等于字段 'f1' 的切片
assert_equal(r[0:3:2]['f1'], r['f1'][0:3:2])
# 断言切片后的第一个元素的字段 'f1' 等于整体切片后的第一个元素的字段 'f1'
assert_equal(r[0:3:2]['f1'][0], r[0:3:2][0]['f1'])
# 断言切片后的第一个元素的字段 'f1' 与它的空切片结果相等
assert_equal(r[0:3:2]['f1'][0][()], r[0:3:2][0]['f1'][()])
# 断言切片后的第一个元素的字段 'f1' 的步长等于整体切片后的第一个元素的字段 'f1' 的步长
assert_equal(r[0:3:2]['f1'][0].strides, r[0:3:2][0]['f1'].strides)
def test_alignment_update(self):
# Check that alignment flag is updated on stride setting
# 创建一个包含10个元素的数组a
a = np.arange(10)
# 断言数组a的对齐标志为真
assert_(a.flags.aligned)
# 将数组a的步长设置为3
a.strides = 3
# 断言数组a的对齐标志为假
assert_(not a.flags.aligned)
def test_ticket_1770(self):
# Should not segfault on python 3k
import numpy as np
try:
# 创建一个包含一个字段 'f1' 的零数组a,类型为浮点数
a = np.zeros((1,), dtype=[('f1', 'f')])
# 设置字段 'f1' 的值为1
a['f1'] = 1
# 设置字段 'f2' 的值为1,这里会引发异常
a['f2'] = 1
except ValueError:
# 如果捕获到 ValueError 异常则继续
pass
except Exception:
# 如果捕获到其它异常则抛出 AssertionError
raise AssertionError
def test_ticket_1608(self):
# "x.flat shouldn't modify data"
# 创建一个2x2的数组x,并转置
x = np.array([[1, 2], [3, 4]]).T
# 将数组x展平后生成一个新的数组
np.array(x.flat)
# 断言展平后的数组与原始数组x相等
assert_equal(x, [[1, 3], [2, 4]])
def test_pickle_string_overwrite(self):
import re
# 创建一个包含单个整数1的字节类型数组data
data = np.array([1], dtype='b')
# 对data进行pickle序列化,协议版本为1
blob = pickle.dumps(data, protocol=1)
# 使用pickle反序列化blob,重新赋值给data
data = pickle.loads(blob)
# 检查loads操作不会覆盖内部字符串
# 使用正则表达式将字符串"a_"中的"a_"替换成"\x01_"
s = re.sub("a(.)", "\x01\\1", "a_")
# 断言替换后的第一个字符为"\x01"
assert_equal(s[0], "\x01")
# 将data的第一个元素设置为0x6a
data[0] = 0x6a
# 再次使用正则表达式将字符串"a_"中的"a_"替换成"\x01_"
s = re.sub("a(.)", "\x01\\1", "a_")
# 断言替换后的第一个字符为"\x01"
assert_equal(s[0], "\x01")
def test_pickle_bytes_overwrite(self):
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
# 创建一个包含单个整数1的字节类型数组data
data = np.array([1], dtype='b')
# 使用pickle对data进行序列化,协议版本从2到最高协议版本
data = pickle.loads(pickle.dumps(data, protocol=proto))
# 将data的第一个元素设置为0x7d
data[0] = 0x7d
# 创建一个ASCII编码的字节字符串bytestring,内容为"\x01 "
bytestring = "\x01 ".encode('ascii')
# 断言bytestring的前两个字节等于"\x01"
assert_equal(bytestring[0:1], '\x01'.encode('ascii'))
def test_pickle_py2_array_latin1_hack(self):
# 检查在Py3中支持encoding='latin1'的反序列化是否正常工作
# Python2中pickle.dumps(numpy.array([129], dtype='b'))的输出结果
data = b"cnumpy.core.multiarray\n_reconstruct\np0\n(cnumpy\nndarray\np1\n(I0\ntp2\nS'b'\np3\ntp4\nRp5\n(I1\n(I1\ntp6\ncnumpy\ndtype\np7\n(S'i1'\np8\nI0\nI1\ntp9\nRp10\n(I3\nS'|'\np11\nNNNI-1\nI-1\nI0\ntp12\nbI00\nS'\\x81'\np13\ntp14\nb."
# 使用'latin1'编码反序列化data,结果应该是一个包含单个元素129的数组
result = pickle.loads(data, encoding='latin1')
# 断言result与预期的数组相等
assert_array_equal(result, np.array([129]).astype('b'))
# 应该不会导致段错误
assert_raises(Exception, pickle.loads, data, encoding='koi8-r')
def test_pickle_py2_scalar_latin1_hack(self):
# 测试在 Python 3 中支持 encoding='latin1' 的标量反序列化修补是否正确工作。
# 定义测试数据
datas = [
# (original, python2_pickle, koi8r_validity)
(np.str_('\u6bd2'),
b"cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n(S'U1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI4\nI4\nI0\ntp6\nbS'\\xd2k\\x00\\x00'\np7\ntp8\nRp9\n.", # noqa
'invalid'),
(np.float64(9e123),
b"cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n(S'f8'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'<'\np5\nNNNI-1\nI-1\nI0\ntp6\nbS'O\\x81\\xb7Z\\xaa:\\xabY'\np7\ntp8\nRp9\n.", # noqa
'invalid'),
# KOI8-R 编码和 latin1 编码中不同的 8 位码点
(np.bytes_(b'\x9c'),
b"cnumpy.core.multiarray\nscalar\np0\n(cnumpy\ndtype\np1\n(S'S1'\np2\nI0\nI1\ntp3\nRp4\n(I3\nS'|'\np5\nNNNI1\nI1\nI0\ntp6\nbS'\\x9c'\np7\ntp8\nRp9\n.", # noqa
'different'),
]
# 对每组数据进行测试
for original, data, koi8r_validity in datas:
# 使用 Latin1 解码反序列化数据
result = pickle.loads(data, encoding='latin1')
assert_equal(result, original)
# 在非 Latin1 编码(例如 KOI8-R)下解码可能产生错误的结果,但不应导致段错误
if koi8r_validity == 'different':
# Unicode 码点在 Latin1 中,但在 KOI8-R 中不同,导致静默的错误结果
result = pickle.loads(data, encoding='koi8-r')
assert_(result != original)
elif koi8r_validity == 'invalid':
# Unicode 码点超出 Latin1 范围,因此会导致编码异常
assert_raises(
ValueError, pickle.loads, data, encoding='koi8-r'
)
else:
raise ValueError(koi8r_validity)
def test_structured_type_to_object(self):
# 创建一个结构化数组并将其转换为对象数组的测试
a_rec = np.array([(0, 1), (3, 2)], dtype='i4,i8')
a_obj = np.empty((2,), dtype=object)
a_obj[0] = (0, 1)
a_obj[1] = (3, 2)
# 使用 astype 将记录数组转换为对象数组
assert_equal(a_rec.astype(object), a_obj)
# 使用 '=' 将记录数组复制到对象数组
b = np.empty_like(a_obj)
b[...] = a_rec
assert_equal(b, a_obj)
# 使用 '=' 将对象数组转换回记录数组
b = np.empty_like(a_rec)
b[...] = a_obj
assert_equal(b, a_rec)
# 定义一个测试函数,用于测试对象列表的赋值行为
def test_assign_obj_listoflists(self):
# Ticket # 1870
# 内部列表应该被赋给对象的元素
# 创建一个长度为4的全零数组,数据类型为object
a = np.zeros(4, dtype=object)
# 复制数组a,得到数组b
b = a.copy()
# 将不同的列表赋给数组a的不同元素
a[0] = [1]
a[1] = [2]
a[2] = [3]
a[3] = [4]
# 通过广播赋值将列表 [[1], [2], [3], [4]] 赋给数组b
b[...] = [[1], [2], [3], [4]]
# 断言数组a和数组b相等
assert_equal(a, b)
# 第一维度应该被广播
# 创建一个2x2的全零数组,数据类型为object,通过广播赋值将列表 [1, 2] 赋给数组a
a = np.zeros((2, 2), dtype=object)
a[...] = [[1, 2]]
# 断言数组a和预期的结果相等
assert_equal(a, [[1, 2], [1, 2]])
@pytest.mark.slow_pypy
# 标记为慢速测试,适用于PyPy
def test_memoryleak(self):
# Ticket #1917 - ensure that array data doesn't leak
# 循环1000次,每次创建一个大小为100000000的字节类型的空数组a,防止内存泄漏
for i in range(1000):
a = np.empty((100000000,), dtype='i1')
del a
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
# 如果没有引用计数的支持,则跳过该测试
def test_ufunc_reduce_memoryleak(self):
# 创建一个长度为6的数组a
a = np.arange(6)
# 获取数组a的引用计数
acnt = sys.getrefcount(a)
# 对数组a执行reduce操作
np.add.reduce(a)
# 断言数组a的引用计数没有变化
assert_equal(sys.getrefcount(a), acnt)
def test_search_sorted_invalid_arguments(self):
# Ticket #2021, should not segfault.
# 创建一个日期类型的数组x,范围是从0到4(不包括4)
x = np.arange(0, 4, dtype='datetime64[D]')
# 断言搜索整数1会引发TypeError异常
assert_raises(TypeError, x.searchsorted, 1)
def test_string_truncation(self):
# Ticket #1990 - Data can be truncated in creation of an array from a
# mixed sequence of numeric values and strings (gh-2583)
# 遍历不同的值,包括布尔值、整数、浮点数和复数
for val in [True, 1234, 123.4, complex(1, 234)]:
# 遍历转换函数和对应的数据类型
for tostr, dtype in [(asunicode, "U"), (asbytes, "S")]:
# 创建一个包含val和字符串'xx'的数组b,数据类型为dtype
b = np.array([val, tostr('xx')], dtype=dtype)
# 断言数组b的第一个元素与val经过转换函数后相等
assert_equal(tostr(b[0]), tostr(val))
# 创建一个包含字符串'xx'和val的数组b,数据类型为dtype
b = np.array([tostr('xx'), val], dtype=dtype)
# 断言数组b的第二个元素与val经过转换函数后相等
assert_equal(tostr(b[1]), tostr(val))
# 测试较长的字符串情况
# 创建一个包含val和长字符串'xxxxxxxxxx'的数组b,数据类型为dtype
b = np.array([val, tostr('xxxxxxxxxx')], dtype=dtype)
# 断言数组b的第一个元素与val经过转换函数后相等
assert_equal(tostr(b[0]), tostr(val))
# 创建一个包含长字符串'xxxxxxxxxx'和val的数组b,数据类型为dtype
b = np.array([tostr('xxxxxxxxxx'), val], dtype=dtype)
# 断言数组b的第二个元素与val经过转换函数后相等
assert_equal(tostr(b[1]), tostr(val))
def test_string_truncation_ucs2(self):
# Ticket #2081. Python compiled with two byte unicode
# can lead to truncation if itemsize is not properly
# adjusted for NumPy's four byte unicode.
# 创建一个包含字符串'abcd'的数组a
a = np.array(['abcd'])
# 断言数组a的元素大小为16
assert_equal(a.dtype.itemsize, 16)
def test_unique_stable(self):
# Ticket #2063 must always choose stable sort for argsort to
# get consistent results
# 创建一个包含重复值的数组v
v = np.array(([0]*5 + [1]*6 + [2]*6)*4)
# 执行unique操作,返回唯一值和其对应的索引
res = np.unique(v, return_index=True)
# 预期的唯一值和对应的索引
tgt = (np.array([0, 1, 2]), np.array([ 0, 5, 11]))
# 断言结果与预期相等
assert_equal(res, tgt)
def test_unicode_alloc_dealloc_match(self):
# Ticket #1578, the mismatch only showed up when running
# python-debug for python versions >= 2.7, and then as
# a core dump and error message.
# 创建一个包含字符串'abc'的数组a,数据类型为np.str_,然后获取它的第一个元素
a = np.array(['abc'], dtype=np.str_)[0]
# 删除数组a的引用
del a
def test_refcount_error_in_clip(self):
# Ticket #1588
# 创建一个包含两个元素的零数组,数据类型为大端字节序的16位整数,然后调用clip方法将数组中的值限制在最小值为0的范围内。
a = np.zeros((2,), dtype='>i2').clip(min=0)
# 将数组a与自身相加,得到新的数组x。
x = a + a
# 将数组x转换为字符串。
y = str(x)
# 检查转换后的字符串是否为预期的结果 "[0 0]"
assert_(y == "[0 0]")
def test_searchsorted_wrong_dtype(self):
# Ticket #2189, it used to segfault, so we check that it raises the
# proper exception.
# 创建一个包含单个元素 ('a', 1) 的结构化数组,数据类型为一个字节和一个整数。
a = np.array([('a', 1)], dtype='S1, int')
# 断言调用np.searchsorted函数时会引发TypeError异常,因为数组a的数据类型不符合要求。
assert_raises(TypeError, np.searchsorted, a, 1.2)
# Ticket #2066, similar problem:
# 使用给定的数据类型创建一个包含两个记录的结构化数组。
dtype = np.rec.format_parser(['i4', 'i4'], [], [])
a = np.recarray((2,), dtype)
# 使用给定的元组值来填充数组a,这些值不符合其预期的数据类型。
a[...] = [(1, 2), (3, 4)]
# 断言调用np.searchsorted函数时会引发TypeError异常,因为数组a的数据类型不符合要求。
assert_raises(TypeError, np.searchsorted, a, 1)
def test_complex64_alignment(self):
# Issue gh-2668 (trac 2076), segfault on sparc due to misalignment
# 定义一个复数数据类型dtt为np.complex64。
dtt = np.complex64
# 创建一个包含10个元素的一维数组arr,数据类型为dtt。
arr = np.arange(10, dtype=dtt)
# 将一维数组arr重塑为二维数组arr2,形状为(2, 5)。
arr2 = np.reshape(arr, (2, 5))
# 将二维数组arr2按Fortran顺序(列主序)转换为字节序列。
data_str = arr2.tobytes('F')
# 使用给定的字节序列和参数创建一个新的数组data_back,按照Fortran顺序(列主序)读取数据。
data_back = np.ndarray(arr2.shape,
arr2.dtype,
buffer=data_str,
order='F')
# 断言数组arr2与数组data_back相等。
assert_array_equal(arr2, data_back)
def test_structured_count_nonzero(self):
# 创建一个结构化数组arr,包含一个整数和两个二维整数数组的元素,然后选择数组的第一个元素。
arr = np.array([0, 1]).astype('i4, 2i4')[:1]
# 计算数组arr中非零元素的数量。
count = np.count_nonzero(arr)
# 断言计算得到的非零元素数量为0。
assert_equal(count, 0)
def test_copymodule_preserves_f_contiguity(self):
# 创建一个包含两行两列的空数组a,按Fortran顺序(列主序)存储。
a = np.empty((2, 2), order='F')
# 使用浅拷贝创建数组b,使用深拷贝创建数组c。
b = copy.copy(a)
c = copy.deepcopy(a)
# 断言数组b和c都具有Fortran顺序(列主序)的标志。
assert_(b.flags.fortran)
assert_(b.flags.f_contiguous)
assert_(c.flags.fortran)
assert_(c.flags.f_contiguous)
def test_fortran_order_buffer(self):
import numpy as np
# 创建一个包含两个元素的二维数组a,数据类型为Unicode字符串,按Fortran顺序(列主序)存储。
a = np.array([['Hello', 'Foob']], dtype='U5', order='F')
# 使用数组a的数据缓冲区创建一个新的数组arr,数据类型为Unicode字符串,形状为[1, 2, 5]。
arr = np.ndarray(shape=[1, 2, 5], dtype='U1', buffer=a)
# 创建一个预期的二维数组arr2,其元素为Unicode字符。
arr2 = np.array([[['H', 'e', 'l', 'l', 'o'],
['F', 'o', 'o', 'b', '']]])
# 断言数组arr与预期的数组arr2相等。
assert_array_equal(arr, arr2)
def test_assign_from_sequence_error(self):
# Ticket #4024.
# 创建一个包含三个整数元素的数组arr。
arr = np.array([1, 2, 3])
# 断言调用arr.__setitem__方法时会引发ValueError异常,因为赋值序列的长度不匹配。
assert_raises(ValueError, arr.__setitem__, slice(None), [9, 9])
# 使用单个整数值来替换数组arr中的所有元素。
arr.__setitem__(slice(None), [9])
# 断言数组arr的所有元素都等于9。
assert_equal(arr, [9, 9, 9])
def test_format_on_flex_array_element(self):
# Ticket #4369.
# 定义一个结构化数据类型dt,包含一个日期和一个浮点数。
dt = np.dtype([('date', '<M8[D]'), ('val', '<f8')])
# 创建一个包含单个元素的结构化数组arr,使用给定的数据类型dt。
arr = np.array([('2000-01-01', 1)], dt)
# 使用字符串格式化操作将数组arr的第一个元素转换为字符串。
formatted = '{0}'.format(arr[0])
# 断言格式化后的字符串与arr的第一个元素的字符串表示相等。
assert_equal(formatted, str(arr[0]))
def test_deepcopy_on_0d_array(self):
# Ticket #3311.
# 创建一个零维数组arr,包含单个整数元素。
arr = np.array(3)
# 使用深拷贝创建数组arr_cp。
arr_cp = copy.deepcopy(arr)
# 断言数组arr和arr_cp在值、形状和类型上都相等。
assert_equal(arr, arr_cp)
assert_equal(arr.shape, arr_cp.shape)
assert_equal(int(arr), int(arr_cp))
# 断言数组arr和arr_cp不是同一个对象。
assert_(arr is not arr_cp)
# 断言arr_cp是arr的一个实例。
assert_(isinstance(arr_cp, type(arr)))
def test_deepcopy_F_order_object_array(self):
# Ticket #6456.
# 创建包含字典对象的 numpy 数组 arr,使用列优先(Fortran 风格)存储顺序
a = {'a': 1}
b = {'b': 2}
arr = np.array([[a, b], [a, b]], order='F')
# 对 arr 进行深拷贝
arr_cp = copy.deepcopy(arr)
# 断言 arr 和 arr_cp 相等
assert_equal(arr, arr_cp)
# 断言 arr 和 arr_cp 不是同一个对象
assert_(arr is not arr_cp)
# 确保我们实际拷贝了对象
assert_(arr[0, 1] is not arr_cp[1, 1])
# 确保允许引用同一个对象
assert_(arr[0, 1] is arr[1, 1])
# 检查拷贝后的对象引用是否正确
assert_(arr_cp[0, 1] is arr_cp[1, 1])
def test_deepcopy_empty_object_array(self):
# Ticket #8536.
# 深拷贝应该成功
a = np.array([], dtype=object)
b = copy.deepcopy(a)
assert_(a.shape == b.shape)
def test_bool_subscript_crash(self):
# gh-4494
# 创建一个记录数组 c
c = np.rec.array([(1, 2, 3), (4, 5, 6)])
# 使用布尔值数组进行子脚本,创建一个掩码数组 masked
masked = c[np.array([True, False])]
# 获取掩码数组的基础对象
base = masked.base
# 删除变量 masked 和 c
del masked, c
# 访问基础对象的数据类型
base.dtype
def test_richcompare_crash(self):
# gh-4613
import operator as op
# 创建一个虚拟类 Foo,其中 __array__ 方法抛出异常
class Foo:
__array_priority__ = 1002
def __array__(self, *args, **kwargs):
raise Exception()
rhs = Foo()
lhs = np.array(1)
# 测试各种比较操作符是否引发 TypeError 异常
for f in [op.lt, op.le, op.gt, op.ge]:
assert_raises(TypeError, f, lhs, rhs)
# 断言 lhs 和 rhs 不相等
assert_(not op.eq(lhs, rhs))
# 断言 lhs 和 rhs 不相等
assert_(op.ne(lhs, rhs))
def test_richcompare_scalar_and_subclass(self):
# gh-4709
# 创建一个继承自 np.ndarray 的子类 Foo
class Foo(np.ndarray):
def __eq__(self, other):
return "OK"
# 创建一个 ndarray 对象 x,并使用 Foo 类视图
x = np.array([1, 2, 3]).view(Foo)
# 断言对比 10 和 x 的结果为 "OK"
assert_equal(10 == x, "OK")
# 断言对比 np.int32(10) 和 x 的结果为 "OK"
assert_equal(np.int32(10) == x, "OK")
# 断言对比 np.array([10]) 和 x 的结果为 "OK"
assert_equal(np.array([10]) == x, "OK")
def test_pickle_empty_string(self):
# gh-3926
# 对空字符串进行序列化和反序列化的测试
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
test_string = np.bytes_('')
assert_equal(pickle.loads(
pickle.dumps(test_string, protocol=proto)), test_string)
def test_frompyfunc_many_args(self):
# gh-5672
# 定义一个接受任意参数的函数 passer
def passer(*args):
pass
# 测试使用 passer 函数创建 np.frompyfunc 时是否引发 ValueError 异常
assert_raises(ValueError, np.frompyfunc, passer, 64, 1)
def test_repeat_broadcasting(self):
# gh-5743
# 创建一个三维数组 a
a = np.arange(60).reshape(3, 4, 5)
# 对 a 进行 repeat 操作,测试不同轴上的广播行为
for axis in chain(range(-a.ndim, a.ndim), [None]):
assert_equal(a.repeat(2, axis=axis), a.repeat([2], axis=axis))
def test_frompyfunc_nout_0(self):
# gh-2014
# 定义一个修改输入数组的函数 f
def f(x):
x[0], x[-1] = x[-1], x[0]
# 创建一个包含对象的二维数组 a
a = np.array([[1, 2, 3], [4, 5], [6, 7, 8, 9]], dtype=object)
# 使用 np.frompyfunc 创建一个 ufunc uf,它的输出数量为 0
assert_equal(np.frompyfunc(f, 1, 0)(a), ())
# 预期的结果数组
expected = np.array([[3, 2, 1], [5, 4], [9, 7, 8, 6]], dtype=object)
# 断言数组 a 的内容与预期的结果相同
assert_array_equal(a, expected)
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_leak_in_structured_dtype_comparison(self):
# 根据 GitHub issue 6250 进行测试,检查结构化数据类型比较中的内存泄漏问题
recordtype = np.dtype([('a', np.float64),
('b', np.int32),
('d', (str, 5))])
# 简单情况下的测试
a = np.zeros(2, dtype=recordtype)
for i in range(100):
a == a
assert_(sys.getrefcount(a) < 10)
# 在报告的 bug 案例中进行测试
before = sys.getrefcount(a)
u, v = a[0], a[1]
u == v
del u, v
gc.collect()
after = sys.getrefcount(a)
assert_equal(before, after)
def test_empty_percentile(self):
# 根据 GitHub issue 6530 和 6553 进行测试,验证空百分位数计算
assert_array_equal(np.percentile(np.arange(10), []), np.array([]))
def test_void_compare_segfault(self):
# 根据 GitHub issue 6922 进行测试,确保空类型比较不导致段错误
a = np.ones(3, dtype=[('object', 'O'), ('int', '<i2')])
a.sort()
def test_reshape_size_overflow(self):
# 根据 GitHub issue 7455 进行测试,检查重塑操作中的大小溢出问题
a = np.ones(20)[::2]
if np.dtype(np.intp).itemsize == 8:
# 64 位情况下,以下是 2**63 + 5 的质因数,乘在一起作为 int64 时会溢出到总大小为 10
new_shape = (2, 13, 419, 691, 823, 2977518503)
else:
# 32 位情况下,以下是 2**31 + 5 的质因数,乘在一起作为 int32 时会溢出到总大小为 10
new_shape = (2, 7, 7, 43826197)
assert_raises(ValueError, a.reshape, new_shape)
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
reason="PyPy bug in error formatting")
# 测试无效的结构化数据类型的情况
def test_invalid_structured_dtypes(self):
# gh-2865: 对应GitHub issue编号
# 将Python对象映射到其他数据类型
assert_raises(ValueError, np.dtype, ('O', [('name', 'i8')]))
assert_raises(ValueError, np.dtype, ('i8', [('name', 'O')]))
assert_raises(ValueError, np.dtype,
('i8', [('name', [('name', 'O')])]))
assert_raises(ValueError, np.dtype, ([('a', 'i4'), ('b', 'i4')], 'O'))
assert_raises(ValueError, np.dtype, ('i8', 'O'))
# 字典中元组元素数量或类型错误
assert_raises(ValueError, np.dtype,
('i', {'name': ('i', 0, 'title', 'oops')}))
assert_raises(ValueError, np.dtype,
('i', {'name': ('i', 'wrongtype', 'title')}))
# 从1.13版本开始不再允许的情况
assert_raises(ValueError, np.dtype,
([('a', 'O'), ('b', 'O')], [('c', 'O'), ('d', 'O')]))
# 作为特殊情况允许存在,参见gh-2798
a = np.ones(1, dtype=('O', [('name', 'O')]))
assert_equal(a[0], 1)
# 特别是,上述联合数据类型(以及总体上的联合数据类型)应该主要表现得像主要的(object)数据类型:
assert a[0] is a.item()
assert type(a[0]) is int
# 测试正确的哈希字典
def test_correct_hash_dict(self):
# gh-8887 - 即使设置了tp_hash,__hash__仍然为None的问题
all_types = set(np._core.sctypeDict.values()) - {np.void}
for t in all_types:
val = t()
try:
hash(val)
except TypeError as e:
assert_equal(t.__hash__, None)
else:
assert_(t.__hash__ != None)
# 测试标量复制
def test_scalar_copy(self):
scalar_types = set(np._core.sctypeDict.values())
values = {
np.void: b"a",
np.bytes_: b"a",
np.str_: "a",
np.datetime64: "2017-08-25",
}
for sctype in scalar_types:
item = sctype(values.get(sctype, 1))
item2 = copy.copy(item)
assert_equal(item, item2)
# 测试void类型的元素内存视图
def test_void_item_memview(self):
va = np.zeros(10, 'V4')
x = va[:1].item()
va[0] = b'\xff\xff\xff\xff'
del va
assert_equal(x, b'\x00\x00\x00\x00')
# 测试void类型的getitem方法
def test_void_getitem(self):
# 测试修复gh-11668的问题
assert_(np.array([b'a'], 'V1').astype('O') == b'a')
assert_(np.array([b'ab'], 'V2').astype('O') == b'ab')
assert_(np.array([b'abc'], 'V3').astype('O') == b'abc')
assert_(np.array([b'abcd'], 'V4').astype('O') == b'abcd')
def test_structarray_title(self):
# 测试结构化数组标题功能
# 在 PyPy 上曾经存在的段错误问题,由于 NPY_TITLE_KEY 功能不正常,
# 导致结构化数组字段项的双重减少引用:
# 参考链接:https://bitbucket.org/pypy/pypy/issues/2789
for j in range(5):
# 创建一个结构化数组,包含一个字段 'x',类型为对象数组
structure = np.array([1], dtype=[(('x', 'X'), np.object_)])
# 在结构化数组的第一个元素中设置 'x' 字段的值为一个数组 [2]
structure[0]['x'] = np.array([2])
# 手动触发垃圾回收
gc.collect()
def test_dtype_scalar_squeeze(self):
# 测试数据类型标量的 squeeze 方法
# gh-11384
values = {
'S': b"a",
'M': "2018-06-20",
}
# 遍历所有的数据类型字符
for ch in np.typecodes['All']:
# 跳过 'O' 类型,即 Python 对象类型
if ch in 'O':
continue
# 获取数据类型对象
sctype = np.dtype(ch).type
# 获取相应类型的值,若不存在则默认为 3
scvalue = sctype(values.get(ch, 3))
# 对于每个轴(None 或 ()),测试 squeeze 方法
for axis in [None, ()]:
squeezed = scvalue.squeeze(axis=axis)
# 断言 squeeze 后的结果与原始值相等
assert_equal(squeezed, scvalue)
# 断言 squeeze 后的对象类型与原始值的类型相同
assert_equal(type(squeezed), type(scvalue))
def test_field_access_by_title(self):
# 测试通过标题访问字段
# gh-11507
s = 'Some long field name'
# 如果支持引用计数,获取字符串 s 的基准引用计数
if HAS_REFCOUNT:
base = sys.getrefcount(s)
# 创建一个结构化数据类型,包含一个名为 s 的字段 'f1',类型为 np.float64
t = np.dtype([((s, 'f1'), np.float64)])
# 创建一个包含 10 个元素的全零数组,数据类型为 t
data = np.zeros(10, t)
# 遍历数组中的每个元素
for i in range(10):
# 使用标题 'f1' 访问数据,仅保留引用计数不变
str(data[['f1']])
# 如果支持引用计数,断言字符串 s 的引用计数不增加
if HAS_REFCOUNT:
assert_(base <= sys.getrefcount(s))
@pytest.mark.parametrize('val', [
# 数组和标量
np.ones((10, 10), dtype='int32'),
np.uint64(10),
])
@pytest.mark.parametrize('protocol',
range(2, pickle.HIGHEST_PROTOCOL + 1)
)
def test_pickle_module(self, protocol, val):
# 测试 pickle 模块
# gh-12837
# 序列化值 val 使用指定的协议 protocol
s = pickle.dumps(val, protocol)
# 断言序列化结果中不包含 '_multiarray_umath'
assert b'_multiarray_umath' not in s
# 若协议为 5 且 val 的维度大于 0,则断言序列化结果中包含 'numpy._core.numeric'
if protocol == 5 and len(val.shape) > 0:
assert b'numpy._core.numeric' in s
else:
# 否则断言序列化结果中包含 'numpy._core.multiarray'
assert b'numpy._core.multiarray' in s
def test_object_casting_errors(self):
# 测试对象类型转换错误
# gh-11993 update to ValueError (see gh-16909), since strings can in
# principle be converted to complex, but this string cannot.
# 创建一个包含字符串、浮点数和整数的对象数组
arr = np.array(['AAAAA', 18465886.0, 18465886.0], dtype=object)
# 断言将该数组转换为 'c8' 类型会引发 ValueError 异常
assert_raises(ValueError, arr.astype, 'c8')
def test_eff1d_casting(self):
# 测试 ediff1d 函数的类型转换
# gh-12711
# 创建一个整数类型的数组 x
x = np.array([1, 2, 4, 7, 0], dtype=np.int16)
# 使用 ediff1d 函数对数组 x 进行差分计算,并在开头和结尾添加特定值
res = np.ediff1d(x, to_begin=-99, to_end=np.array([88, 99]))
# 断言计算结果 res 符合预期值
assert_equal(res, [-99, 1, 2, 3, -7, 88, 99])
# 使用安全类型转换时,将 1<<20 不安全地转换,可能更好的做法是引发错误,
# 但目前没有相应的机制处理这种情况。
res = np.ediff1d(x, to_begin=(1<<20), to_end=(1<<20))
# 断言计算结果 res 符合预期值
assert_equal(res, [0, 1, 2, 3, -7, 0])
def test_pickle_datetime64_array(self):
# 测试对 datetime64 数组进行 pickle 序列化,确保在不同协议下都能正确反序列化
# gh-12745 (如果安装了 pickle5,则会失败)
d = np.datetime64('2015-07-04 12:59:59.50', 'ns')
arr = np.array([d])
for proto in range(2, pickle.HIGHEST_PROTOCOL + 1):
dumped = pickle.dumps(arr, protocol=proto)
# 使用 assert_equal 检查反序列化后的结果与原始数组是否相等
assert_equal(pickle.loads(dumped), arr)
def test_bad_array_interface(self):
# 测试当类没有正确定义 __array_interface__ 属性时,是否能捕获 ValueError 异常
class T:
__array_interface__ = {}
with assert_raises(ValueError):
np.array([T()])
def test_2d__array__shape(self):
# 测试自定义类的 __array__ 方法返回空的 2D 数组,并验证 numpy 在创建数组时是否正确调用该方法
class T:
def __array__(self, dtype=None, copy=None):
return np.ndarray(shape=(0,0))
# 确保在数组创建时使用 __array__ 方法而不是 Sequence 方法
def __iter__(self):
return iter([])
def __getitem__(self, idx):
raise AssertionError("__getitem__ was called")
def __len__(self):
return 0
t = T()
# gh-13659, 确保不会在广播操作中出现异常
arr = np.array([t])
assert arr.shape == (1, 0, 0)
@pytest.mark.skipif(sys.maxsize < 2 ** 31 + 1, reason='overflows 32-bit python')
def test_to_ctypes(self):
# 测试将 numpy 数组转换为 ctypes 对象,确保长度信息正确
# gh-14214
arr = np.zeros((2 ** 31 + 1,), 'b')
assert arr.size * arr.itemsize > 2 ** 31
c_arr = np.ctypeslib.as_ctypes(arr)
assert_equal(c_arr._length_, arr.size)
def test_complex_conversion_error(self):
# 测试复数类型转换时的异常处理
# gh-17068
with pytest.raises(TypeError, match=r"Unable to convert dtype.*"):
complex(np.array("now", np.datetime64))
def test__array_interface__descr(self):
# 测试自定义 dtype 的数组的 __array_interface__ 描述符
# gh-17068
dt = np.dtype(dict(names=['a', 'b'],
offsets=[0, 0],
formats=[np.int64, np.int64]))
descr = np.array((1, 1), dtype=dt).__array_interface__['descr']
# 使用 assert 检查描述符的格式
assert descr == [('', '|V8')] # instead of [(b'', '|V8')]
@pytest.mark.skipif(sys.maxsize < 2 ** 31 + 1, reason='overflows 32-bit python')
@requires_memory(free_bytes=9e9)
def test_dot_big_stride(self):
# 测试在大步长情况下的 dot 运算
# gh-17111
# blas stride = stride//itemsize > int32 max
int32_max = np.iinfo(np.int32).max
n = int32_max + 3
a = np.empty([n], dtype=np.float32)
b = a[::n-1]
b[...] = 1
# 使用 assert 检查步长是否符合预期,并验证 dot 运算的结果
assert b.strides[0] > int32_max * b.dtype.itemsize
assert np.dot(b, b) == 2.0
def test_frompyfunc_name(self):
# 测试 frompyfunc 函数是否正确处理函数名称的转换,特别是在 Python 3 字符串上的处理
# 以及使用非 ASCII 名称进行 utf-8 编码的情况
def cassé(x):
return x
f = np.frompyfunc(cassé, 1, 1)
assert str(f) == "<ufunc 'cassé (vectorized)'>"
@pytest.mark.parametrize("operation", [
'add', 'subtract', 'multiply', 'floor_divide',
'conjugate', 'fmod', 'square', 'reciprocal',
'power', 'absolute', 'negative', 'positive',
'greater', 'greater_equal', 'less',
'less_equal', 'equal', 'not_equal', 'logical_and',
'logical_not', 'logical_or', 'bitwise_and', 'bitwise_or',
'bitwise_xor', 'invert', 'left_shift', 'right_shift',
'gcd', 'lcm'
]
)
@pytest.mark.parametrize("order", [
('b->', 'B->'),
('h->', 'H->'),
('i->', 'I->'),
('l->', 'L->'),
('q->', 'Q->'),
]
)
# 定义测试函数,用于测试各种 numpy 的 universal function (ufunc) 的执行顺序
def test_ufunc_order(self, operation, order):
# gh-18075
# 确保有符号类型在无符号类型之前
def get_idx(string, str_lst):
for i, s in enumerate(str_lst):
if string in s:
return i
raise ValueError(f"{string} not in list")
# 获取指定 ufunc 的类型列表
types = getattr(np, operation).types
# 断言检查:有符号类型在无符号类型之前
assert get_idx(order[0], types) < get_idx(order[1], types), (
f"Unexpected types order of ufunc in {operation}"
f"for {order}. Possible fix: Use signed before unsigned"
"in generate_umath.py")
def test_nonbool_logical(self):
# gh-22845
# 创建两个数组,它们的位模式不重叠。
# 数组大小需要足够大,以测试 SIMD 和标量路径
size = 100
# 使用字节缓冲区创建 np.bool 类型数组 a 和 b
a = np.frombuffer(b'\x01' * size, dtype=np.bool)
b = np.frombuffer(b'\x80' * size, dtype=np.bool)
# 创建预期结果数组,全为 True
expected = np.ones(size, dtype=np.bool)
# 断言检查:逻辑与操作的结果是否与预期相同
assert_array_equal(np.logical_and(a, b), expected)
@pytest.mark.skipif(IS_PYPY, reason="PyPy issue 2742")
def test_gh_23737(self):
# gh-23737
# 使用 pytest.raises 确保 TypeError 被正确地抛出
with pytest.raises(TypeError, match="not an acceptable base type"):
# 尝试创建继承自 np.flexible 的类 Y,预期抛出异常
class Y(np.flexible):
pass
with pytest.raises(TypeError, match="not an acceptable base type"):
# 尝试创建继承自 np.flexible 和 np.ma.core.MaskedArray 的类 X,预期抛出异常
class X(np.flexible, np.ma.core.MaskedArray):
pass
def test_load_ufunc_pickle(self):
# ufunc 被使用半私有路径 numpy.core._multiarray_umath 进行 pickle,必须能够无警告地加载
test_data = b'\x80\x04\x95(\x00\x00\x00\x00\x00\x00\x00\x8c\x1cnumpy.core._multiarray_umath\x94\x8c\x03add\x94\x93\x94.' # noqa
# 使用 pickle.loads 加载测试数据
result = pickle.loads(test_data, encoding='bytes')
# 断言检查:加载的结果应该是 np.add
assert result is np.add
def test__array_namespace__(self):
# 创建一个长度为2的 NumPy 数组
arr = np.arange(2)
# 调用数组的 __array_namespace__ 方法,返回 numpy 的命名空间对象 np
xp = arr.__array_namespace__()
# 断言返回值为 numpy 对象 np
assert xp is np
# 使用指定的 API 版本调用 __array_namespace__ 方法,仍返回 numpy 的命名空间对象 np
xp = arr.__array_namespace__(api_version="2021.12")
assert xp is np
# 使用指定的 API 版本调用 __array_namespace__ 方法,仍返回 numpy 的命名空间对象 np
xp = arr.__array_namespace__(api_version="2022.12")
assert xp is np
# 使用 None 作为 API 版本调用 __array_namespace__ 方法,仍返回 numpy 的命名空间对象 np
xp = arr.__array_namespace__(api_version=None)
assert xp is np
# 使用不支持的 API 版本调用 __array_namespace__ 方法,预期抛出 ValueError 异常
with pytest.raises(
ValueError,
match="Version \"2023.12\" of the Array API Standard "
"is not supported."
):
arr.__array_namespace__(api_version="2023.12")
# 使用非字符串类型作为 API 版本调用 __array_namespace__ 方法,预期抛出 ValueError 异常
with pytest.raises(
ValueError,
match="Only None and strings are allowed as the Array API version"
):
arr.__array_namespace__(api_version=2023)
def test_isin_refcnt_bug(self):
# gh-25295
# 循环执行 1000 次,调用 np.isclose 比较两个 np.int64 类型的数值
for _ in range(1000):
np.isclose(np.int64(2), np.int64(2), atol=1e-15, rtol=1e-300)
def test_replace_regression(self):
# gh-25513 segfault
# 创建一个 chararray,并用指定的测试字符串初始化
carr = np.char.chararray((2,), itemsize=25)
test_strings = [b' 4.52173913043478315E+00',
b' 4.95652173913043548E+00']
carr[:] = test_strings
# 调用 replace 方法替换字符串中的 b"E" 为 b"D"
out = carr.replace(b"E", b"D")
# 创建一个期望结果的 chararray,并用预期的替换结果初始化
expected = np.char.chararray((2,), itemsize=25)
expected[:] = [s.replace(b"E", b"D") for s in test_strings]
# 断言替换后的结果与预期结果相等
assert_array_equal(out, expected)
def test_logspace_base_does_not_determine_dtype(self):
# gh-24957 and cupy/cupy/issues/7946
# 创建起始点和终止点的 float16 数组
start = np.array([0, 2], dtype=np.float16)
stop = np.array([2, 0], dtype=np.float16)
# 使用 logspace 函数生成对数空间的数组,指定了 dtype 为 float32
out = np.logspace(start, stop, num=5, axis=1, dtype=np.float32)
# 创建预期的 float32 类型的数组
expected = np.array([[1., 3.1621094, 10., 31.625, 100.],
[100., 31.625, 10., 3.1621094, 1.]],
dtype=np.float32)
# 断言生成的数组与预期结果几乎相等
assert_almost_equal(out, expected)
# 检查如果计算使用 float64,则测试失败,因为之前一个错误的 python float base 影响了 dtype
out2 = np.logspace(start, stop, num=5, axis=1, dtype=np.float32,
base=np.array([10.0]))
with pytest.raises(AssertionError, match="not almost equal"):
assert_almost_equal(out2, expected)
def test_vectorize_fixed_width_string(self):
# 创建一个包含固定宽度字符串的数组,将其类型转换为 np.str_
arr = np.array(["SOme wOrd DŽ ß ᾛ ΣΣ ffi⁵Å Ç Ⅰ"]).astype(np.str_)
# 定义 str.casefold 函数
f = str.casefold
# 对数组中的每个元素应用 casefold 函数,返回结果数组
res = np.vectorize(f, otypes=[arr.dtype])(arr)
# 断言结果数组的 dtype 等于 "U30"
assert res.dtype == "U30"