NumPy 源码解析(十四)
.\numpy\numpy\lib\tests\test_histograms.py
import numpy as np
from numpy import histogram, histogramdd, histogram_bin_edges
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_almost_equal,
assert_array_almost_equal, assert_raises, assert_allclose,
assert_array_max_ulp, assert_raises_regex, suppress_warnings,
)
from numpy.testing._private.utils import requires_memory
import pytest
class TestHistogram:
def setup_method(self):
pass
def teardown_method(self):
pass
def test_simple(self):
n = 100
v = np.random.rand(n)
(a, b) = histogram(v)
assert_equal(np.sum(a, axis=0), n)
(a, b) = histogram(np.linspace(0, 10, 100))
assert_array_equal(a, 10)
def test_one_bin(self):
hist, edges = histogram([1, 2, 3, 4], [1, 2])
assert_array_equal(hist, [2, ])
assert_array_equal(edges, [1, 2])
assert_raises(ValueError, histogram, [1, 2], bins=0)
h, e = histogram([1, 2], bins=1)
assert_equal(h, np.array([2]))
assert_allclose(e, np.array([1., 2.]))
def test_density(self):
n = 100
v = np.random.rand(n)
a, b = histogram(v, density=True)
area = np.sum(a * np.diff(b))
assert_almost_equal(area, 1)
v = np.arange(10)
bins = [0, 1, 3, 6, 10]
a, b = histogram(v, bins, density=True)
assert_array_equal(a, .1)
assert_equal(np.sum(a * np.diff(b)), 1)
a, b = histogram(v, bins, density=False)
assert_array_equal(a, [1, 2, 3, 4])
v = np.arange(10)
bins = [0, 1, 3, 6, np.inf]
a, b = histogram(v, bins, density=True)
assert_array_equal(a, [.1, .1, .1, 0.])
counts, dmy = np.histogram(
[1, 2, 3, 4], [0.5, 1.5, np.inf], density=True)
assert_equal(counts, [.25, 0])
def test_outliers(self):
a = np.arange(10) + .5
h, b = histogram(a, range=[0, 9])
assert_equal(h.sum(), 9)
h, b = histogram(a, range=[1, 10])
assert_equal(h.sum(), 9)
h, b = histogram(a, range=[1, 9], density=True)
assert_almost_equal((h * np.diff(b)).sum(), 1, decimal=15)
w = np.arange(10) + .5
h, b = histogram(a, range=[1, 9], weights=w, density=True)
assert_equal((h * np.diff(b)).sum(), 1)
h, b = histogram(a, bins=8, range=[1, 9], weights=w)
assert_equal(h, w[1:-1])
def test_arr_weights_mismatch(self):
a = np.arange(10) + .5
w = np.arange(11) + .5
with assert_raises_regex(ValueError, "same shape as"):
h, b = histogram(a, range=[1, 9], weights=w, density=True)
def test_type(self):
a = np.arange(10) + .5
h, b = histogram(a)
assert_(np.issubdtype(h.dtype, np.integer))
h, b = histogram(a, density=True)
assert_(np.issubdtype(h.dtype, np.floating))
h, b = histogram(a, weights=np.ones(10, int))
assert_(np.issubdtype(h.dtype, np.integer))
h, b = histogram(a, weights=np.ones(10, float))
assert_(np.issubdtype(h.dtype, np.floating))
def test_f32_rounding(self):
x = np.array([276.318359, -69.593948, 21.329449], dtype=np.float32)
y = np.array([5005.689453, 4481.327637, 6010.369629], dtype=np.float32)
counts_hist, xedges, yedges = np.histogram2d(x, y, bins=100)
assert_equal(counts_hist.sum(), 3.)
def test_bool_conversion(self):
a = np.array([1, 1, 0], dtype=np.uint8)
int_hist, int_edges = np.histogram(a)
with suppress_warnings() as sup:
rec = sup.record(RuntimeWarning, 'Converting input from .*')
hist, edges = np.histogram([True, True, False])
assert_equal(len(rec), 1)
assert_array_equal(hist, int_hist)
assert_array_equal(edges, int_edges)
def test_weights(self):
v = np.random.rand(100)
w = np.ones(100) * 5
a, b = histogram(v)
na, nb = histogram(v, density=True)
wa, wb = histogram(v, weights=w)
nwa, nwb = histogram(v, weights=w, density=True)
assert_array_almost_equal(a * 5, wa)
assert_array_almost_equal(na, nwa)
v = np.linspace(0, 10, 10)
w = np.concatenate((np.zeros(5), np.ones(5)))
wa, wb = histogram(v, bins=np.arange(11), weights=w)
assert_array_almost_equal(wa, w)
wa, wb = histogram([1, 2, 2, 4], bins=4, weights=[4, 3, 2, 1])
assert_array_equal(wa, [4, 5, 0, 1])
wa, wb = histogram([1, 2, 2, 4], bins=4, weights=[4, 3, 2, 1], density=True)
assert_array_almost_equal(wa, np.array([4, 5, 0, 1]) / 10. / 3. * 4)
a, b = histogram(np.arange(9), [0, 1, 3, 6, 10],
weights=[2, 1, 1, 1, 1, 1, 1, 1, 1], density=True)
assert_almost_equal(a, [.2, .1, .1, .075])
def test_exotic_weights(self):
values = np.array([1.3, 2.5, 2.3])
weights = np.array([1, -1, 2]) + 1j * np.array([2, 1, 2])
wa, wb = histogram(values, bins=[0, 2, 3], weights=weights)
assert_array_almost_equal(wa, np.array([1, 1]) + 1j * np.array([2, 3]))
wa, wb = histogram(values, bins=2, range=[1, 3], weights=weights)
assert_array_almost_equal(wa, np.array([1, 1]) + 1j * np.array([2, 3]))
from decimal import Decimal
values = np.array([1.3, 2.5, 2.3])
weights = np.array([Decimal(1), Decimal(2), Decimal(3)])
wa, wb = histogram(values, bins=[0, 2, 3], weights=weights)
assert_array_almost_equal(wa, [Decimal(1), Decimal(5)])
wa, wb = histogram(values, bins=2, range=[1, 3], weights=weights)
assert_array_almost_equal(wa, [Decimal(1), Decimal(5)])
def test_no_side_effects(self):
values = np.array([1.3, 2.5, 2.3])
np.histogram(values, range=[-10, 10], bins=100)
assert_array_almost_equal(values, [1.3, 2.5, 2.3])
def test_empty(self):
a, b = histogram([], bins=([0, 1]))
assert_array_equal(a, np.array([0]))
assert_array_equal(b, np.array([0, 1]))
def test_error_binnum_type (self):
vals = np.linspace(0.0, 1.0, num=100)
histogram(vals, 5)
assert_raises(TypeError, histogram, vals, 2.4)
def test_finite_range(self):
vals = np.linspace(0.0, 1.0, num=100)
histogram(vals, range=[0.25,0.75])
assert_raises(ValueError, histogram, vals, range=[np.nan,0.75])
assert_raises(ValueError, histogram, vals, range=[0.25,np.inf])
def test_invalid_range(self):
vals = np.linspace(0.0, 1.0, num=100)
with assert_raises_regex(ValueError, "max must be larger than"):
np.histogram(vals, range=[0.1, 0.01])
def test_bin_edge_cases(self):
arr = np.array([337, 404, 739, 806, 1007, 1811, 2012])
hist, edges = np.histogram(arr, bins=8296, range=(2, 2280))
mask = hist > 0
left_edges = edges[:-1][mask]
right_edges = edges[1:][mask]
for x, left, right in zip(arr, left_edges, right_edges):
assert_(x >= left)
assert_(x < right)
def test_last_bin_inclusive_range(self):
arr = np.array([0., 0., 0., 1., 2., 3., 3., 4., 5.])
hist, edges = np.histogram(arr, bins=30, range=(-0.5, 5))
assert_equal(hist[-1], 1)
def test_bin_array_dims(self):
vals = np.linspace(0.0, 1.0, num=100)
bins = np.array([[0, 0.5], [0.6, 1.0]])
with assert_raises_regex(ValueError, "must be 1d"):
np.histogram(vals, bins=bins)
def test_unsigned_monotonicity_check(self):
arr = np.array([2])
bins = np.array([1, 3, 1], dtype='uint64')
with assert_raises(ValueError):
hist, edges = np.histogram(arr, bins=bins)
def test_object_array_of_0d(self):
assert_raises(ValueError,
histogram, [np.array(0.4) for i in range(10)] + [-np.inf])
assert_raises(ValueError,
histogram, [np.array(0.4) for i in range(10)] + [np.inf])
np.histogram([np.array(0.5) for i in range(10)] + [.500000000000001])
np.histogram([np.array(0.5) for i in range(10)] + [.5])
def test_some_nan_values(self):
one_nan = np.array([0, 1, np.nan])
all_nan = np.array([np.nan, np.nan])
sup = suppress_warnings()
sup.filter(RuntimeWarning)
with sup:
assert_raises(ValueError, histogram, one_nan, bins='auto')
assert_raises(ValueError, histogram, all_nan, bins='auto')
h, b = histogram(one_nan, bins='auto', range=(0, 1))
assert_equal(h.sum(), 2)
h, b = histogram(all_nan, bins='auto', range=(0, 1))
assert_equal(h.sum(), 0)
h, b = histogram(one_nan, bins=[0, 1])
assert_equal(h.sum(), 2)
h, b = histogram(all_nan, bins=[0, 1])
assert_equal(h.sum(), 0)
def test_datetime(self):
begin = np.datetime64('2000-01-01', 'D')
offsets = np.array([0, 0, 1, 1, 2, 3, 5, 10, 20])
bins = np.array([0, 2, 7, 20])
dates = begin + offsets
date_bins = begin + bins
td = np.dtype('timedelta64[D]')
d_count, d_edge = histogram(dates, bins=date_bins)
t_count, t_edge = histogram(offsets.astype(td), bins=bins.astype(td))
i_count, i_edge = histogram(offsets, bins=bins)
assert_equal(d_count, i_count)
assert_equal(t_count, i_count)
assert_equal((d_edge - begin).astype(int), i_edge)
assert_equal(t_edge.astype(int), i_edge)
assert_equal(d_edge.dtype, dates.dtype)
assert_equal(t_edge.dtype, td)
def do_signed_overflow_bounds(self, dtype):
exponent = 8 * np.dtype(dtype).itemsize - 1
arr = np.array([-2**exponent + 4, 2**exponent - 4], dtype=dtype)
hist, e = histogram(arr, bins=2)
assert_equal(e, [-2**exponent + 4, 0, 2**exponent - 4])
assert_equal(hist, [1, 1])
def test_signed_overflow_bounds(self):
self.do_signed_overflow_bounds(np.byte)
self.do_signed_overflow_bounds(np.short)
self.do_signed_overflow_bounds(np.intc)
self.do_signed_overflow_bounds(np.int_)
self.do_signed_overflow_bounds(np.longlong)
def do_precision_lower_bound(self, float_small, float_large):
eps = np.finfo(float_large).eps
arr = np.array([1.0], float_small)
range = np.array([1.0 + eps, 2.0], float_large)
if range.astype(float_small)[0] != 1:
return
count, x_loc = np.histogram(arr, bins=1, range=range)
assert_equal(count, [0])
def do_precision_upper_bound(self, float_small, float_large):
eps = np.finfo(float_large).eps
arr = np.array([1.0], float_small)
range = np.array([0.0, 1.0 - eps], float_large)
if range.astype(float_small)[-1] != 1:
return
count, x_loc = np.histogram(arr, bins=1, range=range)
assert_equal(count, [0])
def do_precision(self, float_small, float_large):
self.do_precision_lower_bound(float_small, float_large)
self.do_precision_upper_bound(float_small, float_large)
def test_precision(self):
self.do_precision(np.half, np.single)
self.do_precision(np.half, np.double)
self.do_precision(np.half, np.longdouble)
self.do_precision(np.single, np.double)
self.do_precision(np.single, np.longdouble)
self.do_precision(np.double, np.longdouble)
def test_histogram_bin_edges(self):
hist, e = histogram([1, 2, 3, 4], [1, 2])
edges = histogram_bin_edges([1, 2, 3, 4], [1, 2])
assert_array_equal(edges, e)
arr = np.array([0., 0., 0., 1., 2., 3., 3., 4., 5.])
hist, e = histogram(arr, bins=30, range=(-0.5, 5))
edges = histogram_bin_edges(arr, bins=30, range=(-0.5, 5))
assert_array_equal(edges, e)
hist, e = histogram(arr, bins='auto', range=(0, 1))
edges = histogram_bin_edges(arr, bins='auto', range=(0, 1))
assert_array_equal(edges, e)
@pytest.mark.skip(reason="Bad memory reports lead to OOM in ci testing")
def test_big_arrays(self):
sample = np.zeros([100000000, 3])
xbins = 400
ybins = 400
zbins = np.arange(16000)
hist = np.histogramdd(sample=sample, bins=(xbins, ybins, zbins))
def test_gh_23110(self):
hist, e = np.histogram(np.array([-0.9e-308], dtype='>f8'),
bins=2,
range=(-1e-308, -2e-313))
expected_hist = np.array([1, 0])
assert_array_equal(hist, expected_hist)
class TestHistogramOptimBinNums:
"""
Provide test coverage when using provided estimators for optimal number of
bins
"""
def test_empty(self):
estimator_list = ['fd', 'scott', 'rice', 'sturges',
'doane', 'sqrt', 'auto', 'stone']
for estimator in estimator_list:
a, b = histogram([], bins=estimator)
assert_array_equal(a, np.array([0]))
assert_array_equal(b, np.array([0, 1]))
def test_simple(self):
"""
Straightforward testing with a mixture of linspace data (for
consistency). All test values have been precomputed and the values
shouldn't change
"""
basic_test = {50: {'fd': 4, 'scott': 4, 'rice': 8, 'sturges': 7,
'doane': 8, 'sqrt': 8, 'auto': 7, 'stone': 2},
500: {'fd': 8, 'scott': 8, 'rice': 16, 'sturges': 10,
'doane': 12, 'sqrt': 23, 'auto': 10, 'stone': 9},
5000: {'fd': 17, 'scott': 17, 'rice': 35, 'sturges': 14,
'doane': 17, 'sqrt': 71, 'auto': 17, 'stone': 20}}
for testlen, expectedResults in basic_test.items():
x1 = np.linspace(-10, -1, testlen // 5 * 2)
x2 = np.linspace(1, 10, testlen // 5 * 3)
x = np.concatenate((x1, x2))
for estimator, numbins in expectedResults.items():
a, b = np.histogram(x, estimator)
assert_equal(len(a), numbins, err_msg="For the {0} estimator "
"with datasize of {1}".format(estimator, testlen))
def test_small(self):
"""
Smaller datasets have the potential to cause issues with the data
adaptive methods, especially the FD method. All bin numbers have been
precalculated.
"""
small_dat = {1: {'fd': 1, 'scott': 1, 'rice': 1, 'sturges': 1,
'doane': 1, 'sqrt': 1, 'stone': 1},
2: {'fd': 2, 'scott': 1, 'rice': 3, 'sturges': 2,
'doane': 1, 'sqrt': 2, 'stone': 1},
3: {'fd': 2, 'scott': 2, 'rice': 3, 'sturges': 3,
'doane': 3, 'sqrt': 2, 'stone': 1}}
for testlen, expectedResults in small_dat.items():
testdat = np.arange(testlen).astype(float)
for estimator, expbins in expectedResults.items():
a, b = np.histogram(testdat, estimator)
assert_equal(len(a), expbins, err_msg="For the {0} estimator "
"with datasize of {1}".format(estimator, testlen))
def test_incorrect_methods(self):
"""
Check a Value Error is thrown when an unknown string is passed in
"""
check_list = ['mad', 'freeman', 'histograms', 'IQR']
for estimator in check_list:
assert_raises(ValueError, histogram, [1, 2, 3], estimator)
def test_novariance(self):
"""
Check that methods handle no variance in data
Primarily for Scott and FD as the SD and IQR are both 0 in this case
"""
novar_dataset = np.ones(100)
novar_resultdict = {'fd': 1, 'scott': 1, 'rice': 1, 'sturges': 1,
'doane': 1, 'sqrt': 1, 'auto': 1, 'stone': 1}
for estimator, numbins in novar_resultdict.items():
a, b = np.histogram(novar_dataset, estimator)
assert_equal(len(a), numbins, err_msg="{0} estimator, "
"No Variance test".format(estimator))
def test_limited_variance(self):
"""
Check when IQR is 0, but variance exists, we return the sturges value
and not the fd value.
"""
lim_var_data = np.ones(1000)
lim_var_data[:3] = 0
lim_var_data[-4:] = 100
edges_auto = histogram_bin_edges(lim_var_data, 'auto')
assert_equal(edges_auto, np.linspace(0, 100, 12))
edges_fd = histogram_bin_edges(lim_var_data, 'fd')
assert_equal(edges_fd, np.array([0, 100]))
edges_sturges = histogram_bin_edges(lim_var_data, 'sturges')
assert_equal(edges_sturges, np.linspace(0, 100, 12))
def test_outlier(self):
"""
Check the FD, Scott and Doane with outliers.
The FD estimates a smaller binwidth since it's less affected by
outliers. Since the range is so (artificially) large, this means more
bins, most of which will be empty, but the data of interest usually is
unaffected. The Scott estimator is more affected and returns fewer bins,
despite most of the variance being in one area of the data. The Doane
estimator lies somewhere between the other two.
"""
xcenter = np.linspace(-10, 10, 50)
outlier_dataset = np.hstack((np.linspace(-110, -100, 5), xcenter))
outlier_resultdict = {'fd': 21, 'scott': 5, 'doane': 11, 'stone': 6}
for estimator, numbins in outlier_resultdict.items():
a, b = np.histogram(outlier_dataset, estimator)
assert_equal(len(a), numbins)
def test_scott_vs_stone(self):
"""Verify that Scott's rule and Stone's rule converges for normally distributed data"""
def nbins_ratio(seed, size):
rng = np.random.RandomState(seed)
x = rng.normal(loc=0, scale=2, size=size)
a, b = len(np.histogram(x, 'stone')[0]), len(np.histogram(x, 'scott')[0])
return a / (a + b)
ll = [[nbins_ratio(seed, size) for size in np.geomspace(start=10, stop=100, num=4).round().astype(int)]
for seed in range(10)]
avg = abs(np.mean(ll, axis=0) - 0.5)
assert_almost_equal(avg, [0.15, 0.09, 0.08, 0.03], decimal=2)
def test_simple_range(self):
"""
Straightforward testing with a mixture of linspace data (for
consistency). Adding in a 3rd mixture that will then be
completely ignored. All test values have been precomputed and
the shouldn't change.
"""
basic_test = {
50: {'fd': 8, 'scott': 8, 'rice': 15,
'sturges': 14, 'auto': 14, 'stone': 8},
500: {'fd': 15, 'scott': 16, 'rice': 32,
'sturges': 20, 'auto': 20, 'stone': 80},
5000: {'fd': 33, 'scott': 33, 'rice': 69,
'sturges': 27, 'auto': 33, 'stone': 80}
}
for testlen, expectedResults in basic_test.items():
x1 = np.linspace(-10, -1, testlen // 5 * 2)
x2 = np.linspace(1, 10, testlen // 5 * 3)
x3 = np.linspace(-100, -50, testlen)
x = np.hstack((x1, x2, x3))
for estimator, numbins in expectedResults.items():
a, b = np.histogram(x, estimator, range=(-20, 20))
msg = "For the {0} estimator".format(estimator)
msg += " with datasize of {0}".format(testlen)
assert_equal(len(a), numbins, err_msg=msg)
@pytest.mark.parametrize("bins", ['auto', 'fd', 'doane', 'scott',
'stone', 'rice', 'sturges'])
def test_signed_integer_data(self, bins):
a = np.array([-2, 0, 127], dtype=np.int8)
hist, edges = np.histogram(a, bins=bins)
hist32, edges32 = np.histogram(a.astype(np.int32), bins=bins)
assert_array_equal(hist, hist32)
assert_array_equal(edges, edges32)
def test_integer(self, bins):
"""
Test that bin width for integer data is at least 1.
"""
with suppress_warnings() as sup:
if bins == 'stone':
sup.filter(RuntimeWarning)
assert_equal(
np.histogram_bin_edges(np.tile(np.arange(9), 1000), bins),
np.arange(9))
def test_integer_non_auto(self):
"""
Test that the bin-width>=1 requirement *only* applies to auto binning.
"""
assert_equal(
np.histogram_bin_edges(np.tile(np.arange(9), 1000), 16),
np.arange(17) / 2)
assert_equal(
np.histogram_bin_edges(np.tile(np.arange(9), 1000), [.1, .2]),
[.1, .2])
def test_simple_weighted(self):
"""
Check that weighted data raises a TypeError
"""
estimator_list = ['fd', 'scott', 'rice', 'sturges', 'auto']
for estimator in estimator_list:
assert_raises(TypeError, histogram, [1, 2, 3],
estimator, weights=[1, 2, 3])
class TestHistogramdd:
def test_simple(self):
x = np.array([[-.5, .5, 1.5], [-.5, 1.5, 2.5], [-.5, 2.5, .5],
[.5, .5, 1.5], [.5, 1.5, 2.5], [.5, 2.5, 2.5]])
H, edges = histogramdd(x, (2, 3, 3),
range=[[-1, 1], [0, 3], [0, 3]])
answer = np.array([[[0, 1, 0], [0, 0, 1], [1, 0, 0]],
[[0, 1, 0], [0, 0, 1], [0, 0, 1]]])
assert_array_equal(H, answer)
ed = [[-2, 0, 2], [0, 1, 2, 3], [0, 1, 2, 3]]
H, edges = histogramdd(x, bins=ed, density=True)
assert_(np.all(H == answer / 12.))
H, edges = histogramdd(x, (2, 3, 4),
range=[[-1, 1], [0, 3], [0, 4]],
density=True)
answer = np.array([[[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]],
[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0]]])
assert_array_almost_equal(H, answer / 6., 4)
z = [np.squeeze(y) for y in np.split(x, 3, axis=1)]
H, edges = histogramdd(
z, bins=(4, 3, 2), range=[[-2, 2], [0, 3], [0, 2]])
answer = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 1], [0, 0], [1, 0]],
[[0, 1], [0, 0], [0, 0]],
[[0, 0], [0, 0], [0, 0]]])
assert_array_equal(H, answer)
Z = np.zeros((5, 5, 5))
Z[list(range(5)), list(range(5)), list(range(5))] = 1.
H, edges = histogramdd([np.arange(5), np.arange(5), np.arange(5)], 5)
assert_array_equal(H, Z)
def test_shape_3d(self):
bins = ((5, 4, 6), (6, 4, 5), (5, 6, 4), (4, 6, 5), (6, 5, 4),
(4, 5, 6))
r = np.random.rand(10, 3)
for b in bins:
H, edges = histogramdd(r, b)
assert_(H.shape == b)
def test_shape_4d(self):
bins = ((7, 4, 5, 6), (4, 5, 7, 6), (5, 6, 4, 7), (7, 6, 5, 4),
(5, 7, 6, 4), (4, 6, 7, 5), (6, 5, 7, 4), (7, 5, 4, 6),
(7, 4, 6, 5), (6, 4, 7, 5), (6, 7, 5, 4), (4, 6, 5, 7),
(4, 7, 5, 6), (5, 4, 6, 7), (5, 7, 4, 6), (6, 7, 4, 5),
(6, 5, 4, 7), (4, 7, 6, 5), (4, 5, 6, 7), (7, 6, 4, 5),
(5, 4, 7, 6), (5, 6, 7, 4), (6, 4, 5, 7), (7, 5, 6, 4))
r = np.random.rand(10, 4)
for b in bins:
H, edges = histogramdd(r, b)
assert_(H.shape == b)
def test_weights(self):
v = np.random.rand(100, 2)
hist, edges = histogramdd(v)
n_hist, edges = histogramdd(v, density=True)
w_hist, edges = histogramdd(v, weights=np.ones(100))
assert_array_equal(w_hist, hist)
w_hist, edges = histogramdd(v, weights=np.ones(100) * 2, density=True)
assert_array_equal(w_hist, n_hist)
w_hist, edges = histogramdd(v, weights=np.ones(100, int) * 2)
assert_array_equal(w_hist, 2 * hist)
def test_identical_samples(self):
x = np.zeros((10, 2), int)
hist, edges = histogramdd(x, bins=2)
assert_array_equal(edges[0], np.array([-0.5, 0., 0.5]))
def test_empty(self):
a, b = histogramdd([[], []], bins=([0, 1], [0, 1]))
assert_array_max_ulp(a, np.array([[0.]]))
a, b = np.histogramdd([[], [], []], bins=2)
assert_array_max_ulp(a, np.zeros((2, 2, 2)))
def test_bins_errors(self):
x = np.arange(8).reshape(2, 4)
assert_raises(ValueError, np.histogramdd, x, bins=[-1, 2, 4, 5])
assert_raises(ValueError, np.histogramdd, x, bins=[1, 0.99, 1, 1])
assert_raises(
ValueError, np.histogramdd, x, bins=[1, 1, 1, [1, 2, 3, -3]])
assert_(np.histogramdd(x, bins=[1, 1, 1, [1, 2, 3, 4]]))
def test_inf_edges(self):
with np.errstate(invalid='ignore'):
x = np.arange(6).reshape(3, 2)
expected = np.array([[1, 0], [0, 1], [0, 1]])
h, e = np.histogramdd(x, bins=[3, [-np.inf, 2, 10]])
assert_allclose(h, expected)
h, e = np.histogramdd(x, bins=[3, np.array([-1, 2, np.inf])])
assert_allclose(h, expected)
h, e = np.histogramdd(x, bins=[3, [-np.inf, 3, np.inf]])
assert_allclose(h, expected)
def test_rightmost_binedge(self):
x = [0.9999999995]
bins = [[0., 0.5, 1.0]]
hist, _ = histogramdd(x, bins=bins)
assert_(hist[0] == 0.0)
assert_(hist[1] == 1.)
x = [1.0]
bins = [[0., 0.5, 1.0]]
hist, _ = histogramdd(x, bins=bins)
assert_(hist[0] == 0.0)
assert_(hist[1] == 1.)
x = [1.0000000001]
bins = [[0., 0.5, 1.0]]
hist, _ = histogramdd(x, bins=bins)
assert_(hist[0] == 0.0)
assert_(hist[1] == 0.0)
x = [1.0001]
bins = [[0., 0.5, 1.0]]
hist, _ = histogramdd(x, bins=bins)
assert_(hist[0] == 0.0)
assert_(hist[1] == 0.0)
def test_finite_range(self):
vals = np.random.random((100, 3))
histogramdd(vals, range=[[0.0, 1.0], [0.25, 0.75], [0.25, 0.5]])
assert_raises(ValueError, histogramdd, vals,
range=[[0.0, 1.0], [0.25, 0.75], [0.25, np.inf]])
assert_raises(ValueError, histogramdd, vals,
range=[[0.0, 1.0], [np.nan, 0.75], [0.25, 0.5]])
def test_equal_edges(self):
x = np.array([0, 1, 2])
y = np.array([0, 1, 2])
x_edges = np.array([0, 2, 2])
y_edges = 1
hist, edges = histogramdd((x, y), bins=(x_edges, y_edges))
hist_expected = np.array([
[2.],
[1.],
])
assert_equal(hist, hist_expected)
def test_edge_dtype(self):
x = np.array([0, 10, 20])
y = x / 10
x_edges = np.array([0, 5, 15, 20])
y_edges = x_edges / 10
hist, edges = histogramdd((x, y), bins=(x_edges, y_edges))
assert_equal(edges[0].dtype, x_edges.dtype)
assert_equal(edges[1].dtype, y_edges.dtype)
def test_large_integers(self):
big = 2**60
x = np.array([0], np.int64)
x_edges = np.array([-1, +1], np.int64)
y = big + x
y_edges = big + x_edges
hist, edges = histogramdd((x, y), bins=(x_edges, y_edges))
assert_equal(hist[0, 0], 1)
def test_density_non_uniform_2d(self):
x_edges = np.array([0, 2, 8])
y_edges = np.array([0, 6, 8])
relative_areas = np.array([
[3, 9],
[1, 3]])
x = np.array([1] + [1]*3 + [7]*3 + [7]*9)
y = np.array([7] + [1]*3 + [7]*3 + [1]*9)
hist, edges = histogramdd((y, x), bins=(y_edges, x_edges))
assert_equal(hist, relative_areas)
hist, edges = histogramdd((y, x), bins=(y_edges, x_edges), density=True)
assert_equal(hist, 1 / (8*8))
def test_density_non_uniform_1d(self):
v = np.arange(10)
bins = np.array([0, 1, 3, 6, 10])
hist, edges = histogram(v, bins, density=True)
hist_dd, edges_dd = histogramdd((v,), (bins,), density=True)
assert_equal(hist, hist_dd)
assert_equal(edges, edges_dd[0])
.\numpy\numpy\lib\tests\test_index_tricks.py
import pytest
import numpy as np
from numpy.testing import (
assert_, assert_equal, assert_array_equal, assert_almost_equal,
assert_array_almost_equal, assert_raises, assert_raises_regex,
)
from numpy.lib._index_tricks_impl import (
mgrid, ogrid, ndenumerate, fill_diagonal, diag_indices, diag_indices_from,
index_exp, ndindex, c_, r_, s_, ix_
)
class TestRavelUnravelIndex:
def test_basic(self):
assert_equal(np.unravel_index(2, (2, 2)), (1, 0))
assert_equal(np.unravel_index(indices=2,
shape=(2, 2)),
(1, 0))
with assert_raises(TypeError):
np.unravel_index(indices=2, hape=(2, 2))
with assert_raises(TypeError):
np.unravel_index(2, hape=(2, 2))
with assert_raises(TypeError):
np.unravel_index(254, ims=(17, 94))
with assert_raises(TypeError):
np.unravel_index(254, dims=(17, 94))
assert_equal(np.ravel_multi_index((1, 0), (2, 2)), 2)
assert_equal(np.unravel_index(254, (17, 94)), (2, 66))
assert_equal(np.ravel_multi_index((2, 66), (17, 94)), 254)
assert_raises(ValueError, np.unravel_index, -1, (2, 2))
assert_raises(TypeError, np.unravel_index, 0.5, (2, 2))
assert_raises(ValueError, np.unravel_index, 4, (2, 2))
assert_raises(ValueError, np.ravel_multi_index, (-3, 1), (2, 2))
assert_raises(ValueError, np.ravel_multi_index, (2, 1), (2, 2))
assert_raises(ValueError, np.ravel_multi_index, (0, -3), (2, 2))
assert_raises(ValueError, np.ravel_multi_index, (0, 2), (2, 2))
assert_raises(TypeError, np.ravel_multi_index, (0.1, 0.), (2, 2))
assert_equal(np.unravel_index((2*3 + 1)*6 + 4, (4, 3, 6)), [2, 1, 4])
assert_equal(
np.ravel_multi_index([2, 1, 4], (4, 3, 6)), (2*3 + 1)*6 + 4)
arr = np.array([[3, 6, 6], [4, 5, 1]])
assert_equal(np.ravel_multi_index(arr, (7, 6)), [22, 41, 37])
assert_equal(
np.ravel_multi_index(arr, (7, 6), order='F'), [31, 41, 13])
assert_equal(
np.ravel_multi_index(arr, (4, 6), mode='clip'), [22, 23, 19])
assert_equal(np.ravel_multi_index(arr, (4, 4), mode=('clip', 'wrap')),
[12, 13, 13])
assert_equal(np.ravel_multi_index((3, 1, 4, 1), (6, 7, 8, 9)), 1621)
assert_equal(np.unravel_index(np.array([22, 41, 37]), (7, 6)),
[[3, 6, 6], [4, 5, 1]])
assert_equal(
np.unravel_index(np.array([31, 41, 13]), (7, 6), order='F'),
[[3, 6, 6], [4, 5, 1]])
assert_equal(np.unravel_index(1621, (6, 7, 8, 9)), [3, 1, 4, 1])
def test_empty_indices(self):
msg1 = 'indices must be integral: the provided empty sequence was'
msg2 = 'only int indices permitted'
assert_raises_regex(TypeError, msg1, np.unravel_index, [], (10, 3, 5))
assert_raises_regex(TypeError, msg1, np.unravel_index, (), (10, 3, 5))
assert_raises_regex(TypeError, msg2, np.unravel_index, np.array([]),
(10, 3, 5))
assert_equal(np.unravel_index(np.array([],dtype=int), (10, 3, 5)),
[[], [], []])
assert_raises_regex(TypeError, msg1, np.ravel_multi_index, ([], []),
(10, 3))
assert_raises_regex(TypeError, msg1, np.ravel_multi_index, ([], ['abc']),
(10, 3))
assert_raises_regex(TypeError, msg2, np.ravel_multi_index,
(np.array([]), np.array([])), (5, 3))
assert_equal(np.ravel_multi_index(
(np.array([], dtype=int), np.array([], dtype=int)), (5, 3)), [])
assert_equal(np.ravel_multi_index(np.array([[], []], dtype=int),
(5, 3)), [])
def test_big_indices(self):
if np.intp == np.int64:
arr = ([1, 29], [3, 5], [3, 117], [19, 2],
[2379, 1284], [2, 2], [0, 1])
assert_equal(
np.ravel_multi_index(arr, (41, 7, 120, 36, 2706, 8, 6)),
[5627771580, 117259570957])
assert_raises(ValueError, np.unravel_index, 1, (2**32-1, 2**31+1))
dummy_arr = ([0],[0])
half_max = np.iinfo(np.intp).max // 2
assert_equal(
np.ravel_multi_index(dummy_arr, (half_max, 2)), [0])
assert_raises(ValueError,
np.ravel_multi_index, dummy_arr, (half_max+1, 2))
assert_equal(
np.ravel_multi_index(dummy_arr, (half_max, 2), order='F'), [0])
assert_raises(ValueError,
np.ravel_multi_index, dummy_arr, (half_max+1, 2), order='F')
def test_dtypes(self):
for dtype in [np.int16, np.uint16, np.int32,
np.uint32, np.int64, np.uint64]:
coords = np.array(
[[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0]], dtype=dtype)
shape = (5, 8)
uncoords = 8*coords[0]+coords[1]
assert_equal(np.ravel_multi_index(coords, shape), uncoords)
assert_equal(coords, np.unravel_index(uncoords, shape))
uncoords = coords[0]+5*coords[1]
assert_equal(
np.ravel_multi_index(coords, shape, order='F'), uncoords)
assert_equal(coords, np.unravel_index(uncoords, shape, order='F'))
coords = np.array(
[[1, 0, 1, 2, 3, 4], [1, 6, 1, 3, 2, 0], [1, 3, 1, 0, 9, 5]],
dtype=dtype)
shape = (5, 8, 10)
uncoords = 10*(8*coords[0]+coords[1])+coords[2]
assert_equal(np.ravel_multi_index(coords, shape), uncoords)
assert_equal(coords, np.unravel_index(uncoords, shape))
uncoords = coords[0]+5*(coords[1]+8*coords[2])
assert_equal(
np.ravel_multi_index(coords, shape, order='F'), uncoords)
assert_equal(coords, np.unravel_index(uncoords, shape, order='F'))
def test_clipmodes(self):
assert_equal(
np.ravel_multi_index([5, 1, -1, 2], (4, 3, 7, 12), mode='wrap'),
np.ravel_multi_index([1, 1, 6, 2], (4, 3, 7, 12)))
assert_equal(np.ravel_multi_index([5, 1, -1, 2], (4, 3, 7, 12),
mode=(
'wrap', 'raise', 'clip', 'raise')),
np.ravel_multi_index([1, 1, 0, 2], (4, 3, 7, 12)))
assert_raises(
ValueError, np.ravel_multi_index, [5, 1, -1, 2], (4, 3, 7, 12))
def test_writeability(self):
x, y = np.unravel_index([1, 2, 3], (4, 5))
assert_(x.flags.writeable)
assert_(y.flags.writeable)
def test_0d(self):
x = np.unravel_index(0, ())
assert_equal(x, ())
assert_raises_regex(ValueError, "0d array", np.unravel_index, [0], ())
assert_raises_regex(
ValueError, "out of bounds", np.unravel_index, [1], ())
@pytest.mark.parametrize("mode", ["clip", "wrap", "raise"])
def test_empty_array_ravel(self, mode):
res = np.ravel_multi_index(
np.zeros((3, 0), dtype=np.intp), (2, 1, 0), mode=mode)
assert(res.shape == (0,))
with assert_raises(ValueError):
np.ravel_multi_index(
np.zeros((3, 1), dtype=np.intp), (2, 1, 0), mode=mode)
def test_empty_array_unravel(self):
res = np.unravel_index(np.zeros(0, dtype=np.intp), (2, 1, 0))
assert(len(res) == 3)
assert(all(a.shape == (0,) for a in res))
with assert_raises(ValueError):
np.unravel_index([1], (2, 1, 0))
class TestGrid:
def test_basic(self):
a = mgrid[-1:1:10j]
b = mgrid[-1:1:0.1]
assert_(a.shape == (10,))
assert_(b.shape == (20,))
assert_(a[0] == -1)
assert_almost_equal(a[-1], 1)
assert_(b[0] == -1)
assert_almost_equal(b[1]-b[0], 0.1, 11)
assert_almost_equal(b[-1], b[0]+19*0.1, 11)
assert_almost_equal(a[1]-a[0], 2.0/9.0, 11)
def test_linspace_equivalence(self):
y, st = np.linspace(2, 10, retstep=True)
assert_almost_equal(st, 8/49.0)
assert_array_almost_equal(y, mgrid[2:10:50j], 13)
def test_nd(self):
c = mgrid[-1:1:10j, -2:2:10j]
d = mgrid[-1:1:0.1, -2:2:0.2]
assert_(c.shape == (2, 10, 10))
assert_(d.shape == (2, 20, 20))
assert_array_equal(c[0][0, :], -np.ones(10, 'd'))
assert_array_equal(c[1][:, 0], -2*np.ones(10, 'd'))
assert_array_almost_equal(c[0][-1, :], np.ones(10, 'd'), 11)
assert_array_almost_equal(c[1][:, -1], 2*np.ones(10, 'd'), 11)
assert_array_almost_equal(d[0, 1, :] - d[0, 0, :], 0.1*np.ones(20, 'd'), 11)
assert_array_almost_equal(d[1, :, 1] - d[1, :, 0], 0.2*np.ones(20, 'd'), 11)
def test_sparse(self):
grid_full = mgrid[-1:1:10j, -2:2:10j]
grid_sparse = ogrid[-1:1:10j, -2:2:10j]
grid_broadcast = np.broadcast_arrays(*grid_sparse)
for f, b in zip(grid_full, grid_broadcast):
assert_equal(f, b)
@pytest.mark.parametrize("start, stop, step, expected", [
(None, 10, 10j, (200, 10)),
(-10, 20, None, (1800, 30)),
])
def test_mgrid_size_none_handling(self, start, stop, step, expected):
grid = mgrid[start:stop:step, start:stop:step]
grid_small = mgrid[start:stop:step]
assert_equal(grid.size, expected[0])
assert_equal(grid_small.size, expected[1])
def test_accepts_npfloating(self):
grid64 = mgrid[0.1:0.33:0.1, ]
grid32 = mgrid[np.float32(0.1):np.float32(0.33):np.float32(0.1), ]
assert_array_almost_equal(grid64, grid32)
assert grid32.dtype == np.float32
grid64 = mgrid[0.1:0.33:0.1]
grid32 = mgrid[np.float32(0.1):np.float32(0.33):np.float32(0.1)]
assert_(grid32.dtype == np.float64)
assert_array_almost_equal(grid64, grid32
def test_accepts_longdouble(self):
grid64 = mgrid[0.1:0.33:0.1, ]
grid128 = mgrid[
np.longdouble(0.1):np.longdouble(0.33):np.longdouble(0.1),
]
assert_(grid128.dtype == np.longdouble)
assert_array_almost_equal(grid64, grid128)
grid128c_a = mgrid[0:np.longdouble(1):3.4j]
grid128c_b = mgrid[0:np.longdouble(1):3.4j, ]
assert_(grid128c_a.dtype == grid128c_b.dtype == np.longdouble)
assert_array_equal(grid128c_a, grid128c_b[0])
grid64 = mgrid[0.1:0.33:0.1]
grid128 = mgrid[
np.longdouble(0.1):np.longdouble(0.33):np.longdouble(0.1)
]
assert_(grid128.dtype == np.longdouble)
assert_array_almost_equal(grid64, grid128)
def test_accepts_npcomplexfloating(self):
assert_array_almost_equal(
mgrid[0.1:0.3:3j, ], mgrid[0.1:0.3:np.complex64(3j), ]
)
assert_array_almost_equal(
mgrid[0.1:0.3:3j], mgrid[0.1:0.3:np.complex64(3j)]
)
grid64_a = mgrid[0.1:0.3:3.3j]
grid64_b = mgrid[0.1:0.3:3.3j, ][0]
assert_(grid64_a.dtype == grid64_b.dtype == np.float64)
assert_array_equal(grid64_a, grid64_b)
grid128_a = mgrid[0.1:0.3:np.clongdouble(3.3j)]
grid128_b = mgrid[0.1:0.3:np.clongdouble(3.3j), ][0]
assert_(grid128_a.dtype == grid128_b.dtype == np.longdouble)
assert_array_equal(grid64_a, grid128_b)
class TestConcatenator:
def test_1d(self):
assert_array_equal(r_[1, 2, 3, 4, 5, 6], np.array([1, 2, 3, 4, 5, 6]))
b = np.ones(5)
c = r_[b, 0, 0, b]
assert_array_equal(c, [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1])
def test_mixed_type(self):
g = r_[10.1, 1:10]
assert_(g.dtype == 'f8')
def test_more_mixed_type(self):
g = r_[-10.1, np.array([1]), np.array([2, 3, 4]), 10.0]
assert_(g.dtype == 'f8')
def test_complex_step(self):
g = r_[0:36:100j]
assert_(g.shape == (100,))
g = r_[0:36:np.complex64(100j)]
assert_(g.shape == (100,))
def test_2d(self):
b = np.random.rand(5, 5)
c = np.random.rand(5, 5)
d = r_['1', b, c]
assert_(d.shape == (5, 10))
assert_array_equal(d[:, :5], b)
assert_array_equal(d[:, 5:], c)
d = r_[b, c]
assert_(d.shape == (10, 5))
assert_array_equal(d[:5, :], b)
assert_array_equal(d[5:, :], c)
def test_0d(self):
assert_equal(r_[0, np.array(1), 2], [0, 1, 2])
assert_equal(r_[[0, 1, 2], np.array(3)], [0, 1, 2, 3])
assert_equal(r_[np.array(0), [1, 2, 3]], [0, 1, 2, 3])
class TestNdenumerate:
def test_basic(self):
a = np.array([[1, 2], [3, 4]])
assert_equal(list(ndenumerate(a)),
[((0, 0), 1), ((0, 1), 2), ((1, 0), 3), ((1, 1), 4)])
class TestIndexExpression:
def test_regression_1(self):
a = np.arange(2)
assert_equal(a[:-1], a[s_[:-1]])
assert_equal(a[:-1], a[index_exp[:-1]])
def test_simple_1(self):
a = np.random.rand(4, 5, 6)
assert_equal(a[:, :3, [1, 2]], a[index_exp[:, :3, [1, 2]]])
assert_equal(a[:, :3, [1, 2]], a[s_[:, :3, [1, 2]]])
class TestIx_:
def test_regression_1(self):
a, = np.ix_(range(0))
assert_equal(a.dtype, np.intp)
a, = np.ix_([])
assert_equal(a.dtype, np.intp)
a, = np.ix_(np.array([], dtype=np.float32))
assert_equal(a.dtype, np.float32)
def test_shape_and_dtype(self):
sizes = (4, 5, 3, 2)
for func in (range, np.arange):
arrays = np.ix_(*[func(sz) for sz in sizes])
for k, (a, sz) in enumerate(zip(arrays, sizes)):
assert_equal(a.shape[k], sz)
assert_(all(sh == 1 for j, sh in enumerate(a.shape) if j != k))
assert_(np.issubdtype(a.dtype, np.integer))
def test_bool(self):
bool_a = [True, False, True, True]
int_a, = np.nonzero(bool_a)
assert_equal(np.ix_(bool_a)[0], int_a)
def test_1d_only(self):
idx2d = [[1, 2, 3], [4, 5, 6]]
assert_raises(ValueError, np.ix_, idx2d)
def test_repeated_input(self):
length_of_vector = 5
x = np.arange(length_of_vector)
out = ix_(x, x)
assert_equal(out[0].shape, (length_of_vector, 1))
assert_equal(out[1].shape, (1, length_of_vector))
assert_equal(x.shape, (length_of_vector,))
def test_c_():
a = c_[np.array([[1, 2, 3]]), 0, 0, np.array([[4, 5, 6]])]
assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]])
class TestFillDiagonal:
def test_basic(self):
a = np.zeros((3, 3), int)
fill_diagonal(a, 5)
assert_array_equal(
a, np.array([[5, 0, 0],
[0, 5, 0],
[0, 0, 5]])
)
def test_tall_matrix(self):
a = np.zeros((10, 3), int)
fill_diagonal(a, 5)
assert_array_equal(
a, np.array([[5, 0, 0],
[0, 5, 0],
[0, 0, 5],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]])
)
def test_tall_matrix_wrap(self):
a = np.zeros((10, 3), int)
fill_diagonal(a, 5, True)
assert_array_equal(
a, np.array([[5, 0, 0],
[0, 5, 0],
[0, 0, 5],
[0, 0, 0],
[5, 0, 0],
[0, 5, 0],
[0, 0, 5],
[0, 0, 0],
[5, 0, 0],
[0, 5, 0]])
)
def test_wide_matrix(self):
a = np.zeros((3, 10), int)
fill_diagonal(a, 5)
assert_array_equal(
a, np.array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 5, 0, 0, 0, 0, 0, 0, 0]])
)
def test_operate_4d_array(self):
a = np.zeros((3, 3, 3, 3), int)
fill_diagonal(a, 4)
i = np.array([0, 1, 2])
assert_equal(np.where(a != 0), (i, i, i, i))
def test_low_dim_handling(self):
a = np.zeros(3, int)
with assert_raises_regex(ValueError, "at least 2-d"):
fill_diagonal(a, 5)
def test_hetero_shape_handling(self):
a = np.zeros((3,3,7,3), int)
with assert_raises_regex(ValueError, "equal length"):
fill_diagonal(a, 2)
def test_diag_indices():
di = diag_indices(4)
a = np.array([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]])
a[di] = 100
assert_array_equal(
a, np.array([[100, 2, 3, 4],
[5, 100, 7, 8],
[9, 10, 100, 12],
[13, 14, 15, 100]])
)
d3 = diag_indices(2, 3)
a = np.zeros((2, 2, 2), int)
a[d3] = 1
assert_array_equal(
a, np.array([
[[1, 0],
[0, 0]],
[[0, 0],
[0, 1]]
])
)
class TestDiagIndicesFrom:
def test_diag_indices_from(self):
x = np.random.random((4, 4))
r, c = diag_indices_from(x)
assert_array_equal(r, np.arange(4))
assert_array_equal(c, np.arange(4))
def test_error_small_input(self):
x = np.ones(7)
with assert_raises_regex(ValueError, "at least 2-d"):
diag_indices_from(x)
def test_error_shape_mismatch(self):
x = np.zeros((3, 3, 2, 3), int)
with assert_raises_regex(ValueError, "equal length"):
diag_indices_from(x)
def test_ndindex():
x = list(ndindex(1, 2, 3))
expected = [ix for ix, e in ndenumerate(np.zeros((1, 2, 3)))]
assert_array_equal(x, expected)
x = list(ndindex((1, 2, 3)))
assert_array_equal(x, expected)
x = list(ndindex((3,)))
assert_array_equal(x, list(ndindex(3)))
x = list(ndindex())
assert_equal(x, [()])
x = list(ndindex(()))
assert_equal(x, [()])
x = list(ndindex(*[0]))
assert_equal(x, [])
.\numpy\numpy\lib\tests\test_io.py
import sys
import gc
import gzip
import os
import threading
import time
import warnings
import re
import pytest
from pathlib import Path
from tempfile import NamedTemporaryFile
from io import BytesIO, StringIO
from datetime import datetime
import locale
from multiprocessing import Value, get_context
from ctypes import c_bool
import numpy as np
import numpy.ma as ma
from numpy.exceptions import VisibleDeprecationWarning
from numpy.lib._iotools import ConverterError, ConversionWarning
from numpy.lib import _npyio_impl
from numpy.lib._npyio_impl import recfromcsv, recfromtxt
from numpy.ma.testutils import assert_equal
from numpy.testing import (
assert_warns, assert_, assert_raises_regex, assert_raises,
assert_allclose, assert_array_equal, temppath, tempdir, IS_PYPY,
HAS_REFCOUNT, suppress_warnings, assert_no_gc_cycles, assert_no_warnings,
break_cycles, IS_WASM
)
from numpy.testing._private.utils import requires_memory
from numpy._utils import asbytes
def roundtrip(self, save_func, *args, **kwargs):
"""
save_func : callable
用于将数组保存到文件的函数。
file_on_disk : bool
如果为 True,则将文件保存在磁盘上,而不是在字符串缓冲区中。
save_kwds : dict
传递给 `save_func` 的参数。
load_kwds : dict
传递给 `numpy.load` 的参数。
args : tuple of arrays
要保存到文件的数组。
"""
save_kwds = kwargs.get('save_kwds', {})
load_kwds = kwargs.get('load_kwds', {"allow_pickle": True})
file_on_disk = kwargs.get('file_on_disk', False)
if file_on_disk:
target_file = NamedTemporaryFile(delete=False)
load_file = target_file.name
else:
target_file = BytesIO()
load_file = target_file
try:
arr = args
save_func(target_file, *arr, **save_kwds)
target_file.flush()
target_file.seek(0)
if sys.platform == 'win32' and not isinstance(target_file, BytesIO):
target_file.close()
arr_reloaded = np.load(load_file, **load_kwds)
self.arr = arr
self.arr_reloaded = arr_reloaded
finally:
if not isinstance(target_file, BytesIO):
target_file.close()
if 'arr_reloaded' in locals():
if not isinstance(arr_reloaded, np.lib.npyio.NpzFile):
os.remove(target_file.name)
def check_roundtrips(self, a):
self.roundtrip(a)
self.roundtrip(a, file_on_disk=True)
self.roundtrip(np.asfortranarray(a))
self.roundtrip(np.asfortranarray(a), file_on_disk=True)
if a.shape[0] > 1:
self.roundtrip(np.asfortranarray(a)[1:])
self.roundtrip(np.asfortranarray(a)[1:], file_on_disk=True)
def test_array(self):
a = np.array([], float)
self.check_roundtrips(a)
a = np.array([[1, 2], [3, 4]], float)
self.check_roundtrips(a)
a = np.array([[1, 2], [3, 4]], int)
self.check_roundtrips(a)
a = np.array([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]], dtype=np.csingle)
self.check_roundtrips(a)
a = np.array([[1 + 5j, 2 + 6j], [3 + 7j, 4 + 8j]], dtype=np.cdouble)
self.check_roundtrips(a)
def test_array_object(self):
a = np.array([], object)
self.check_roundtrips(a)
a = np.array([[1, 2], [3, 4]], object)
self.check_roundtrips(a)
def test_1D(self):
a = np.array([1, 2, 3, 4], int)
self.roundtrip(a)
def test_mmap(self):
a = np.array([[1, 2.5], [4, 7.3]])
self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'})
a = np.asfortranarray([[1, 2.5], [4, 7.3]])
self.roundtrip(a, file_on_disk=True, load_kwds={'mmap_mode': 'r'})
def test_record(self):
a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
self.check_roundtrips(a)
@pytest.mark.slow
def test_format_2_0(self):
dt = [(("%d" % i) * 100, float) for i in range(500)]
a = np.ones(1000, dtype=dt)
with warnings.catch_warnings(record=True):
warnings.filterwarnings('always', '', UserWarning)
self.check_roundtrips(a)
class TestSaveLoad(RoundtripTest):
def roundtrip(self, *args, **kwargs):
RoundtripTest.roundtrip(self, np.save, *args, **kwargs)
assert_equal(self.arr[0], self.arr_reloaded)
assert_equal(self.arr[0].dtype, self.arr_reloaded.dtype)
assert_equal(self.arr[0].flags.fnc, self.arr_reloaded.flags.fnc)
class TestSavezLoad(RoundtripTest):
def roundtrip(self, *args, **kwargs):
RoundtripTest.roundtrip(self, np.savez, *args, **kwargs)
try:
for n, arr in enumerate(self.arr):
reloaded = self.arr_reloaded['arr_%d' % n]
assert_equal(arr, reloaded)
assert_equal(arr.dtype, reloaded.dtype)
assert_equal(arr.flags.fnc, reloaded.flags.fnc)
finally:
if self.arr_reloaded.fid:
self.arr_reloaded.fid.close()
os.remove(self.arr_reloaded.fid.name)
@pytest.mark.skipif(IS_PYPY, reason="Hangs on PyPy")
@pytest.mark.skipif(not IS_64BIT, reason="Needs 64bit platform")
@pytest.mark.slow
def test_big_arrays(self):
L = (1 << 31) + 100000
a = np.empty(L, dtype=np.uint8)
with temppath(prefix="numpy_test_big_arrays_", suffix=".npz") as tmp:
np.savez(tmp, a=a)
del a
npfile = np.load(tmp)
a = npfile['a']
npfile.close()
del a
def test_multiple_arrays(self):
a = np.array([[1, 2], [3, 4]], float)
b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex)
self.roundtrip(a, b)
def test_named_arrays(self):
a = np.array([[1, 2], [3, 4]], float)
b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex)
c = BytesIO()
np.savez(c, file_a=a, file_b=b)
c.seek(0)
l = np.load(c)
assert_equal(a, l['file_a'])
assert_equal(b, l['file_b'])
def test_tuple_getitem_raises(self):
a = np.array([1, 2, 3])
f = BytesIO()
np.savez(f, a=a)
f.seek(0)
l = np.load(f)
with pytest.raises(KeyError, match="(1, 2)"):
l[1, 2]
def test_BagObj(self):
a = np.array([[1, 2], [3, 4]], float)
b = np.array([[1 + 2j, 2 + 7j], [3 - 6j, 4 + 12j]], complex)
c = BytesIO()
np.savez(c, file_a=a, file_b=b)
c.seek(0)
l = np.load(c)
assert_equal(sorted(dir(l.f)), ['file_a', 'file_b'])
assert_equal(a, l.f.file_a)
assert_equal(b, l.f.file_b)
@pytest.mark.skipif(IS_WASM, reason="Cannot start thread")
def test_savez_filename_clashes(self):
def writer(error_list):
with temppath(suffix='.npz') as tmp:
arr = np.random.randn(500, 500)
try:
np.savez(tmp, arr=arr)
except OSError as err:
error_list.append(err)
errors = []
threads = [threading.Thread(target=writer, args=(errors,))
for j in range(3)]
for t in threads:
t.start()
for t in threads:
t.join()
if errors:
raise AssertionError(errors)
def test_not_closing_opened_fid(self):
with temppath(suffix='.npz') as tmp:
with open(tmp, 'wb') as fp:
np.savez(fp, data='LOVELY LOAD')
with open(tmp, 'rb', 10000) as fp:
fp.seek(0)
assert_(not fp.closed)
np.load(fp)['data']
assert_(not fp.closed)
fp.seek(0)
assert_(not fp.closed)
@pytest.mark.slow_pypy
def test_closing_fid(self):
with temppath(suffix='.npz') as tmp:
np.savez(tmp, data='LOVELY LOAD')
with suppress_warnings() as sup:
sup.filter(ResourceWarning)
for i in range(1, 1025):
try:
np.load(tmp)["data"]
except Exception as e:
msg = "Failed to load data from a file: %s" % e
raise AssertionError(msg)
finally:
if IS_PYPY:
gc.collect()
def test_closing_zipfile_after_load(self):
prefix = 'numpy_test_closing_zipfile_after_load_'
with temppath(suffix='.npz', prefix=prefix) as tmp:
np.savez(tmp, lab='place holder')
data = np.load(tmp)
fp = data.zip.fp
data.close()
assert_(fp.closed)
@pytest.mark.parametrize("count, expected_repr", [
(1, "NpzFile {fname!r} with keys: arr_0"),
(5, "NpzFile {fname!r} with keys: arr_0, arr_1, arr_2, arr_3, arr_4"),
(6, "NpzFile {fname!r} with keys: arr_0, arr_1, arr_2, arr_3, arr_4..."),
])
def test_repr_lists_keys(self, count, expected_repr):
a = np.array([[1, 2], [3, 4]], float)
with temppath(suffix='.npz') as tmp:
np.savez(tmp, *[a]*count)
l = np.load(tmp)
assert repr(l) == expected_repr.format(fname=tmp)
l.close()
class TestSaveTxt:
def test_array(self):
a = np.array([[1, 2], [3, 4]], float)
fmt = "%.18e"
c = BytesIO()
np.savetxt(c, a, fmt=fmt)
c.seek(0)
assert_equal(c.readlines(),
[asbytes((fmt + ' ' + fmt + '\n') % (1, 2)),
asbytes((fmt + ' ' + fmt + '\n') % (3, 4))])
a = np.array([[1, 2], [3, 4]], int)
c = BytesIO()
np.savetxt(c, a, fmt='%d')
c.seek(0)
assert_equal(c.readlines(), [b'1 2\n', b'3 4\n'])
def test_1D(self):
a = np.array([1, 2, 3, 4], int)
c = BytesIO()
np.savetxt(c, a, fmt='%d')
c.seek(0)
lines = c.readlines()
assert_equal(lines, [b'1\n', b'2\n', b'3\n', b'4\n'])
def test_0D_3D(self):
c = BytesIO()
assert_raises(ValueError, np.savetxt, c, np.array(1))
assert_raises(ValueError, np.savetxt, c, np.array([[[1], [2]]]))
def test_structured(self):
a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
c = BytesIO()
np.savetxt(c, a, fmt='%d')
c.seek(0)
assert_equal(c.readlines(), [b'1 2\n', b'3 4\n'])
def test_structured_padded(self):
a = np.array([(1, 2, 3),(4, 5, 6)], dtype=[
('foo', 'i4'), ('bar', 'i4'), ('baz', 'i4')
])
c = BytesIO()
np.savetxt(c, a[['foo', 'baz']], fmt='%d')
c.seek(0)
assert_equal(c.readlines(), [b'1 3\n', b'4 6\n'])
def test_multifield_view(self):
a = np.ones(1, dtype=[('x', 'i4'), ('y', 'i4'), ('z', 'f4')])
v = a[['x', 'z']]
with temppath(suffix='.npy') as path:
path = Path(path)
np.save(path, v)
data = np.load(path)
assert_array_equal(data, v)
def test_delimiter(self):
a = np.array([[1., 2.], [3., 4.]])
c = BytesIO()
np.savetxt(c, a, delimiter=',', fmt='%d')
c.seek(0)
assert_equal(c.readlines(), [b'1,2\n', b'3,4\n'])
def test_format(self):
a = np.array([(1, 2), (3, 4)])
c = BytesIO()
np.savetxt(c, a, fmt=['%02d', '%3.1f'])
c.seek(0)
assert_equal(c.readlines(), [b'01 2.0\n', b'03 4.0\n'])
c = BytesIO()
np.savetxt(c, a, fmt='%02d : %3.1f')
c.seek(0)
lines = c.readlines()
assert_equal(lines, [b'01 : 2.0\n', b'03 : 4.0\n'])
c = BytesIO()
np.savetxt(c, a, fmt='%02d : %3.1f', delimiter=',')
c.seek(0)
lines = c.readlines()
assert_equal(lines, [b'01 : 2.0\n', b'03 : 4.0\n'])
c = BytesIO()
assert_raises(ValueError, np.savetxt, c, a, fmt=99)
def test_header_footer(self):
c = BytesIO()
a = np.array([(1, 2), (3, 4)], dtype=int)
test_header_footer = 'Test header / footer'
np.savetxt(c, a, fmt='%1d', header=test_header_footer)
c.seek(0)
assert_equal(c.read(),
asbytes('# ' + test_header_footer + '\n1 2\n3 4\n'))
c = BytesIO()
np.savetxt(c, a, fmt='%1d', footer=test_header_footer)
c.seek(0)
assert_equal(c.read(),
asbytes('1 2\n3 4\n# ' + test_header_footer + '\n'))
c = BytesIO()
commentstr = '% '
np.savetxt(c, a, fmt='%1d',
header=test_header_footer, comments=commentstr)
c.seek(0)
assert_equal(c.read(),
asbytes(commentstr + test_header_footer + '\n' + '1 2\n3 4\n'))
c = BytesIO()
commentstr = '% '
np.savetxt(c, a, fmt='%1d',
footer=test_header_footer, comments=commentstr)
c.seek(0)
assert_equal(c.read(),
asbytes('1 2\n3 4\n' + commentstr + test_header_footer + '\n'))
@pytest.mark.parametrize("filename_type", [Path, str])
def test_file_roundtrip(self, filename_type):
with temppath() as name:
a = np.array([(1, 2), (3, 4)])
np.savetxt(filename_type(name), a)
b = np.loadtxt(filename_type(name))
assert_array_equal(a, b)
def test_complex_arrays(self):
ncols = 2
nrows = 2
a = np.zeros((ncols, nrows), dtype=np.complex128)
re = np.pi
im = np.e
a[:] = re + 1.0j * im
c = BytesIO()
np.savetxt(c, a, fmt=' %+.3e')
c.seek(0)
lines = c.readlines()
assert_equal(
lines,
[b' ( +3.142e+00+ +2.718e+00j) ( +3.142e+00+ +2.718e+00j)\n',
b' ( +3.142e+00+ +2.718e+00j) ( +3.142e+00+ +2.718e+00j)\n'])
c = BytesIO()
np.savetxt(c, a, fmt=' %+.3e' * 2 * ncols)
c.seek(0)
lines = c.readlines()
assert_equal(
lines,
[b' +3.142e+00 +2.718e+00 +3.142e+00 +2.718e+00\n',
b' +3.142e+00 +2.718e+00 +3.142e+00 +2.718e+00\n'])
c = BytesIO()
np.savetxt(c, a, fmt=['(%.3e%+.3ej)'] * ncols)
c.seek(0)
lines = c.readlines()
assert_equal(
lines,
[b'(3.142e+00+2.718e+00j) (3.142e+00+2.718e+00j)\n',
b'(3.142e+00+2.718e+00j) (3.142e+00+2.718e+00j)\n'])
def test_complex_negative_exponent(self):
ncols = 2
nrows = 2
a = np.zeros((ncols, nrows), dtype=np.complex128)
re = np.pi
im = np.e
a[:] = re - 1.0j * im
c = BytesIO()
np.savetxt(c, a, fmt='%.3e')
c.seek(0)
lines = c.readlines()
assert_equal(
lines,
[b' (3.142e+00-2.718e+00j) (3.142e+00-2.718e+00j)\n',
b' (3.142e+00-2.718e+00j) (3.142e+00-2.718e+00j)\n'])
def test_custom_writer(self):
class CustomWriter(list):
def write(self, text):
self.extend(text.split(b'\n'))
w = CustomWriter()
a = np.array([(1, 2), (3, 4)])
np.savetxt(w, a)
b = np.loadtxt(w)
assert_array_equal(a, b)
def test_unicode(self):
utf8 = b'\xcf\x96'.decode('UTF-8')
a = np.array([utf8], dtype=np.str_)
with tempdir() as tmpdir:
np.savetxt(os.path.join(tmpdir, 'test.csv'), a, fmt=['%s'],
encoding='UTF-8')
def test_unicode_roundtrip(self):
utf8 = b'\xcf\x96'.decode('UTF-8')
a = np.array([utf8], dtype=np.str_)
suffixes = ['', '.gz']
if HAS_BZ2:
suffixes.append('.bz2')
if HAS_LZMA:
suffixes.extend(['.xz', '.lzma'])
with tempdir() as tmpdir:
for suffix in suffixes:
np.savetxt(os.path.join(tmpdir, 'test.csv' + suffix), a,
fmt=['%s'], encoding='UTF-16-LE')
b = np.loadtxt(os.path.join(tmpdir, 'test.csv' + suffix),
encoding='UTF-16-LE', dtype=np.str_)
assert_array_equal(a, b)
def test_unicode_bytestream(self):
utf8 = b'\xcf\x96'.decode('UTF-8')
a = np.array([utf8], dtype=np.str_)
s = BytesIO()
np.savetxt(s, a, fmt=['%s'], encoding='UTF-8')
s.seek(0)
assert_equal(s.read().decode('UTF-8'), utf8 + '\n')
def test_unicode_stringstream(self):
utf8 = b'\xcf\x96'.decode('UTF-8')
a = np.array([utf8], dtype=np.str_)
s = StringIO()
np.savetxt(s, a, fmt=['%s'], encoding='UTF-8')
s.seek(0)
assert_equal(s.read(), utf8 + '\n')
@pytest.mark.parametrize("iotype", [StringIO, BytesIO])
def test_unicode_and_bytes_fmt(self, iotype):
a = np.array([1.])
s = iotype()
np.savetxt(s, a, fmt="%f")
s.seek(0)
if iotype is StringIO:
assert_equal(s.read(), "%f\n" % 1.)
else:
assert_equal(s.read(), b"%f\n" % 1.)
@pytest.mark.skipif(sys.platform=='win32', reason="files>4GB may not work")
@pytest.mark.slow
@requires_memory(free_bytes=7e9)
def test_large_zip(self):
def check_large_zip(memoryerror_raised):
memoryerror_raised.value = False
try:
test_data = np.asarray([np.random.rand(
np.random.randint(50,100),4)
for i in range(800000)], dtype=object)
with tempdir() as tmpdir:
np.savez(os.path.join(tmpdir, 'test.npz'),
test_data=test_data)
except MemoryError:
memoryerror_raised.value = True
raise
memoryerror_raised = Value(c_bool)
ctx = get_context('fork')
p = ctx.Process(target=check_large_zip, args=(memoryerror_raised,))
p.start()
p.join()
if memoryerror_raised.value:
raise MemoryError("Child process raised a MemoryError exception")
if p.exitcode == -9:
pytest.xfail("subprocess got a SIGKILL, apparently free memory was not sufficient")
assert p.exitcode == 0
class LoadTxtBase:
def check_compressed(self, fopen, suffixes):
wanted = np.arange(6).reshape((2, 3))
linesep = ('\n', '\r\n', '\r')
for sep in linesep:
data = '0 1 2' + sep + '3 4 5'
for suffix in suffixes:
with temppath(suffix=suffix) as name:
with fopen(name, mode='wt', encoding='UTF-32-LE') as f:
f.write(data)
res = self.loadfunc(name, encoding='UTF-32-LE')
assert_array_equal(res, wanted)
with fopen(name, "rt", encoding='UTF-32-LE') as f:
res = self.loadfunc(f)
assert_array_equal(res, wanted)
def test_compressed_gzip(self):
self.check_compressed(gzip.open, ('.gz',))
@pytest.mark.skipif(not HAS_BZ2, reason="Needs bz2")
def test_compressed_bz2(self):
self.check_compressed(bz2.open, ('.bz2',))
@pytest.mark.skipif(not HAS_LZMA, reason="Needs lzma")
def test_compressed_lzma(self):
self.check_compressed(lzma.open, ('.xz', '.lzma'))
def test_encoding(self):
with temppath() as path:
with open(path, "wb") as f:
f.write('0.\n1.\n2.'.encode("UTF-16"))
x = self.loadfunc(path, encoding="UTF-16")
assert_array_equal(x, [0., 1., 2.])
def test_stringload(self):
nonascii = b'\xc3\xb6\xc3\xbc\xc3\xb6'.decode("UTF-8")
with temppath() as path:
with open(path, "wb") as f:
f.write(nonascii.encode("UTF-16"))
x = self.loadfunc(path, encoding="UTF-16", dtype=np.str_)
assert_array_equal(x, nonascii)
def test_binary_decode(self):
utf16 = b'\xff\xfeh\x04 \x00i\x04 \x00j\x04'
v = self.loadfunc(BytesIO(utf16), dtype=np.str_, encoding='UTF-16')
assert_array_equal(v, np.array(utf16.decode('UTF-16').split()))
def test_converters_decode(self):
c = TextIO()
c.write(b'\xcf\x96')
c.seek(0)
x = self.loadfunc(c, dtype=np.str_, encoding="bytes",
converters={0: lambda x: x.decode('UTF-8')})
a = np.array([b'\xcf\x96'.decode('UTF-8')])
assert_array_equal(x, a)
def test_converters_nodecode(self):
utf8 = b'\xcf\x96'.decode('UTF-8')
with temppath() as path:
with open(path, 'wt', encoding='UTF-8') as f:
f.write(utf8)
x = self.loadfunc(path, dtype=np.str_,
converters={0: lambda x: x + 't'},
encoding='UTF-8')
a = np.array([utf8 + 't'])
assert_array_equal(x, a)
class TestLoadTxt(LoadTxtBase):
loadfunc = staticmethod(np.loadtxt)
def setup_method(self):
self.orig_chunk = _npyio_impl._loadtxt_chunksize
_npyio_impl._loadtxt_chunksize = 1
def teardown_method(self):
_npyio_impl._loadtxt_chunksize = self.orig_chunk
def test_record(self):
c = TextIO()
c.write('1 2\n3 4')
c.seek(0)
x = np.loadtxt(c, dtype=[('x', np.int32), ('y', np.int32)])
a = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
assert_array_equal(x, a)
d = TextIO()
d.write('M 64 75.0\nF 25 60.0')
d.seek(0)
mydescriptor = {'names': ('gender', 'age', 'weight'),
'formats': ('S1', 'i4', 'f4')}
b = np.array([('M', 64.0, 75.0),
('F', 25.0, 60.0)], dtype=mydescriptor)
y = np.loadtxt(d, dtype=mydescriptor)
assert_array_equal(y, b)
def test_array(self):
c = TextIO()
c.write('1 2\n3 4')
c.seek(0)
x = np.loadtxt(c, dtype=int)
a = np.array([[1, 2], [3, 4]], int)
assert_array_equal(x, a)
c.seek(0)
x = np.loadtxt(c, dtype=float)
a = np.array([[1, 2], [3, 4]], float)
assert_array_equal(x, a)
def test_1D(self):
c = TextIO()
c.write('1\n2\n3\n4\n')
c.seek(0)
x = np.loadtxt(c, dtype=int)
a = np.array([1, 2, 3, 4], int)
assert_array_equal(x, a)
c = TextIO()
c.write('1,2,3,4\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',')
a = np.array([1, 2, 3, 4], int)
assert_array_equal(x, a)
def test_missing(self):
c = TextIO()
c.write('1,2,3,,5\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
converters={3: lambda s: int(s or -999)})
a = np.array([1, 2, 3, -999, 5], int)
assert_array_equal(x, a)
def test_converters_with_usecols(self):
c = TextIO()
c.write('1,2,3,,5\n6,7,8,9,10\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
converters={3: lambda s: int(s or -999)},
usecols=(1, 3,))
a = np.array([[2, -999], [7, 9]], int)
assert_array_equal(x, a)
def test_comments_unicode(self):
c = TextIO()
c.write('# comment\n1,2,3,5\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
comments='#')
a = np.array([1, 2, 3, 5], int)
assert_array_equal(x, a)
def test_comments_byte(self):
c = TextIO()
c.write('# comment\n1,2,3,5\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
comments=b'#')
a = np.array([1, 2, 3, 5], int)
def test_comments_multiple(self):
c = TextIO()
c.write('# comment\n1,2,3\n@ comment2\n4,5,6 // comment3')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
comments=['#', '@', '//'])
a = np.array([[1, 2, 3], [4, 5, 6]], int)
assert_array_equal(x, a)
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
reason="PyPy bug in error formatting")
def test_comments_multi_chars(self):
c = TextIO()
c.write('/* comment\n1,2,3,5\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
comments='/*')
a = np.array([1, 2, 3, 5], int)
assert_array_equal(x, a)
c = TextIO()
c.write('*/ comment\n1,2,3,5\n')
c.seek(0)
assert_raises(ValueError, np.loadtxt, c, dtype=int, delimiter=',',
comments='/*')
def test_skiprows(self):
c = TextIO()
c.write('comment\n1,2,3,5\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
skiprows=1)
a = np.array([1, 2, 3, 5], int)
assert_array_equal(x, a)
c = TextIO()
c.write('# comment\n1,2,3,5\n')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
skiprows=1)
a = np.array([1, 2, 3, 5], int)
assert_array_equal(x, a)
def test_usecols(self):
a = np.array([[1, 2], [3, 4]], float)
c = BytesIO()
np.savetxt(c, a)
c.seek(0)
x = np.loadtxt(c, dtype=float, usecols=(1,))
assert_array_equal(x, a[:, 1])
a = np.array([[1, 2, 3], [3, 4, 5]], float)
c = BytesIO()
np.savetxt(c, a)
c.seek(0)
x = np.loadtxt(c, dtype=float, usecols=(1, 2))
assert_array_equal(x, a[:, 1:])
c.seek(0)
x = np.loadtxt(c, dtype=float, usecols=np.array([1, 2]))
assert_array_equal(x, a[:, 1:])
for int_type in [int, np.int8, np.int16,
np.int32, np.int64, np.uint8, np.uint16,
np.uint32, np.uint64]:
to_read = int_type(1)
c.seek(0)
x = np.loadtxt(c, dtype=float, usecols=to_read)
assert_array_equal(x, a[:, 1])
class CrazyInt:
def __index__(self):
return 1
crazy_int = CrazyInt()
c.seek(0)
x = np.loadtxt(c, dtype=float, usecols=crazy_int)
assert_array_equal(x, a[:, 1])
c.seek(0)
x = np.loadtxt(c, dtype=float, usecols=(crazy_int,))
assert_array_equal(x, a[:, 1])
data = '''JOE 70.1 25.3
BOB 60.5 27.9
'''
c = TextIO(data)
names = ['stid', 'temp']
dtypes = ['S4', 'f8']
arr = np.loadtxt(c, usecols=(0, 2), dtype=list(zip(names, dtypes)))
assert_equal(arr['stid'], [b"JOE", b"BOB"])
assert_equal(arr['temp'], [25.3, 27.9])
c.seek(0)
bogus_idx = 1.5
assert_raises_regex(
TypeError,
'^usecols must be.*%s' % type(bogus_idx).__name__,
np.loadtxt, c, usecols=bogus_idx
)
assert_raises_regex(
TypeError,
'^usecols must be.*%s' % type(bogus_idx).__name__,
np.loadtxt, c, usecols=[0, bogus_idx, 0]
)
def test_bad_usecols(self):
with pytest.raises(OverflowError):
np.loadtxt(["1\n"], usecols=[2**64], delimiter=",")
with pytest.raises((ValueError, OverflowError)):
np.loadtxt(["1\n"], usecols=[2**62], delimiter=",")
with pytest.raises(TypeError,
match="If a structured dtype .*. But 1 usecols were given and "
"the number of fields is 3."):
np.loadtxt(["1,1\n"], dtype="i,2i", usecols=[0], delimiter=",")
def test_fancy_dtype(self):
c = TextIO()
c.write('1,2,3.0\n4,5,6.0\n')
c.seek(0)
dt = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
x = np.loadtxt(c, dtype=dt, delimiter=',')
a = np.array([(1, (2, 3.0)), (4, (5, 6.0))], dt)
assert_array_equal(x, a)
def test_shaped_dtype(self):
c = TextIO("aaaa 1.0 8.0 1 2 3 4 5 6")
dt = np.dtype([('name', 'S4'), ('x', float), ('y', float),
('block', int, (2, 3))])
x = np.loadtxt(c, dtype=dt)
a = np.array([('aaaa', 1.0, 8.0, [[1, 2, 3], [4, 5, 6]])],
dtype=dt)
assert_array_equal(x, a)
def test_3d_shaped_dtype(self):
c = TextIO("aaaa 1.0 8.0 1 2 3 4 5 6 7 8 9 10 11 12")
dt = np.dtype([('name', 'S4'), ('x', float), ('y', float),
('block', int, (2, 2, 3))])
x = np.loadtxt(c, dtype=dt)
a = np.array([('aaaa', 1.0, 8.0,
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]])],
dtype=dt)
assert_array_equal(x, a)
def test_str_dtype(self):
c = ["str1", "str2"]
for dt in (str, np.bytes_):
a = np.array(["str1", "str2"], dtype=dt)
x = np.loadtxt(c, dtype=dt)
assert_array_equal(x, a)
def test_empty_file(self):
with pytest.warns(UserWarning, match="input contained no data"):
c = TextIO()
x = np.loadtxt(c)
assert_equal(x.shape, (0,))
x = np.loadtxt(c, dtype=np.int64)
assert_equal(x.shape, (0,))
assert_(x.dtype == np.int64)
def test_unused_converter(self):
c = TextIO()
c.writelines(['1 21\n', '3 42\n'])
c.seek(0)
data = np.loadtxt(c, usecols=(1,),
converters={0: lambda s: int(s, 16)})
assert_array_equal(data, [21, 42])
c.seek(0)
data = np.loadtxt(c, usecols=(1,),
converters={1: lambda s: int(s, 16)})
assert_array_equal(data, [33, 66])
def test_dtype_with_object(self):
data = """ 1; 2001-01-01
2; 2002-01-31 """
ndtype = [('idx', int), ('code', object)]
func = lambda s: strptime(s.strip(), "%Y-%m-%d")
converters = {1: func}
test = np.loadtxt(TextIO(data), delimiter=";", dtype=ndtype,
converters=converters)
control = np.array(
[(1, datetime(2001, 1, 1)), (2, datetime(2002, 1, 31))],
dtype=ndtype)
assert_equal(test, control)
def test_uint64_type(self):
tgt = (9223372043271415339, 9223372043271415853)
c = TextIO()
c.write("%s %s" % tgt)
c.seek(0)
res = np.loadtxt(c, dtype=np.uint64)
assert_equal(res, tgt)
def test_int64_type(self):
tgt = (-9223372036854775807, 9223372036854775807)
c = TextIO()
c.write("%s %s" % tgt)
c.seek(0)
res = np.loadtxt(c, dtype=np.int64)
assert_equal(res, tgt)
def test_from_float_hex(self):
tgt = np.logspace(-10, 10, 5).astype(np.float32)
tgt = np.hstack((tgt, -tgt)).astype(float)
inp = '\n'.join(map(float.hex, tgt))
c = TextIO()
c.write(inp)
for dt in [float, np.float32]:
c.seek(0)
res = np.loadtxt(
c, dtype=dt, converters=float.fromhex, encoding="latin1")
assert_equal(res, tgt, err_msg="%s" % dt)
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
reason="PyPy bug in error formatting")
def test_default_float_converter_no_default_hex_conversion(self):
"""
确保 fromhex 只用于带有正确前缀的值,并且不会默认调用。与 gh-19598 相关的回归测试。
"""
c = TextIO("a b c")
with pytest.raises(ValueError,
match=".*convert string 'a' to float64 at row 0, column 1"):
np.loadtxt(c)
@pytest.mark.skipif(IS_PYPY and sys.implementation.version <= (7, 3, 8),
reason="PyPy bug in error formatting")
def test_default_float_converter_exception(self):
"""
确保在浮点数转换失败时引发的异常消息是正确的。与 gh-19598 相关的回归测试。
"""
c = TextIO("qrs tuv")
with pytest.raises(ValueError,
match="could not convert string 'qrs' to float64"):
np.loadtxt(c)
def test_from_complex(self):
tgt = (complex(1, 1), complex(1, -1))
c = TextIO()
c.write("%s %s" % tgt)
c.seek(0)
res = np.loadtxt(c, dtype=complex)
assert_equal(res, tgt)
def test_complex_misformatted(self):
a = np.zeros((2, 2), dtype=np.complex128)
re = np.pi
im = np.e
a[:] = re - 1.0j * im
c = BytesIO()
np.savetxt(c, a, fmt='%.16e')
c.seek(0)
txt = c.read()
c.seek(0)
txt_bad = txt.replace(b'e+00-', b'e00+-')
assert_(txt_bad != txt)
c.write(txt_bad)
c.seek(0)
res = np.loadtxt(c, dtype=complex)
assert_equal(res, a)
def test_universal_newline(self):
with temppath() as name:
with open(name, 'w') as f:
f.write('1 21\r3 42\r')
data = np.loadtxt(name)
assert_array_equal(data, [[1, 21], [3, 42]])
def test_empty_field_after_tab(self):
c = TextIO()
c.write('1 \t2 \t3\tstart \n4\t5\t6\t \n7\t8\t9.5\t')
c.seek(0)
dt = {'names': ('x', 'y', 'z', 'comment'),
'formats': ('<i4', '<i4', '<f4', '|S8')}
x = np.loadtxt(c, dtype=dt, delimiter='\t')
a = np.array([b'start ', b' ', b''])
assert_array_equal(x['comment'], a)
def test_unpack_structured(self):
txt = TextIO("M 21 72\nF 35 58")
dt = {'names': ('a', 'b', 'c'), 'formats': ('|S1', '<i4', '<f4')}
a, b, c = np.loadtxt(txt, dtype=dt, unpack=True)
assert_(a.dtype.str == '|S1')
assert_(b.dtype.str == '<i4')
assert_(c.dtype.str == '<f4')
assert_array_equal(a, np.array([b'M', b'F']))
assert_array_equal(b, np.array([21, 35]))
assert_array_equal(c, np.array([72., 58.]))
def test_ndmin_keyword(self):
c = TextIO()
c.write('1,2,3\n4,5,6')
c.seek(0)
assert_raises(ValueError, np.loadtxt, c, ndmin=3)
c.seek(0)
assert_raises(ValueError, np.loadtxt, c, ndmin=1.5)
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',', ndmin=1)
a = np.array([[1, 2, 3], [4, 5, 6]])
assert_array_equal(x, a)
d = TextIO()
d.write('0,1,2')
d.seek(0)
x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=2)
assert_(x.shape == (1, 3))
d.seek(0)
x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=1)
assert_(x.shape == (3,))
d.seek(0)
x = np.loadtxt(d, dtype=int, delimiter=',', ndmin=0)
assert_(x.shape == (3,))
e = TextIO()
e.write('0\n1\n2')
e.seek(0)
x = np.loadtxt(e, dtype=int, delimiter=',', ndmin=2)
assert_(x.shape == (3, 1))
e.seek(0)
x = np.loadtxt(e, dtype=int, delimiter=',', ndmin=1)
assert_(x.shape == (3,))
e.seek(0)
x = np.loadtxt(e, dtype=int, delimiter=',', ndmin=0)
assert_(x.shape == (3,))
with pytest.warns(UserWarning, match="input contained no data"):
f = TextIO()
assert_(np.loadtxt(f, ndmin=2).shape == (0, 1,))
assert_(np.loadtxt(f, ndmin=1).shape == (0,))
def test_generator_source(self):
def count():
for i in range(10):
yield "%d" % i
res = np.loadtxt(count())
assert_array_equal(res, np.arange(10))
def test_bad_line(self):
c = TextIO()
c.write('1 2 3\n4 5 6\n2 3')
c.seek(0)
assert_raises_regex(ValueError, "3", np.loadtxt, c)
def test_none_as_string(self):
c = TextIO()
c.write('100,foo,200\n300,None,400')
c.seek(0)
dt = np.dtype([('x', int), ('a', 'S10'), ('y', int)])
np.loadtxt(c, delimiter=',', dtype=dt, comments=None)
@pytest.mark.skipif(locale.getpreferredencoding() == 'ANSI_X3.4-1968',
reason="Wrong preferred encoding")
def test_binary_load(self):
butf8 = b"5,6,7,\xc3\x95scarscar\r\n15,2,3,hello\r\n"\
b"20,2,3,\xc3\x95scar\r\n"
sutf8 = butf8.decode("UTF-8").replace("\r", "").splitlines()
with temppath() as path:
with open(path, "wb") as f:
f.write(butf8)
with open(path, "rb") as f:
x = np.loadtxt(f, encoding="UTF-8", dtype=np.str_)
assert_array_equal(x, sutf8)
with open(path, "rb") as f:
x = np.loadtxt(f, encoding="UTF-8", dtype="S")
x = [b'5,6,7,\xc3\x95scarscar', b'15,2,3,hello', b'20,2,3,\xc3\x95scar']
assert_array_equal(x, np.array(x, dtype="S"))
def test_max_rows(self):
c = TextIO()
c.write('1,2,3,5\n4,5,7,8\n2,1,4,5')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
max_rows=1)
a = np.array([1, 2, 3, 5], int)
assert_array_equal(x, a)
def test_max_rows_with_skiprows(self):
c = TextIO()
c.write('comments\n1,2,3,5\n4,5,7,8\n2,1,4,5')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
skiprows=1, max_rows=1)
a = np.array([1, 2, 3, 5], int)
assert_array_equal(x, a)
c = TextIO()
c.write('comment\n1,2,3,5\n4,5,7,8\n2,1,4,5')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
skiprows=1, max_rows=2)
a = np.array([[1, 2, 3, 5], [4, 5, 7, 8]], int)
assert_array_equal(x, a)
def test_max_rows_with_read_continuation(self):
c = TextIO()
c.write('1,2,3,5\n4,5,7,8\n2,1,4,5')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
max_rows=2)
a = np.array([[1, 2, 3, 5], [4, 5, 7, 8]], int)
assert_array_equal(x, a)
x = np.loadtxt(c, dtype=int, delimiter=',')
a = np.array([2, 1, 4, 5], int)
assert_array_equal(x, a)
def test_max_rows_larger(self):
c = TextIO()
c.write('comment\n1,2,3,5\n4,5,7,8\n2,1,4,5')
c.seek(0)
x = np.loadtxt(c, dtype=int, delimiter=',',
skiprows=1, max_rows=6)
a = np.array([[1, 2, 3, 5], [4, 5, 7, 8], [2, 1, 4, 5]], int)
assert_array_equal(x, a)
@pytest.mark.parametrize(["skip", "data"], [
(1, ["ignored\n", "1,2\n", "\n", "3,4\n"]),
(1, ["ignored", "1,2", "", "3,4"]),
(1, StringIO("ignored\n1,2\n\n3,4")),
(0, ["-1,0\n", "1,2\n", "\n", "3,4\n"]),
(0, ["-1,0", "1,2", "", "3,4"]),
(0, StringIO("-1,0\n1,2\n\n3,4"))])
def test_max_rows_empty_lines(self, skip, data):
with pytest.warns(UserWarning,
match=f"Input line 3.*max_rows={3-skip}"):
res = np.loadtxt(data, dtype=int, skiprows=skip, delimiter=",",
max_rows=3-skip)
assert_array_equal(res, [[-1, 0], [1, 2], [3, 4]][skip:])
if isinstance(data, StringIO):
data.seek(0)
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
with pytest.raises(UserWarning):
np.loadtxt(data, dtype=int, skiprows=skip, delimiter=",",
max_rows=3-skip)
class Testfromregex:
def test_record(self):
c = TextIO()
c.write('1.312 foo\n1.534 bar\n4.444 qux')
c.seek(0)
dt = [('num', np.float64), ('val', 'S3')]
x = np.fromregex(c, r"([0-9.]+)\s+(...)", dt)
a = np.array([(1.312, 'foo'), (1.534, 'bar'), (4.444, 'qux')],
dtype=dt)
assert_array_equal(x, a)
def test_record_2(self):
c = TextIO()
c.write('1312 foo\n1534 bar\n4444 qux')
c.seek(0)
dt = [('num', np.int32), ('val', 'S3')]
x = np.fromregex(c, r"(\d+)\s+(...)", dt)
a = np.array([(1312, 'foo'), (1534, 'bar'), (4444, 'qux')],
dtype=dt)
assert_array_equal(x, a)
def test_record_3(self):
c = TextIO()
c.write('1312 foo\n1534 bar\n4444 qux')
c.seek(0)
dt = [('num', np.float64)]
x = np.fromregex(c, r"(\d+)\s+...", dt)
a = np.array([(1312,), (1534,), (4444,)], dtype=dt)
assert_array_equal(x, a)
@pytest.mark.parametrize("path_type", [str, Path])
def test_record_unicode(self, path_type):
utf8 = b'\xcf\x96'
with temppath() as str_path:
path = path_type(str_path)
with open(path, 'wb') as f:
f.write(b'1.312 foo' + utf8 + b' \n1.534 bar\n4.444 qux')
dt = [('num', np.float64), ('val', 'U4')]
x = np.fromregex(path, r"(?u)([0-9.]+)\s+(\w+)", dt, encoding='UTF-8')
a = np.array([(1.312, 'foo' + utf8.decode('UTF-8')), (1.534, 'bar'),
(4.444, 'qux')], dtype=dt)
assert_array_equal(x, a)
regexp = re.compile(r"([0-9.]+)\s+(\w+)", re.UNICODE)
x = np.fromregex(path, regexp, dt, encoding='UTF-8')
assert_array_equal(x, a)
def test_compiled_bytes(self):
regexp = re.compile(b'(\\d)')
c = BytesIO(b'123')
dt = [('num', np.float64)]
a = np.array([1, 2, 3], dtype=dt)
x = np.fromregex(c, regexp, dt)
assert_array_equal(x, a)
def test_bad_dtype_not_structured(self):
regexp = re.compile(b'(\\d)')
c = BytesIO(b'123')
with pytest.raises(TypeError, match='structured datatype'):
np.fromregex(c, regexp, dtype=np.float64)
def test_record(self):
data = TextIO('1 2\n3 4')
test = np.genfromtxt(data, dtype=[('x', np.int32), ('y', np.int32)])
control = np.array([(1, 2), (3, 4)], dtype=[('x', 'i4'), ('y', 'i4')])
assert_equal(test, control)
data = TextIO('M 64.0 75.0\nF 25.0 60.0')
descriptor = {'names': ('gender', 'age', 'weight'),
'formats': ('S1', 'i4', 'f4')}
control = np.array([('M', 64.0, 75.0), ('F', 25.0, 60.0)],
dtype=descriptor)
test = np.genfromtxt(data, dtype=descriptor)
assert_equal(test, control)
def test_array(self):
data = TextIO('1 2\n3 4')
control = np.array([[1, 2], [3, 4]], dtype=int)
test = np.genfromtxt(data, dtype=int)
assert_array_equal(test, control)
data.seek(0)
control = np.array([[1, 2], [3, 4]], dtype=float)
test = np.loadtxt(data, dtype=float)
assert_array_equal(test, control)
def test_1D(self):
control = np.array([1, 2, 3, 4], int)
data = TextIO('1\n2\n3\n4\n')
test = np.genfromtxt(data, dtype=int)
assert_array_equal(test, control)
data = TextIO('1,2,3,4\n')
test = np.genfromtxt(data, dtype=int, delimiter=',')
assert_array_equal(test, control)
def test_comments(self):
control = np.array([1, 2, 3, 5], int)
data = TextIO('# comment\n1,2,3,5\n')
test = np.genfromtxt(data, dtype=int, delimiter=',', comments='#')
assert_equal(test, control)
data = TextIO('1,2,3,5# comment\n')
test = np.genfromtxt(data, dtype=int, delimiter=',', comments='#')
assert_equal(test, control)
def test_skiprows(self):
control = np.array([1, 2, 3, 5], int)
kwargs = dict(dtype=int, delimiter=',')
data = TextIO('comment\n1,2,3,5\n')
test = np.genfromtxt(data, skip_header=1, **kwargs)
assert_equal(test, control)
data = TextIO('# comment\n1,2,3,5\n')
test = np.loadtxt(data, skiprows=1, **kwargs)
assert_equal(test, control)
def test_skip_footer(self):
data = ["# %i" % i for i in range(1, 6)]
data.append("A, B, C")
data.extend(["%i,%3.1f,%03s" % (i, i, i) for i in range(51)])
data[-1] = "99,99"
kwargs = dict(delimiter=",", names=True, skip_header=5, skip_footer=10)
test = np.genfromtxt(TextIO("\n".join(data)), **kwargs)
ctrl = np.array([("%f" % i, "%f" % i, "%f" % i) for i in range(41)],
dtype=[(_, float) for _ in "ABC"])
assert_equal(test, ctrl)
def test_skip_footer_with_invalid(self):
with suppress_warnings() as sup:
sup.filter(ConversionWarning)
basestr = '1 1\n2 2\n3 3\n4 4\n5 \n6 \n7 \n'
assert_raises(ValueError, np.genfromtxt, TextIO(basestr), skip_footer=1)
a = np.genfromtxt(TextIO(basestr), skip_footer=1, invalid_raise=False)
assert_equal(a, np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]))
a = np.genfromtxt(TextIO(basestr), skip_footer=3)
assert_equal(a, np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]]))
basestr = '1 1\n2 \n3 3\n4 4\n5 \n6 6\n7 7\n'
a = np.genfromtxt(TextIO(basestr), skip_footer=1, invalid_raise=False)
assert_equal(a, np.array([[1., 1.], [3., 3.], [4., 4.], [6., 6.]]))
a = np.genfromtxt(TextIO(basestr), skip_footer=3, invalid_raise=False)
assert_equal(a, np.array([[1., 1.], [3., 3.], [4., 4.]]))
def test_header(self):
data = TextIO('gender age weight\nM 64.0 75.0\nF 25.0 60.0')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
test = np.genfromtxt(data, dtype=None, names=True, encoding='bytes')
assert_(w[0].category is VisibleDeprecationWarning)
control = {'gender': np.array([b'M', b'F']),
'age': np.array([64.0, 25.0]),
'weight': np.array([75.0, 60.0])}
assert_equal(test['gender'], control['gender'])
assert_equal(test['age'], control['age'])
assert_equal(test['weight'], control['weight'])
def test_auto_dtype(self):
data = TextIO('A 64 75.0 3+4j True\nBCD 25 60.0 5+6j False')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
test = np.genfromtxt(data, dtype=None, encoding='bytes')
assert_(w[0].category is VisibleDeprecationWarning)
control = [np.array([b'A', b'BCD']),
np.array([64, 25]),
np.array([75.0, 60.0]),
np.array([3 + 4j, 5 + 6j]),
np.array([True, False]), ]
assert_equal(test.dtype.names, ['f0', 'f1', 'f2', 'f3', 'f4'])
for (i, ctrl) in enumerate(control):
assert_equal(test['f%i' % i], ctrl)
def test_auto_dtype_uniform(self):
data = TextIO('1 2 3 4\n5 6 7 8\n')
test = np.genfromtxt(data, dtype=None)
control = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
assert_equal(test, control)
def test_fancy_dtype(self):
data = TextIO('1,2,3.0\n4,5,6.0\n')
fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
test = np.genfromtxt(data, dtype=fancydtype, delimiter=',')
control = np.array([(1, (2, 3.0)), (4, (5, 6.0))], dtype=fancydtype)
assert_equal(test, control)
def test_names_overwrite(self):
descriptor = {'names': ('g', 'a', 'w'),
'formats': ('S1', 'i4', 'f4')}
data = TextIO(b'M 64.0 75.0\nF 25.0 60.0')
names = ('gender', 'age', 'weight')
test = np.genfromtxt(data, dtype=descriptor, names=names)
descriptor['names'] = names
control = np.array([('M', 64.0, 75.0),
('F', 25.0, 60.0)], dtype=descriptor)
assert_equal(test, control)
def test_bad_fname(self):
with pytest.raises(TypeError, match='fname must be a string,'):
np.genfromtxt(123)
def test_commented_header(self):
data = TextIO("""
def test_names_and_comments_hash(self):
# 测试当数据中包含 # 号作为注释时的情况
data = TextIO(b"""
M 21 72.100000
F 35 58.330000
M 33 21.99
""")
# 捕获警告并验证是否触发了 VisibleDeprecationWarning
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 从文本数据中读取并解析为结构化数组
test = np.genfromtxt(data, names=True, dtype=None,
encoding="bytes")
# 断言是否触发了 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 预期的控制数组
ctrl = np.array([('M', 21, 72.1), ('F', 35, 58.33), ('M', 33, 21.99)],
dtype=[('gender', '|S1'), ('age', int), ('weight', float)])
# 断言解析结果与预期控制数组是否相等
assert_equal(test, ctrl)
def test_names_and_comments_none(self):
# 测试当 names 为 True 但 comments 为 None 时的情况 (gh-10780)
data = TextIO('col1 col2\n 1 2\n 3 4')
# 从文本数据中读取并解析为结构化数组,要求字段名为 True,注释符号为 None
test = np.genfromtxt(data, dtype=(int, int), comments=None, names=True)
# 预期的控制数组
control = np.array([(1, 2), (3, 4)], dtype=[('col1', int), ('col2', int)])
# 断言解析结果与预期控制数组是否相等
assert_equal(test, control)
def test_file_is_closed_on_error(self):
# 测试当出现错误时文件是否正确关闭 (gh-13200)
with tempdir() as tmpdir:
fpath = os.path.join(tmpdir, "test.csv")
with open(fpath, "wb") as f:
f.write('\N{GREEK PI SYMBOL}'.encode())
# ResourceWarnings 是由析构函数触发的,因此不会通过常规错误传播检测到
with assert_no_warnings():
# 使用 ASCII 编码尝试读取文件,预期会引发 UnicodeDecodeError
with pytest.raises(UnicodeDecodeError):
np.genfromtxt(fpath, encoding="ascii")
def test_autonames_and_usecols(self):
# 测试 names 和 usecols 的情况
data = TextIO('A B C D\n aaaa 121 45 9.1')
# 捕获警告并验证是否触发了 VisibleDeprecationWarning
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 从文本数据中读取并解析为结构化数组,使用指定的列并自动分配字段名
test = np.genfromtxt(data, usecols=('A', 'C', 'D'),
names=True, dtype=None, encoding="bytes")
# 断言是否触发了 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 预期的控制数组
control = np.array(('aaaa', 45, 9.1),
dtype=[('A', '|S4'), ('C', int), ('D', float)])
# 断言解析结果与预期控制数组是否相等
assert_equal(test, control)
def test_converters_with_usecols(self):
# 测试自定义转换器和usecols的组合
# 创建包含指定数据的文本流对象
data = TextIO('1,2,3,,5\n6,7,8,9,10\n')
# 从文本流中读取数据,并应用自定义转换器和列过滤器
test = np.genfromtxt(data, dtype=int, delimiter=',',
converters={3: lambda s: int(s or -999)},
usecols=(1, 3,))
# 生成预期的控制数组以进行断言比较
control = np.array([[2, -999], [7, 9]], int)
# 断言测试结果与控制数组相等
assert_equal(test, control)
def test_converters_with_usecols_and_names(self):
# 测试名称和usecols
# 创建包含指定数据的文本流对象
data = TextIO('A B C D\n aaaa 121 45 9.1')
# 使用警告记录来捕获警告信息
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 从文本流中读取数据,指定列和名称,应用转换器
test = np.genfromtxt(data, usecols=('A', 'C', 'D'), names=True,
dtype=None, encoding="bytes",
converters={'C': lambda s: 2 * int(s)})
# 断言捕获到的第一个警告是可见性过时警告
assert_(w[0].category is VisibleDeprecationWarning)
# 生成预期的控制数组以进行断言比较
control = np.array(('aaaa', 90, 9.1),
dtype=[('A', '|S4'), ('C', int), ('D', float)])
# 断言测试结果与控制数组相等
assert_equal(test, control)
def test_converters_cornercases(self):
# 测试日期时间转换
# 创建转换器字典,将日期字符串转换为日期时间对象
converter = {
'date': lambda s: strptime(s, '%Y-%m-%d %H:%M:%SZ')}
# 创建包含指定数据的文本流对象
data = TextIO('2009-02-03 12:00:00Z, 72214.0')
# 从文本流中读取数据,指定分隔符和转换器
test = np.genfromtxt(data, delimiter=',', dtype=None,
names=['date', 'stid'], converters=converter)
# 生成预期的控制数组以进行断言比较
control = np.array((datetime(2009, 2, 3), 72214.),
dtype=[('date', np.object_), ('stid', float)])
# 断言测试结果与控制数组相等
assert_equal(test, control)
def test_converters_cornercases2(self):
# 测试日期时间64位转换
# 创建转换器字典,将日期字符串转换为numpy的datetime64对象
converter = {
'date': lambda s: np.datetime64(strptime(s, '%Y-%m-%d %H:%M:%SZ'))}
# 创建包含指定数据的文本流对象
data = TextIO('2009-02-03 12:00:00Z, 72214.0')
# 从文本流中读取数据,指定分隔符和转换器
test = np.genfromtxt(data, delimiter=',', dtype=None,
names=['date', 'stid'], converters=converter)
# 生成预期的控制数组以进行断言比较
control = np.array((datetime(2009, 2, 3), 72214.),
dtype=[('date', 'datetime64[us]'), ('stid', float)])
# 断言测试结果与控制数组相等
assert_equal(test, control)
def test_unused_converter(self):
# 测试未使用的转换器是否被忽略
# 创建包含指定数据的文本流对象
data = TextIO("1 21\n 3 42\n")
# 从文本流中读取数据,指定列过滤器和转换器
test = np.genfromtxt(data, usecols=(1,),
converters={0: lambda s: int(s, 16)})
# 断言测试结果与预期数组相等
assert_equal(test, [21, 42])
#
data.seek(0)
# 从文本流中读取数据,指定列过滤器和转换器
test = np.genfromtxt(data, usecols=(1,),
converters={1: lambda s: int(s, 16)})
# 断言测试结果与预期数组相等
assert_equal(test, [33, 66])
def test_invalid_converter(self):
# 定义一个函数,用于将输入的字符串转换为浮点数,处理包含'r'的情况和不包含'r'的情况
strip_rand = lambda x: float((b'r' in x.lower() and x.split()[-1]) or
(b'r' not in x.lower() and x.strip() or 0.0))
# 定义一个函数,用于将输入的字符串转换为浮点数,处理包含'%'的情况和不包含'%'的情况
strip_per = lambda x: float((b'%' in x.lower() and x.split()[0]) or
(b'%' not in x.lower() and x.strip() or 0.0))
# 创建一个TextIO对象,包含多行文本数据
s = TextIO("D01N01,10/1/2003 ,1 %,R 75,400,600\r\n"
"L24U05,12/5/2003, 2 %,1,300, 150.5\r\n"
"D02N03,10/10/2004,R 1,,7,145.55")
# 定义关键字参数字典
kwargs = dict(
converters={2: strip_per, 3: strip_rand}, delimiter=",",
dtype=None, encoding="bytes")
# 断言调用np.genfromtxt会抛出ConverterError异常
assert_raises(ConverterError, np.genfromtxt, s, **kwargs)
def test_tricky_converter_bug1666(self):
# 测试一些边缘情况
s = TextIO('q1,2\nq3,4')
# 定义一个lambda函数用作转换器,将字符串转换为浮点数
cnv = lambda s: float(s[1:])
# 调用np.genfromtxt解析数据,使用逗号作为分隔符,指定第一列使用cnv函数进行转换
test = np.genfromtxt(s, delimiter=',', converters={0: cnv})
# 定义预期的控制数组
control = np.array([[1., 2.], [3., 4.]])
# 断言test数组与control数组相等
assert_equal(test, control)
def test_dtype_with_converters(self):
# 定义一个字符串,包含数据
dstr = "2009; 23; 46"
# 使用np.genfromtxt解析数据,分号作为分隔符,指定第一列使用bytes函数进行转换
test = np.genfromtxt(TextIO(dstr,),
delimiter=";", dtype=float, converters={0: bytes})
# 定义预期的控制数组,指定dtype为每列的数据类型
control = np.array([('2009', 23., 46)],
dtype=[('f0', '|S4'), ('f1', float), ('f2', float)])
# 断言test数组与control数组相等
assert_equal(test, control)
# 再次调用np.genfromtxt解析数据,指定第一列使用float函数进行转换
test = np.genfromtxt(TextIO(dstr,),
delimiter=";", dtype=float, converters={0: float})
# 定义预期的控制数组,只包含浮点数
control = np.array([2009., 23., 46],)
# 断言test数组与control数组相等
assert_equal(test, control)
@pytest.mark.filterwarnings("ignore:.*recfromcsv.*:DeprecationWarning")
def test_dtype_with_converters_and_usecols(self):
# 定义一个包含数据的字符串
dstr = "1,5,-1,1:1\n2,8,-1,1:n\n3,3,-2,m:n\n"
# 定义一个映射,将字符串映射到整数
dmap = {'1:1':0, '1:n':1, 'm:1':2, 'm:n':3}
# 定义一个dtype,指定每列的名称和数据类型
dtyp = [('e1','i4'),('e2','i4'),('e3','i2'),('n', 'i1')]
# 定义转换器字典,将每列数据按照指定的转换函数进行转换
conv = {0: int, 1: int, 2: int, 3: lambda r: dmap[r.decode()]}
# 调用recfromcsv解析数据,使用逗号作为分隔符,使用conv字典进行数据转换
test = recfromcsv(TextIO(dstr,), dtype=dtyp, delimiter=',',
names=None, converters=conv, encoding="bytes")
# 定义预期的控制数组,生成一个结构化数组
control = np.rec.array([(1,5,-1,0), (2,8,-1,1), (3,3,-2,3)], dtype=dtyp)
# 断言test数组与control数组相等
assert_equal(test, control)
# 重新定义dtype,只包含部分列,并调用recfromcsv解析数据
dtyp = [('e1', 'i4'), ('e2', 'i4'), ('n', 'i1')]
test = recfromcsv(TextIO(dstr,), dtype=dtyp, delimiter=',',
usecols=(0, 1, 3), names=None, converters=conv,
encoding="bytes")
# 定义预期的控制数组,生成一个结构化数组
control = np.rec.array([(1,5,0), (2,8,1), (3,3,3)], dtype=dtyp)
# 断言test数组与control数组相等
assert_equal(test, control)
def test_dtype_with_object(self):
# Test using an explicit dtype with an object
data = """ 1; 2001-01-01
2; 2002-01-31 """
ndtype = [('idx', int), ('code', object)]
func = lambda s: strptime(s.strip(), "%Y-%m-%d")
converters = {1: func}
# 从文本数据创建结构化数组,指定字段数据类型和转换器
test = np.genfromtxt(TextIO(data), delimiter=";", dtype=ndtype,
converters=converters)
# 创建控制用的数组,以验证结果
control = np.array(
[(1, datetime(2001, 1, 1)), (2, datetime(2002, 1, 31))],
dtype=ndtype)
# 断言测试结果与控制结果相等
assert_equal(test, control)
ndtype = [('nest', [('idx', int), ('code', object)])]
# 检测嵌套字段的情况是否抛出预期的异常
with assert_raises_regex(NotImplementedError,
'Nested fields.* not supported.*'):
test = np.genfromtxt(TextIO(data), delimiter=";",
dtype=ndtype, converters=converters)
# 嵌套字段为空时也不支持,检测是否抛出预期的异常
ndtype = [('idx', int), ('code', object), ('nest', [])]
with assert_raises_regex(NotImplementedError,
'Nested fields.* not supported.*'):
test = np.genfromtxt(TextIO(data), delimiter=";",
dtype=ndtype, converters=converters)
def test_dtype_with_object_no_converter(self):
# Object without a converter uses bytes:
# 测试未使用转换器时对象使用字节流的情况
parsed = np.genfromtxt(TextIO("1"), dtype=object)
assert parsed[()] == b"1"
parsed = np.genfromtxt(TextIO("string"), dtype=object)
assert parsed[()] == b"string"
def test_userconverters_with_explicit_dtype(self):
# Test user_converters w/ explicit (standard) dtype
data = TextIO('skip,skip,2001-01-01,1.0,skip')
# 使用用户定义的转换器解析数据,验证结果
test = np.genfromtxt(data, delimiter=",", names=None, dtype=float,
usecols=(2, 3), converters={2: bytes})
control = np.array([('2001-01-01', 1.)],
dtype=[('', '|S10'), ('', float)])
assert_equal(test, control)
def test_utf8_userconverters_with_explicit_dtype(self):
utf8 = b'\xcf\x96'
with temppath() as path:
with open(path, 'wb') as f:
f.write(b'skip,skip,2001-01-01' + utf8 + b',1.0,skip')
# 使用 UTF-8 编码解析包含 UTF-8 数据的文件
test = np.genfromtxt(path, delimiter=",", names=None, dtype=float,
usecols=(2, 3), converters={2: str},
encoding='UTF-8')
control = np.array([('2001-01-01' + utf8.decode('UTF-8'), 1.)],
dtype=[('', '|U11'), ('', float)])
assert_equal(test, control)
def test_spacedelimiter(self):
# Test space delimiter
data = TextIO("1 2 3 4 5\n6 7 8 9 10")
# 使用空格作为分隔符解析数据
test = np.genfromtxt(data)
control = np.array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 10.]])
assert_equal(test, control)
def test_integer_delimiter(self):
# 使用整数作为分隔符进行测试
data = " 1 2 3\n 4 5 67\n890123 4"
# 使用 np.genfromtxt 从 TextIO 对象读取数据,以 3 作为分隔符
test = np.genfromtxt(TextIO(data), delimiter=3)
# 预期的结果数组
control = np.array([[1, 2, 3], [4, 5, 67], [890, 123, 4]])
# 断言测试结果与预期结果相等
assert_equal(test, control)
def test_missing(self):
data = TextIO('1,2,3,,5\n')
# 使用 np.genfromtxt 从 TextIO 对象读取数据,指定数据类型为整数,使用 ',' 作为分隔符,
# 并使用转换器来处理第 3 列的缺失值
test = np.genfromtxt(data, dtype=int, delimiter=',',
converters={3: lambda s: int(s or - 999)})
# 预期的结果数组
control = np.array([1, 2, 3, -999, 5], int)
# 断言测试结果与预期结果相等
assert_equal(test, control)
def test_missing_with_tabs(self):
# 使用制表符作为分隔符进行测试
txt = "1\t2\t3\n\t2\t\n1\t\t3"
# 使用 np.genfromtxt 从 TextIO 对象读取数据,启用掩码,并且不指定数据类型
test = np.genfromtxt(TextIO(txt), delimiter="\t",
usemask=True,)
# 预期的数据数组和掩码数组
ctrl_d = np.array([(1, 2, 3), (np.nan, 2, np.nan), (1, np.nan, 3)],)
ctrl_m = np.array([(0, 0, 0), (1, 0, 1), (0, 1, 0)], dtype=bool)
# 断言测试结果的数据数组和掩码数组与预期结果相等
assert_equal(test.data, ctrl_d)
assert_equal(test.mask, ctrl_m)
def test_usecols(self):
# 测试列的选择
# 选择第一列
control = np.array([[1, 2], [3, 4]], float)
data = TextIO()
# 将控制数据写入 TextIO 对象
np.savetxt(data, control)
data.seek(0)
# 使用 np.genfromtxt 从 TextIO 对象读取数据,指定数据类型为浮点数,并选择使用列 (1,)
test = np.genfromtxt(data, dtype=float, usecols=(1,))
# 断言测试结果与预期结果相等
assert_equal(test, control[:, 1])
#
control = np.array([[1, 2, 3], [3, 4, 5]], float)
data = TextIO()
np.savetxt(data, control)
data.seek(0)
# 使用 np.genfromtxt 从 TextIO 对象读取数据,指定数据类型为浮点数,并选择使用列 (1, 2)
test = np.genfromtxt(data, dtype=float, usecols=(1, 2))
# 断言测试结果与预期结果相等
assert_equal(test, control[:, 1:])
# 使用数组而非元组进行测试
data.seek(0)
test = np.genfromtxt(data, dtype=float, usecols=np.array([1, 2]))
# 断言测试结果与预期结果相等
assert_equal(test, control[:, 1:])
def test_usecols_as_css(self):
# 使用逗号分隔的字符串指定 usecols 进行测试
data = "1 2 3\n4 5 6"
# 使用 np.genfromtxt 从 TextIO 对象读取数据,指定列名为 'a, b, c',并选择使用列 'a, c'
test = np.genfromtxt(TextIO(data),
names="a, b, c", usecols="a, c")
# 预期的结果数组
ctrl = np.array([(1, 3), (4, 6)], dtype=[(_, float) for _ in "ac"])
# 断言测试结果与预期结果相等
assert_equal(test, ctrl)
def test_usecols_with_structured_dtype(self):
# 使用显式结构化数据类型进行 usecols 测试
data = TextIO("JOE 70.1 25.3\nBOB 60.5 27.9")
names = ['stid', 'temp']
dtypes = ['S4', 'f8']
# 使用 np.genfromtxt 从 TextIO 对象读取数据,指定列使用 (0, 2),并使用给定的结构化数据类型
test = np.genfromtxt(
data, usecols=(0, 2), dtype=list(zip(names, dtypes)))
# 断言测试结果的 'stid' 列和 'temp' 列与预期结果相等
assert_equal(test['stid'], [b"JOE", b"BOB"])
assert_equal(test['temp'], [25.3, 27.9])
def test_usecols_with_integer(self):
# 使用整数作为 usecols 进行测试
test = np.genfromtxt(TextIO(b"1 2 3\n4 5 6"), usecols=0)
# 断言测试结果与预期结果相等
assert_equal(test, np.array([1., 4.]))
def test_usecols_with_named_columns(self):
# Test usecols with named columns
ctrl = np.array([(1, 3), (4, 6)], dtype=[('a', float), ('c', float)])
data = "1 2 3\n4 5 6"
kwargs = dict(names="a, b, c")
# 使用 genfromtxt 从文本数据创建 NumPy 数组,仅选择指定列 ('a' 和 'c')
test = np.genfromtxt(TextIO(data), usecols=(0, -1), **kwargs)
assert_equal(test, ctrl)
# 再次使用 genfromtxt,但这次使用列名 ('a' 和 'c') 替代索引位置
test = np.genfromtxt(TextIO(data),
usecols=('a', 'c'), **kwargs)
assert_equal(test, ctrl)
def test_empty_file(self):
# Test that an empty file raises the proper warning.
with suppress_warnings() as sup:
sup.filter(message="genfromtxt: Empty input file:")
data = TextIO()
# 读取空文本时,验证 genfromtxt 是否返回空数组
test = np.genfromtxt(data)
assert_equal(test, np.array([]))
# 当 skip_header > 0 时,再次验证空文本情况
test = np.genfromtxt(data, skip_header=1)
assert_equal(test, np.array([]))
def test_fancy_dtype_alt(self):
# Check that a nested dtype isn't MIA
data = TextIO('1,2,3.0\n4,5,6.0\n')
# 定义一个复杂的 dtype,包含嵌套的字段
fancydtype = np.dtype([('x', int), ('y', [('t', int), ('s', float)])])
# 使用 genfromtxt 读取数据,并验证是否正确创建了复杂 dtype 的数组
test = np.genfromtxt(data, dtype=fancydtype, delimiter=',', usemask=True)
control = ma.array([(1, (2, 3.0)), (4, (5, 6.0))], dtype=fancydtype)
assert_equal(test, control)
def test_shaped_dtype(self):
c = TextIO("aaaa 1.0 8.0 1 2 3 4 5 6")
# 定义一个结构化 dtype,包含一个形状为 (2, 3) 的数组字段
dt = np.dtype([('name', 'S4'), ('x', float), ('y', float),
('block', int, (2, 3))])
# 使用 genfromtxt 读取数据,并验证是否正确创建了结构化 dtype 的数组
x = np.genfromtxt(c, dtype=dt)
a = np.array([('aaaa', 1.0, 8.0, [[1, 2, 3], [4, 5, 6]])],
dtype=dt)
assert_array_equal(x, a)
def test_withmissing(self):
data = TextIO('A,B\n0,1\n2,N/A')
kwargs = dict(delimiter=",", missing_values="N/A", names=True)
# 使用 genfromtxt 读取数据,处理缺失值并创建带有掩码的结构化数组
test = np.genfromtxt(data, dtype=None, usemask=True, **kwargs)
control = ma.array([(0, 1), (2, -1)],
mask=[(False, False), (False, True)],
dtype=[('A', int), ('B', int)])
assert_equal(test, control)
assert_equal(test.mask, control.mask)
#
data.seek(0)
# 再次使用 genfromtxt,处理不同数据类型的列,并生成带有掩码的结构化数组
test = np.genfromtxt(data, usemask=True, **kwargs)
control = ma.array([(0, 1), (2, -1)],
mask=[(False, False), (False, True)],
dtype=[('A', float), ('B', float)])
assert_equal(test, control)
assert_equal(test.mask, control.mask)
def test_user_missing_values(self):
# 创建包含缺失值的测试数据字符串
data = "A, B, C\n0, 0., 0j\n1, N/A, 1j\n-9, 2.2, N/A\n3, -99, 3j"
# 设置基础参数字典
basekwargs = dict(dtype=None, delimiter=",", names=True,)
# 设置数据类型元组
mdtype = [('A', int), ('B', float), ('C', complex)]
# 使用 np.genfromtxt 从数据流中读取数据,并设置 N/A 为缺失值
test = np.genfromtxt(TextIO(data), missing_values="N/A",
**basekwargs)
# 创建控制组数组,用于比较结果
control = ma.array([(0, 0.0, 0j), (1, -999, 1j),
(-9, 2.2, -999j), (3, -99, 3j)],
mask=[(0, 0, 0), (0, 1, 0), (0, 0, 1), (0, 0, 0)],
dtype=mdtype)
# 断言测试结果与控制组相等
assert_equal(test, control)
# 更新 basekwargs 中的 dtype 为 mdtype
basekwargs['dtype'] = mdtype
# 使用 np.genfromtxt 从数据流中读取数据,并设置特定缺失值和掩码
test = np.genfromtxt(TextIO(data),
missing_values={0: -9, 1: -99, 2: -999j}, usemask=True, **basekwargs)
# 更新控制组数组,用于比较结果
control = ma.array([(0, 0.0, 0j), (1, -999, 1j),
(-9, 2.2, -999j), (3, -99, 3j)],
mask=[(0, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 0)],
dtype=mdtype)
# 断言测试结果与控制组相等
assert_equal(test, control)
# 使用 np.genfromtxt 从数据流中读取数据,并设置不同的缺失值和掩码
test = np.genfromtxt(TextIO(data),
missing_values={0: -9, 'B': -99, 'C': -999j},
usemask=True,
**basekwargs)
# 更新控制组数组,用于比较结果
control = ma.array([(0, 0.0, 0j), (1, -999, 1j),
(-9, 2.2, -999j), (3, -99, 3j)],
mask=[(0, 0, 0), (0, 1, 0), (1, 0, 1), (0, 1, 0)],
dtype=mdtype)
# 断言测试结果与控制组相等
assert_equal(test, control)
def test_user_filling_values(self):
# 测试包含缺失值和填充值的情况
ctrl = np.array([(0, 3), (4, -999)], dtype=[('a', int), ('b', int)])
# 创建包含缺失值的测试数据字符串
data = "N/A, 2, 3\n4, ,???"
# 设置关键字参数字典
kwargs = dict(delimiter=",",
dtype=int,
names="a,b,c",
missing_values={0: "N/A", 'b': " ", 2: "???"},
filling_values={0: 0, 'b': 0, 2: -999})
# 使用 np.genfromtxt 从数据流中读取数据,并设置缺失值和填充值
test = np.genfromtxt(TextIO(data), **kwargs)
# 创建控制组数组,用于比较结果
ctrl = np.array([(0, 2, 3), (4, 0, -999)],
dtype=[(_, int) for _ in "abc"])
# 断言测试结果与控制组相等
assert_equal(test, ctrl)
# 使用 np.genfromtxt 从数据流中读取数据,只选择部分列,并设置缺失值和填充值
test = np.genfromtxt(TextIO(data), usecols=(0, -1), **kwargs)
# 创建控制组数组,用于比较结果
ctrl = np.array([(0, 3), (4, -999)], dtype=[(_, int) for _ in "ac"])
# 断言测试结果与控制组相等
assert_equal(test, ctrl)
# 创建另一个包含缺失值的测试数据字符串
data2 = "1,2,*,4\n5,*,7,8\n"
# 使用 np.genfromtxt 从数据流中读取数据,并设置特定缺失值和填充值
test = np.genfromtxt(TextIO(data2), delimiter=',', dtype=int,
missing_values="*", filling_values=0)
# 创建控制组数组,用于比较结果
ctrl = np.array([[1, 2, 0, 4], [5, 0, 7, 8]])
# 断言测试结果与控制组相等
assert_equal(test, ctrl)
# 使用 np.genfromtxt 从数据流中读取数据,并设置特定缺失值和填充值
test = np.genfromtxt(TextIO(data2), delimiter=',', dtype=int,
missing_values="*", filling_values=-1)
# 创建控制组数组,用于比较结果
ctrl = np.array([[1, 2, -1, 4], [5, -1, 7, 8]])
# 断言测试结果与控制组相等
assert_equal(test, ctrl)
def test_withmissing_float(self):
# 创建一个文本输入对象,包含特定的数据
data = TextIO('A,B\n0,1.5\n2,-999.00')
# 使用 np.genfromtxt 从文本输入对象中读取数据,指定参数如下:
# dtype=None 表示数据类型为自动推断
# delimiter=',' 指定字段分隔符为逗号
# missing_values='-999.0' 指定缺失值为 '-999.0'
# names=True 表示第一行包含字段名
# usemask=True 表示使用掩码数组来标记缺失值
test = np.genfromtxt(data, dtype=None, delimiter=',',
missing_values='-999.0', names=True, usemask=True)
# 创建一个控制数据的掩码数组,表示缺失值情况
control = ma.array([(0, 1.5), (2, -1.)],
mask=[(False, False), (False, True)],
dtype=[('A', int), ('B', float)])
# 断言测试结果与控制数据相等
assert_equal(test, control)
# 断言测试结果的掩码数组与控制数据的掩码数组相等
assert_equal(test.mask, control.mask)
def test_with_masked_column_uniform(self):
# 测试具有掩码列的情况
data = TextIO('1 2 3\n4 5 6\n')
# 使用 np.genfromtxt 从文本输入对象中读取数据,指定参数如下:
# dtype=None 表示数据类型为自动推断
# missing_values='2,5' 指定多个缺失值为 '2' 和 '5'
# usemask=True 表示使用掩码数组来标记缺失值
test = np.genfromtxt(data, dtype=None,
missing_values='2,5', usemask=True)
# 创建一个控制数据的掩码数组,表示缺失值情况
control = ma.array([[1, 2, 3], [4, 5, 6]], mask=[[0, 1, 0], [0, 1, 0]])
# 断言测试结果与控制数据相等
assert_equal(test, control)
def test_with_masked_column_various(self):
# 测试具有掩码列的情况
data = TextIO('True 2 3\nFalse 5 6\n')
# 使用 np.genfromtxt 从文本输入对象中读取数据,指定参数如下:
# dtype=None 表示数据类型为自动推断
# missing_values='2,5' 指定多个缺失值为 '2' 和 '5'
# usemask=True 表示使用掩码数组来标记缺失值
test = np.genfromtxt(data, dtype=None,
missing_values='2,5', usemask=True)
# 创建一个控制数据的掩码数组,表示缺失值情况和字段类型
control = ma.array([(1, 2, 3), (0, 5, 6)],
mask=[(0, 1, 0), (0, 1, 0)],
dtype=[('f0', bool), ('f1', bool), ('f2', int)])
# 断言测试结果与控制数据相等
assert_equal(test, control)
def test_invalid_raise(self):
# 测试无效的数据引发异常的情况
data = ["1, 1, 1, 1, 1"] * 50
for i in range(5):
data[10 * i] = "2, 2, 2, 2 2"
data.insert(0, "a, b, c, d, e")
# 创建一个包含指定数据的文本输入对象
mdata = TextIO("\n".join(data))
# 定义关键字参数字典
kwargs = dict(delimiter=",", dtype=None, names=True)
# 定义一个函数 f,该函数调用 np.genfromtxt 从文本输入对象中读取数据
# invalid_raise=False 表示遇到无效数据时不引发异常
def f():
return np.genfromtxt(mdata, invalid_raise=False, **kwargs)
# 断言函数 f 会引发 ConversionWarning 警告
mtest = assert_warns(ConversionWarning, f)
# 断言测试结果长度为 45
assert_equal(len(mtest), 45)
# 断言测试结果与控制数据相等,数据类型为每个字段都是整数 'abcde'
assert_equal(mtest, np.ones(45, dtype=[(_, int) for _ in 'abcde']))
#
mdata.seek(0)
# 断言调用 np.genfromtxt 会引发 ValueError 异常
assert_raises(ValueError, np.genfromtxt, mdata,
delimiter=",", names=True)
def test_invalid_raise_with_usecols(self):
# 测试使用 usecols 参数时,无效的数据引发异常的情况
data = ["1, 1, 1, 1, 1"] * 50
for i in range(5):
data[10 * i] = "2, 2, 2, 2 2"
data.insert(0, "a, b, c, d, e")
# 创建一个包含指定数据的文本输入对象
mdata = TextIO("\n".join(data))
# 定义关键字参数字典
kwargs = dict(delimiter=",", dtype=None, names=True,
invalid_raise=False)
# 定义一个函数 f,该函数调用 np.genfromtxt 从文本输入对象中读取数据
# usecols=(0, 4) 表示仅使用第 0 和第 4 列数据
def f():
return np.genfromtxt(mdata, usecols=(0, 4), **kwargs)
# 断言函数 f 会引发 ConversionWarning 警告
mtest = assert_warns(ConversionWarning, f)
# 断言测试结果长度为 45
assert_equal(len(mtest), 45)
# 断言测试结果与控制数据相等,数据类型为每个字段都是整数 'ae'
assert_equal(mtest, np.ones(45, dtype=[(_, int) for _ in 'ae']))
#
mdata.seek(0)
# 调用 np.genfromtxt 读取指定列数据,无异常引发
mtest = np.genfromtxt(mdata, usecols=(0, 1), **kwargs)
# 断言测试结果长度为 50
assert_equal(len(mtest), 50)
# 创建一个控制数据,包含指定的数据和类型
control = np.ones(50, dtype=[(_, int) for _ in 'ab'])
control[[10 * _ for _ in range(5)]] = (2, 2)
# 断言测试结果与控制数据相等
assert_equal(mtest, control)
def test_inconsistent_dtype(self):
# 测试不一致的数据类型处理
# 创建包含重复数据的列表
data = ["1, 1, 1, 1, -1.1"] * 50
# 将数据列表连接成一个文本流对象
mdata = TextIO("\n".join(data))
# 定义转换器字典,将第4列的数据进行特定格式的转换
converters = {4: lambda x: "(%s)" % x.decode()}
# 构建参数字典,包括分隔符、转换器、数据类型、编码方式等
kwargs = dict(delimiter=",", converters=converters,
dtype=[(_, int) for _ in 'abcde'], encoding="bytes")
# 断言调用 genfromtxt 方法时会引发 ValueError 异常
assert_raises(ValueError, np.genfromtxt, mdata, **kwargs)
def test_default_field_format(self):
# 测试默认字段格式
# 定义包含数据的字符串
data = "0, 1, 2.3\n4, 5, 6.7"
# 创建包含数据的文本流对象
mtest = np.genfromtxt(TextIO(data),
delimiter=",", dtype=None, defaultfmt="f%02i")
# 创建预期的 NumPy 数组对象
ctrl = np.array([(0, 1, 2.3), (4, 5, 6.7)],
dtype=[("f00", int), ("f01", int), ("f02", float)])
# 断言生成的数组与预期的数组相等
assert_equal(mtest, ctrl)
def test_single_dtype_wo_names(self):
# 测试单一数据类型但无字段名
# 定义包含数据的字符串
data = "0, 1, 2.3\n4, 5, 6.7"
# 创建包含数据的文本流对象
mtest = np.genfromtxt(TextIO(data),
delimiter=",", dtype=float, defaultfmt="f%02i")
# 创建预期的 NumPy 数组对象
ctrl = np.array([[0., 1., 2.3], [4., 5., 6.7]], dtype=float)
# 断言生成的数组与预期的数组相等
assert_equal(mtest, ctrl)
def test_single_dtype_w_explicit_names(self):
# 测试单一数据类型且使用显式字段名
# 定义包含数据的字符串
data = "0, 1, 2.3\n4, 5, 6.7"
# 创建包含数据的文本流对象
mtest = np.genfromtxt(TextIO(data),
delimiter=",", dtype=float, names="a, b, c")
# 创建预期的 NumPy 结构化数组对象
ctrl = np.array([(0., 1., 2.3), (4., 5., 6.7)],
dtype=[(_, float) for _ in "abc"])
# 断言生成的结构化数组与预期的结构化数组相等
assert_equal(mtest, ctrl)
def test_single_dtype_w_implicit_names(self):
# 测试单一数据类型且使用隐式字段名
# 定义包含数据的字符串
data = "a, b, c\n0, 1, 2.3\n4, 5, 6.7"
# 创建包含数据的文本流对象
mtest = np.genfromtxt(TextIO(data),
delimiter=",", dtype=float, names=True)
# 创建预期的 NumPy 结构化数组对象
ctrl = np.array([(0., 1., 2.3), (4., 5., 6.7)],
dtype=[(_, float) for _ in "abc"])
# 断言生成的结构化数组与预期的结构化数组相等
assert_equal(mtest, ctrl)
def test_easy_structured_dtype(self):
# 测试简单结构化数据类型
# 定义包含数据的字符串
data = "0, 1, 2.3\n4, 5, 6.7"
# 创建包含数据的文本流对象
mtest = np.genfromtxt(TextIO(data), delimiter=",",
dtype=(int, float, float), defaultfmt="f_%02i")
# 创建预期的 NumPy 结构化数组对象
ctrl = np.array([(0, 1., 2.3), (4, 5., 6.7)],
dtype=[("f_00", int), ("f_01", float), ("f_02", float)])
# 断言生成的结构化数组与预期的结构化数组相等
assert_equal(mtest, ctrl)
def test_autostrip(self):
# 测试自动去除空白功能
data = "01/01/2003 , 1.3, abcde"
kwargs = dict(delimiter=",", dtype=None, encoding="bytes")
# 捕获警告并记录
with warnings.catch_warnings(record=True) as w:
# 总是警告,不管警告内容
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 使用 np.genfromtxt 从 TextIO 对象中读取数据
mtest = np.genfromtxt(TextIO(data), **kwargs)
# 断言第一个警告类型为 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 控制数组,用于断言检查
ctrl = np.array([('01/01/2003 ', 1.3, ' abcde')],
dtype=[('f0', '|S12'), ('f1', float), ('f2', '|S8')])
# 断言 mtest 和 ctrl 相等
assert_equal(mtest, ctrl)
# 再次捕获警告并记录
with warnings.catch_warnings(record=True) as w:
# 总是警告,不管警告内容
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 使用 np.genfromtxt 从 TextIO 对象中读取数据,启用自动去除空白
mtest = np.genfromtxt(TextIO(data), autostrip=True, **kwargs)
# 断言第一个警告类型为 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 控制数组,用于断言检查
ctrl = np.array([('01/01/2003', 1.3, 'abcde')],
dtype=[('f0', '|S10'), ('f1', float), ('f2', '|S5')])
# 断言 mtest 和 ctrl 相等
assert_equal(mtest, ctrl)
def test_replace_space(self):
# 测试 'replace_space' 选项
txt = "A.A, B (B), C:C\n1, 2, 3.14"
# 测试默认选项:将空格替换为 '_',删除非字母数字字符
test = np.genfromtxt(TextIO(txt),
delimiter=",", names=True, dtype=None)
# 控制数据类型
ctrl_dtype = [("AA", int), ("B_B", int), ("CC", float)]
# 控制数组,用于断言检查
ctrl = np.array((1, 2, 3.14), dtype=ctrl_dtype)
# 断言 test 和 ctrl 相等
assert_equal(test, ctrl)
# 测试:不替换空格,不删除字符
test = np.genfromtxt(TextIO(txt),
delimiter=",", names=True, dtype=None,
replace_space='', deletechars='')
# 控制数据类型
ctrl_dtype = [("A.A", int), ("B (B)", int), ("C:C", float)]
# 控制数组,用于断言检查
ctrl = np.array((1, 2, 3.14), dtype=ctrl_dtype)
# 断言 test 和 ctrl 相等
assert_equal(test, ctrl)
# 测试:不删除字符(空格替换为 _)
test = np.genfromtxt(TextIO(txt),
delimiter=",", names=True, dtype=None,
deletechars='')
# 控制数据类型
ctrl_dtype = [("A.A", int), ("B_(B)", int), ("C:C", float)]
# 控制数组,用于断言检查
ctrl = np.array((1, 2, 3.14), dtype=ctrl_dtype)
# 断言 test 和 ctrl 相等
assert_equal(test, ctrl)
def test_replace_space_known_dtype(self):
# 当 dtype != None 时,测试 'replace_space'(以及相关选项)
txt = "A.A, B (B), C:C\n1, 2, 3"
# 默认情况下:将空格替换为 '_',删除非字母数字字符进行测试
test = np.genfromtxt(TextIO(txt),
delimiter=",", names=True, dtype=int)
ctrl_dtype = [("AA", int), ("B_B", int), ("CC", int)]
ctrl = np.array((1, 2, 3), dtype=ctrl_dtype)
assert_equal(test, ctrl)
# 测试:不替换,不删除
test = np.genfromtxt(TextIO(txt),
delimiter=",", names=True, dtype=int,
replace_space='', deletechars='')
ctrl_dtype = [("A.A", int), ("B (B)", int), ("C:C", int)]
ctrl = np.array((1, 2, 3), dtype=ctrl_dtype)
assert_equal(test, ctrl)
# 测试:不删除(空格被替换为 _)
test = np.genfromtxt(TextIO(txt),
delimiter=",", names=True, dtype=int,
deletechars='')
ctrl_dtype = [("A.A", int), ("B_(B)", int), ("C:C", int)]
ctrl = np.array((1, 2, 3), dtype=ctrl_dtype)
assert_equal(test, ctrl)
def test_incomplete_names(self):
# 测试包含不完整名称的情况
data = "A,,C\n0,1,2\n3,4,5"
kwargs = dict(delimiter=",", names=True)
# 使用 dtype=None
ctrl = np.array([(0, 1, 2), (3, 4, 5)],
dtype=[(_, int) for _ in ('A', 'f0', 'C')])
test = np.genfromtxt(TextIO(data), dtype=None, **kwargs)
assert_equal(test, ctrl)
# 使用默认 dtype
ctrl = np.array([(0, 1, 2), (3, 4, 5)],
dtype=[(_, float) for _ in ('A', 'f0', 'C')])
test = np.genfromtxt(TextIO(data), **kwargs)
def test_names_auto_completion(self):
# 确保名称自动完成
data = "1 2 3\n 4 5 6"
test = np.genfromtxt(TextIO(data),
dtype=(int, float, int), names="a")
ctrl = np.array([(1, 2, 3), (4, 5, 6)],
dtype=[('a', int), ('f0', float), ('f1', int)])
assert_equal(test, ctrl)
def test_names_with_usecols_bug1636(self):
# 确保在使用 usecols 参数时选择正确的列名
data = "A,B,C,D,E\n0,1,2,3,4\n0,1,2,3,4\n0,1,2,3,4"
# 控制用于比较的列名列表
ctrl_names = ("A", "C", "E")
# 使用 genfromtxt 函数从文本数据中加载数据,并指定数据类型为整数
test = np.genfromtxt(TextIO(data),
dtype=(int, int, int), delimiter=",",
usecols=(0, 2, 4), names=True)
# 断言加载数据的列名与控制列表相同
assert_equal(test.dtype.names, ctrl_names)
#
# 重新加载数据,这次使用列名字符串而不是索引
test = np.genfromtxt(TextIO(data),
dtype=(int, int, int), delimiter=",",
usecols=("A", "C", "E"), names=True)
# 再次断言加载数据的列名与控制列表相同
assert_equal(test.dtype.names, ctrl_names)
#
# 再次重新加载数据,这次只指定整数数据类型而不指定列名数据类型
test = np.genfromtxt(TextIO(data),
dtype=int, delimiter=",",
usecols=("A", "C", "E"), names=True)
# 最后一次断言加载数据的列名与控制列表相同
assert_equal(test.dtype.names, ctrl_names)
def test_fixed_width_names(self):
# 测试固定宽度文本数据的加载,同时保留列名
data = " A B C\n 0 1 2.3\n 45 67 9."
kwargs = dict(delimiter=(5, 5, 4), names=True, dtype=None)
# 控制用于比较的 NumPy 数组
ctrl = np.array([(0, 1, 2.3), (45, 67, 9.)],
dtype=[('A', int), ('B', int), ('C', float)])
# 使用 genfromtxt 函数加载数据
test = np.genfromtxt(TextIO(data), **kwargs)
# 断言加载的数据与控制数组相同
assert_equal(test, ctrl)
#
# 再次加载数据,这次仅指定一个整数作为分隔符宽度
kwargs = dict(delimiter=5, names=True, dtype=None)
# 重新定义控制数组
ctrl = np.array([(0, 1, 2.3), (45, 67, 9.)],
dtype=[('A', int), ('B', int), ('C', float)])
# 再次使用 genfromtxt 函数加载数据
test = np.genfromtxt(TextIO(data), **kwargs)
# 再次断言加载的数据与控制数组相同
assert_equal(test, ctrl)
def test_filling_values(self):
# 测试处理缺失值
data = b"1, 2, 3\n1, , 5\n0, 6, \n"
kwargs = dict(delimiter=",", dtype=None, filling_values=-999)
# 控制用于比较的 NumPy 数组
ctrl = np.array([[1, 2, 3], [1, -999, 5], [0, 6, -999]], dtype=int)
# 使用 genfromtxt 函数加载数据
test = np.genfromtxt(TextIO(data), **kwargs)
# 断言加载的数据与控制数组相同
assert_equal(test, ctrl)
def test_comments_is_none(self):
# 测试处理 None 类型注释的问题
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 使用 genfromtxt 函数加载数据,指定注释为 None
test = np.genfromtxt(TextIO("test1,testNonetherestofthedata"),
dtype=None, comments=None, delimiter=',',
encoding="bytes")
assert_(w[0].category is VisibleDeprecationWarning)
# 断言加载的数据中的第二个元素为字节字符串 b'testNonetherestofthedata'
assert_equal(test[1], b'testNonetherestofthedata')
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 重新加载数据,这次包含一个空格来触发警告
test = np.genfromtxt(TextIO("test1, testNonetherestofthedata"),
dtype=None, comments=None, delimiter=',',
encoding="bytes")
assert_(w[0].category is VisibleDeprecationWarning)
# 再次断言加载的数据中的第二个元素为字节字符串 b' testNonetherestofthedata'
assert_equal(test[1], b' testNonetherestofthedata')
def test_latin1(self):
# 定义 Latin-1 编码的字节序列
latin1 = b'\xf6\xfc\xf6'
# 定义普通的字节序列
norm = b"norm1,norm2,norm3\n"
# 定义混合 Latin-1 编码的字节序列
enc = b"test1,testNonethe" + latin1 + b",test3\n"
# 构建完整的测试数据流
s = norm + enc + norm
# 捕获警告信息
with warnings.catch_warnings(record=True) as w:
# 设置警告过滤器
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 使用 np.genfromtxt 从 TextIO 流中读取数据,指定参数
test = np.genfromtxt(TextIO(s),
dtype=None, comments=None, delimiter=',',
encoding="bytes")
# 断言捕获到的第一个警告是 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 断言测试结果的特定元素与预期值相等
assert_equal(test[1, 0], b"test1")
assert_equal(test[1, 1], b"testNonethe" + latin1)
assert_equal(test[1, 2], b"test3")
# 使用 Latin-1 编码重新进行数据解析
test = np.genfromtxt(TextIO(s),
dtype=None, comments=None, delimiter=',',
encoding='latin1')
# 断言测试结果的特定元素与预期值相等
assert_equal(test[1, 0], "test1")
assert_equal(test[1, 1], "testNonethe" + latin1.decode('latin1'))
assert_equal(test[1, 2], "test3")
# 再次捕获警告信息
with warnings.catch_warnings(record=True) as w:
# 设置警告过滤器
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 使用 np.genfromtxt 从 TextIO 流中读取数据,指定参数
test = np.genfromtxt(TextIO(b"0,testNonethe" + latin1),
dtype=None, comments=None, delimiter=',',
encoding="bytes")
# 断言捕获到的第一个警告是 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 断言测试结果的特定字段与预期值相等
assert_equal(test['f0'], 0)
assert_equal(test['f1'], b"testNonethe" + latin1)
def test_binary_decode_autodtype(self):
# 定义 UTF-16 编码的字节序列
utf16 = b'\xff\xfeh\x04 \x00i\x04 \x00j\x04'
# 调用被测函数,加载数据并指定参数
v = self.loadfunc(BytesIO(utf16), dtype=None, encoding='UTF-16')
# 断言加载后的数据数组与预期结果相等
assert_array_equal(v, np.array(utf16.decode('UTF-16').split()))
def test_utf8_byte_encoding(self):
# 定义 UTF-8 编码的字节序列
utf8 = b"\xcf\x96"
# 定义普通的字节序列
norm = b"norm1,norm2,norm3\n"
# 定义混合 UTF-8 编码的字节序列
enc = b"test1,testNonethe" + utf8 + b",test3\n"
# 构建完整的测试数据流
s = norm + enc + norm
# 捕获警告信息
with warnings.catch_warnings(record=True) as w:
# 设置警告过滤器
warnings.filterwarnings('always', '', VisibleDeprecationWarning)
# 使用 np.genfromtxt 从 TextIO 流中读取数据,指定参数
test = np.genfromtxt(TextIO(s),
dtype=None, comments=None, delimiter=',',
encoding="bytes")
# 断言捕获到的第一个警告是 VisibleDeprecationWarning
assert_(w[0].category is VisibleDeprecationWarning)
# 定义预期的控制数组
ctl = np.array([
[b'norm1', b'norm2', b'norm3'],
[b'test1', b'testNonethe' + utf8, b'test3'],
[b'norm1', b'norm2', b'norm3']])
# 断言测试结果与预期的控制数组相等
assert_array_equal(test, ctl)
def test_utf8_file(self):
# 定义 UTF-8 编码的特殊字符
utf8 = b"\xcf\x96"
# 使用临时路径创建文件,并写入重复的测试数据行
with temppath() as path:
with open(path, "wb") as f:
f.write((b"test1,testNonethe" + utf8 + b",test3\n") * 2)
# 从文件中读取数据到 NumPy 数组,指定 UTF-8 编码
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',', encoding="UTF-8")
# 创建控制数组 ctl 作为预期输出结果
ctl = np.array([
["test1", "testNonethe" + utf8.decode("UTF-8"), "test3"],
["test1", "testNonethe" + utf8.decode("UTF-8"), "test3"]],
dtype=np.str_)
# 断言测试结果与控制数组相等
assert_array_equal(test, ctl)
# 测试包含混合数据类型的情况
with open(path, "wb") as f:
f.write(b"0,testNonethe" + utf8)
# 重新读取文件到 NumPy 数组,再次指定 UTF-8 编码
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',', encoding="UTF-8")
# 断言字段 'f0' 的值为 0
assert_equal(test['f0'], 0)
# 断言字段 'f1' 的值为 "testNonethe" 加上 UTF-8 解码的特殊字符
assert_equal(test['f1'], "testNonethe" + utf8.decode("UTF-8"))
def test_utf8_file_nodtype_unicode(self):
# 使用 Unicode 字符代表 UTF-8 编码的特殊字符
utf8 = '\u03d6'
latin1 = '\xf6\xfc\xf6'
# 如果无法使用首选编码对 UTF-8 测试字符串进行编码,则跳过测试
try:
encoding = locale.getpreferredencoding()
utf8.encode(encoding)
except (UnicodeError, ImportError):
pytest.skip('Skipping test_utf8_file_nodtype_unicode, '
'unable to encode utf8 in preferred encoding')
# 使用临时路径创建文件,并写入多行文本数据
with temppath() as path:
with open(path, "wt") as f:
f.write("norm1,norm2,norm3\n")
f.write("norm1," + latin1 + ",norm3\n")
f.write("test1,testNonethe" + utf8 + ",test3\n")
# 忽略有关 recfromtxt 的警告信息
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '',
VisibleDeprecationWarning)
# 从文件中读取数据到 NumPy 数组,使用 bytes 编码
test = np.genfromtxt(path, dtype=None, comments=None,
delimiter=',', encoding="bytes")
# 检查是否出现编码未指定警告
assert_(w[0].category is VisibleDeprecationWarning)
# 创建控制数组 ctl 作为预期输出结果
ctl = np.array([
["norm1", "norm2", "norm3"],
["norm1", latin1, "norm3"],
["test1", "testNonethe" + utf8, "test3"]],
dtype=np.str_)
# 断言测试结果与控制数组相等
assert_array_equal(test, ctl)
@pytest.mark.filterwarnings("ignore:.*recfromtxt.*:DeprecationWarning")
# 定义测试函数 `test_recfromtxt`,用于测试 `recfromtxt` 函数
def test_recfromtxt(self):
# 创建包含数据的文本流对象
data = TextIO('A,B\n0,1\n2,3')
# 设置关键字参数字典
kwargs = dict(delimiter=",", missing_values="N/A", names=True)
# 调用 `recfromtxt` 函数,使用给定的参数
test = recfromtxt(data, **kwargs)
# 创建期望结果的 NumPy 数组
control = np.array([(0, 1), (2, 3)],
dtype=[('A', int), ('B', int)])
# 断言 `test` 是一个 `np.recarray` 类型的对象
assert_(isinstance(test, np.recarray))
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
# 创建包含新数据的文本流对象
data = TextIO('A,B\n0,1\n2,N/A')
# 使用额外的关键字参数调用 `recfromtxt` 函数
test = recfromtxt(data, dtype=None, usemask=True, **kwargs)
# 创建期望结果的掩码数组
control = ma.array([(0, 1), (2, -1)],
mask=[(False, False), (False, True)],
dtype=[('A', int), ('B', int)])
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
# 断言 `test.mask` 和 `control.mask` 数组相等
assert_equal(test.mask, control.mask)
# 断言 `test.A` 数组中的值与期望的一致
assert_equal(test.A, [0, 2])
# 使用 pytest 标记忽略特定警告,针对 `recfromcsv` 函数的测试
@pytest.mark.filterwarnings("ignore:.*recfromcsv.*:DeprecationWarning")
def test_recfromcsv(self):
# 创建包含数据的文本流对象
data = TextIO('A,B\n0,1\n2,3')
# 设置关键字参数字典
kwargs = dict(missing_values="N/A", names=True, case_sensitive=True,
encoding="bytes")
# 使用给定参数调用 `recfromcsv` 函数
test = recfromcsv(data, dtype=None, **kwargs)
# 创建期望结果的 NumPy 数组
control = np.array([(0, 1), (2, 3)],
dtype=[('A', int), ('B', int)])
# 断言 `test` 是一个 `np.recarray` 类型的对象
assert_(isinstance(test, np.recarray))
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
# 创建包含新数据的文本流对象
data = TextIO('A,B\n0,1\n2,N/A')
# 使用额外的关键字参数调用 `recfromcsv` 函数
test = recfromcsv(data, dtype=None, usemask=True, **kwargs)
# 创建期望结果的掩码数组
control = ma.array([(0, 1), (2, -1)],
mask=[(False, False), (False, True)],
dtype=[('A', int), ('B', int)])
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
# 断言 `test.mask` 和 `control.mask` 数组相等
assert_equal(test.mask, control.mask)
# 创建包含数据的文本流对象
data = TextIO('A,B\n0,1\n2,3')
# 使用单个关键字参数调用 `recfromcsv` 函数
test = recfromcsv(data, missing_values='N/A',)
# 创建期望结果的 NumPy 数组
control = np.array([(0, 1), (2, 3)],
dtype=[('a', int), ('b', int)])
# 断言 `test` 是一个 `np.recarray` 类型的对象
assert_(isinstance(test, np.recarray))
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
# 创建包含数据的文本流对象
data = TextIO('A,B\n0,1\n2,3')
# 定义新的数据类型
dtype = [('a', int), ('b', float)]
# 使用额外的关键字参数调用 `recfromcsv` 函数
test = recfromcsv(data, missing_values='N/A', dtype=dtype)
# 创建期望结果的 NumPy 数组
control = np.array([(0, 1), (2, 3)],
dtype=dtype)
# 断言 `test` 是一个 `np.recarray` 类型的对象
assert_(isinstance(test, np.recarray))
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
#gh-10394
# 创建包含数据的文本流对象,用于测试特定的转换器
data = TextIO('color\n"red"\n"blue"')
# 使用自定义转换器调用 `recfromcsv` 函数
test = recfromcsv(data, converters={0: lambda x: x.strip('\"')})
# 创建期望结果的 NumPy 数组
control = np.array([('red',), ('blue',)], dtype=[('color', (str, 4))])
# 断言 `test.dtype` 和 `control.dtype` 数组相等
assert_equal(test.dtype, control.dtype)
# 断言 `test` 和 `control` 数组相等
assert_equal(test, control)
def test_max_rows(self):
# Test the `max_rows` keyword argument.
data = '1 2\n3 4\n5 6\n7 8\n9 10\n'
# 创建一个 TextIO 对象,用于模拟数据输入流
txt = TextIO(data)
# 从文本输入流中使用 numpy 读取数据,限制最大读取行数为 3
a1 = np.genfromtxt(txt, max_rows=3)
# 从同一个文本输入流中继续读取数据,未指定 max_rows,因此读取剩余的行
a2 = np.genfromtxt(txt)
# 断言 a1 的结果与预期值相等
assert_equal(a1, [[1, 2], [3, 4], [5, 6]])
# 断言 a2 的结果与预期值相等
assert_equal(a2, [[7, 8], [9, 10]])
# max_rows 参数必须至少为 1,验证是否会引发 ValueError 异常
assert_raises(ValueError, np.genfromtxt, TextIO(data), max_rows=0)
# 包含多个无效行的输入
data = '1 1\n2 2\n0 \n3 3\n4 4\n5 \n6 \n7 \n'
# 从文本输入流中读取最多 2 行数据
test = np.genfromtxt(TextIO(data), max_rows=2)
control = np.array([[1., 1.], [2., 2.]])
assert_equal(test, control)
# 测试关键字冲突
assert_raises(ValueError, np.genfromtxt, TextIO(data), skip_footer=1,
max_rows=4)
# 测试无效值情况
assert_raises(ValueError, np.genfromtxt, TextIO(data), max_rows=4)
# 测试无效值但不抛出异常的情况
with suppress_warnings() as sup:
sup.filter(ConversionWarning)
test = np.genfromtxt(TextIO(data), max_rows=4, invalid_raise=False)
control = np.array([[1., 1.], [2., 2.], [3., 3.], [4., 4.]])
assert_equal(test, control)
test = np.genfromtxt(TextIO(data), max_rows=5, invalid_raise=False)
assert_equal(test, control)
# 带有字段名的结构化数组
data = 'a b\n#c d\n1 1\n2 2\n#0 \n3 3\n4 4\n5 5\n'
# 测试带有标题、字段名和注释的情况
txt = TextIO(data)
test = np.genfromtxt(txt, skip_header=1, max_rows=3, names=True)
control = np.array([(1.0, 1.0), (2.0, 2.0), (3.0, 3.0)],
dtype=[('c', '<f8'), ('d', '<f8')])
assert_equal(test, control)
# 继续读取相同的 "文件",不使用 skip_header 或 names,并使用之前确定的 dtype。
test = np.genfromtxt(txt, max_rows=None, dtype=test.dtype)
control = np.array([(4.0, 4.0), (5.0, 5.0)],
dtype=[('c', '<f8'), ('d', '<f8')])
assert_equal(test, control)
def test_gft_using_filename(self):
# 测试能够从文件名以及文件对象中加载数据
tgt = np.arange(6).reshape((2, 3))
linesep = ('\n', '\r\n', '\r')
for sep in linesep:
data = '0 1 2' + sep + '3 4 5'
# 使用临时文件路径来创建文件,并将数据写入文件中
with temppath() as name:
with open(name, 'w') as f:
f.write(data)
# 从文件中读取数据并使用 numpy 进行处理
res = np.genfromtxt(name)
# 断言读取的结果与目标值相等
assert_array_equal(res, tgt)
def test_gft_from_gzip(self):
# 测试从 gzip 文件中加载数据
wanted = np.arange(6).reshape((2, 3))
linesep = ('\n', '\r\n', '\r')
for sep in linesep:
# 构造包含不同换行符的数据字符串
data = '0 1 2' + sep + '3 4 5'
s = BytesIO()
# 使用 gzip 将数据写入 BytesIO 对象
with gzip.GzipFile(fileobj=s, mode='w') as g:
g.write(asbytes(data))
# 创建临时文件,并将数据写入文件
with temppath(suffix='.gz2') as name:
with open(name, 'w') as f:
f.write(data)
# 使用 np.genfromtxt 读取临时文件,并验证结果与期望值相等
assert_array_equal(np.genfromtxt(name), wanted)
def test_gft_using_generator(self):
# gft 不能处理 Unicode 数据
def count():
for i in range(10):
yield asbytes("%d" % i)
# 使用生成器对象作为输入,验证 np.genfromtxt 的输出结果
res = np.genfromtxt(count())
assert_array_equal(res, np.arange(10))
def test_auto_dtype_largeint(self):
# 对于 numpy/numpy#5635 的回归测试,验证大整数可能引发的 OverflowError
# 测试自动定义输出 dtype
#
# 2**66 = 73786976294838206464 => 应转换为 float
# 2**34 = 17179869184 => 应转换为 int64
# 2**10 = 1024 => 应转换为 int (在 32 位系统上为 int32,在 64 位系统上为 int64)
data = TextIO('73786976294838206464 17179869184 1024')
# 使用 np.genfromtxt 读取文本数据,不指定 dtype
test = np.genfromtxt(data, dtype=None)
# 验证生成的 dtype 的字段名
assert_equal(test.dtype.names, ['f0', 'f1', 'f2'])
# 验证每个字段的 dtype
assert_(test.dtype['f0'] == float)
assert_(test.dtype['f1'] == np.int64)
assert_(test.dtype['f2'] == np.int_)
# 验证字段数据是否正确转换
assert_allclose(test['f0'], 73786976294838206464.)
assert_equal(test['f1'], 17179869184)
assert_equal(test['f2'], 1024)
def test_unpack_float_data(self):
txt = TextIO("1,2,3\n4,5,6\n7,8,9\n0.0,1.0,2.0")
# 使用 np.loadtxt 解析文本数据,以逗号为分隔符,同时进行数据解包
a, b, c = np.loadtxt(txt, delimiter=",", unpack=True)
assert_array_equal(a, np.array([1.0, 4.0, 7.0, 0.0]))
assert_array_equal(b, np.array([2.0, 5.0, 8.0, 1.0]))
assert_array_equal(c, np.array([3.0, 6.0, 9.0, 2.0]))
def test_unpack_structured(self):
# 对于 gh-4341 的回归测试,验证结构化数组的解包功能
txt = TextIO("M 21 72\nF 35 58")
dt = {'names': ('a', 'b', 'c'), 'formats': ('S1', 'i4', 'f4')}
# 使用 np.genfromtxt 解析文本数据,指定 dtype,并进行数据解包
a, b, c = np.genfromtxt(txt, dtype=dt, unpack=True)
assert_equal(a.dtype, np.dtype('S1'))
assert_equal(b.dtype, np.dtype('i4'))
assert_equal(c.dtype, np.dtype('f4'))
assert_array_equal(a, np.array([b'M', b'F']))
assert_array_equal(b, np.array([21, 35]))
assert_array_equal(c, np.array([72., 58.]))
def test_unpack_auto_dtype(self):
# Regression test for gh-4341
# 进行gh-4341的回归测试
# Unpacking should work when dtype=None
# 当dtype=None时,应该可以正常解包
txt = TextIO("M 21 72.\nF 35 58.")
# 创建预期结果,包括字符串数组和数值数组
expected = (np.array(["M", "F"]), np.array([21, 35]), np.array([72., 58.]))
# 使用genfromtxt从文本输入txt中读取数据,指定dtype为None,开启解包模式,并使用utf-8编码
test = np.genfromtxt(txt, dtype=None, unpack=True, encoding="utf-8")
# 遍历预期结果和测试结果,逐一断言数组相等及其数据类型相等
for arr, result in zip(expected, test):
assert_array_equal(arr, result)
assert_equal(arr.dtype, result.dtype)
def test_unpack_single_name(self):
# Regression test for gh-4341
# 进行gh-4341的回归测试
# Unpacking should work when structured dtype has only one field
# 当结构化dtype只有一个字段时,应该可以正常解包
txt = TextIO("21\n35")
# 定义结构化dtype
dt = {'names': ('a',), 'formats': ('i4',)}
# 创建预期结果,包括整数数组
expected = np.array([21, 35], dtype=np.int32)
# 使用genfromtxt从文本输入txt中读取数据,指定dtype为dt,开启解包模式
test = np.genfromtxt(txt, dtype=dt, unpack=True)
# 断言预期结果和测试结果数组相等,及其数据类型相等
assert_array_equal(expected, test)
assert_equal(expected.dtype, test.dtype)
def test_squeeze_scalar(self):
# Regression test for gh-4341
# 进行gh-4341的回归测试
# Unpacking a scalar should give zero-dim output,
# even if dtype is structured
# 即使dtype是结构化的,解包标量应该得到零维输出
txt = TextIO("1")
# 定义结构化dtype
dt = {'names': ('a',), 'formats': ('i4',)}
# 创建预期结果,包括整数数组
expected = np.array((1,), dtype=np.int32)
# 使用genfromtxt从文本输入txt中读取数据,指定dtype为dt,开启解包模式
test = np.genfromtxt(txt, dtype=dt, unpack=True)
# 断言预期结果和测试结果数组相等
assert_array_equal(expected, test)
# 断言测试结果的形状为零维
assert_equal((), test.shape)
# 断言预期结果和测试结果的数据类型相等
assert_equal(expected.dtype, test.dtype)
@pytest.mark.parametrize("ndim", [0, 1, 2])
def test_ndmin_keyword(self, ndim: int):
# lets have the same behaviour of ndmin as loadtxt
# 让ndmin的行为与loadtxt相同
# as they should be the same for non-missing values
# 因为对于非缺失值,它们应该是相同的
txt = "42"
# 使用loadtxt和genfromtxt分别加载文本输入txt,指定ndmin参数为ndim
a = np.loadtxt(StringIO(txt), ndmin=ndim)
b = np.genfromtxt(StringIO(txt), ndmin=ndim)
# 断言两者结果相等
assert_array_equal(a, b)
class TestPathUsage:
# 测试 pathlib.Path 是否可以使用
def test_loadtxt(self):
# 使用临时路径创建一个后缀为 '.txt' 的文件
with temppath(suffix='.txt') as path:
# 将路径转换为 pathlib.Path 对象
path = Path(path)
# 创建一个二维数组
a = np.array([[1.1, 2], [3, 4]])
# 将数组 a 保存到路径对应的文件中
np.savetxt(path, a)
# 从文件中加载数据到数组 x
x = np.loadtxt(path)
# 断言数组 x 和数组 a 相等
assert_array_equal(x, a)
def test_save_load(self):
# 测试 pathlib.Path 实例能否与 save 方法一起使用
with temppath(suffix='.npy') as path:
path = Path(path)
a = np.array([[1, 2], [3, 4]], int)
# 将数组 a 保存到路径对应的文件中
np.save(path, a)
# 从文件中加载数据到变量 data
data = np.load(path)
# 断言变量 data 和数组 a 相等
assert_array_equal(data, a)
def test_save_load_memmap(self):
# 测试 pathlib.Path 实例能否用于加载内存映射
with temppath(suffix='.npy') as path:
path = Path(path)
a = np.array([[1, 2], [3, 4]], int)
# 将数组 a 保存到路径对应的文件中
np.save(path, a)
# 以只读模式加载内存映射数据到变量 data
data = np.load(path, mmap_mode='r')
# 断言变量 data 和数组 a 相等
assert_array_equal(data, a)
# 关闭内存映射文件
del data
if IS_PYPY:
break_cycles()
break_cycles()
@pytest.mark.xfail(IS_WASM, reason="memmap doesn't work correctly")
@pytest.mark.parametrize("filename_type", [Path, str])
def test_save_load_memmap_readwrite(self, filename_type):
# 测试 pathlib.Path 实例能否用于读写内存映射
with temppath(suffix='.npy') as path:
path = filename_type(path)
a = np.array([[1, 2], [3, 4]], int)
# 将数组 a 保存到路径对应的文件中
np.save(path, a)
# 以读写模式加载内存映射数据到变量 b
b = np.load(path, mmap_mode='r+')
# 修改数组 a 和内存映射数据 b 的第一个元素
a[0][0] = 5
b[0][0] = 5
# 关闭内存映射文件
del b
if IS_PYPY:
break_cycles()
break_cycles()
# 重新加载路径对应的数据到变量 data
data = np.load(path)
# 断言变量 data 和数组 a 相等
assert_array_equal(data, a)
@pytest.mark.parametrize("filename_type", [Path, str])
def test_savez_load(self, filename_type):
# 测试 pathlib.Path 实例能否与 savez 方法一起使用
with temppath(suffix='.npz') as path:
path = filename_type(path)
# 保存带有 'lab' 键的数据到路径对应的文件中
np.savez(path, lab='place holder')
# 使用 with 语句加载路径对应的数据到变量 data
with np.load(path) as data:
# 断言变量 data 的 'lab' 键的值与 'place holder' 相等
assert_array_equal(data['lab'], 'place holder')
@pytest.mark.parametrize("filename_type", [Path, str])
def test_savez_compressed_load(self, filename_type):
# 测试 pathlib.Path 实例能否与 savez_compressed 方法一起使用
with temppath(suffix='.npz') as path:
path = filename_type(path)
# 压缩保存带有 'lab' 键的数据到路径对应的文件中
np.savez_compressed(path, lab='place holder')
# 加载路径对应的数据到变量 data
data = np.load(path)
# 断言变量 data 的 'lab' 键的值与 'place holder' 相等
assert_array_equal(data['lab'], 'place holder')
# 关闭文件数据
data.close()
@pytest.mark.parametrize("filename_type", [Path, str])
def test_genfromtxt(self, filename_type):
# 测试 pathlib.Path 实例能否与 genfromtxt 方法一起使用
with temppath(suffix='.txt') as path:
path = filename_type(path)
a = np.array([(1, 2), (3, 4)])
# 将数组 a 保存到路径对应的文件中
np.savetxt(path, a)
# 从文件中加载数据到变量 data
data = np.genfromtxt(path)
# 断言数组 a 和变量 data 相等
assert_array_equal(a, data)
@pytest.mark.parametrize("filename_type", [Path, str])
@pytest.mark.filterwarnings("ignore:.*recfromtxt.*:DeprecationWarning")
# 定义一个测试方法,用于测试从文本文件中读取结构化数据
def test_recfromtxt(self, filename_type):
# 使用临时路径创建一个以'.txt'结尾的文件
with temppath(suffix='.txt') as path:
# 将路径转换为指定类型(Path对象或字符串)
path = filename_type(path)
# 打开文件并写入数据'A,B\n0,1\n2,3'
with open(path, 'w') as f:
f.write('A,B\n0,1\n2,3')
# 定义参数字典,指定分隔符为逗号,缺失值标记为"N/A",使用列名
kwargs = dict(delimiter=",", missing_values="N/A", names=True)
# 调用recfromtxt函数读取数据文件,并传入参数kwargs
test = recfromtxt(path, **kwargs)
# 创建预期结果,一个包含元组的NumPy数组,指定列A和B的数据类型为整数
control = np.array([(0, 1), (2, 3)], dtype=[('A', int), ('B', int)])
# 断言测试结果是一个np.recarray结构
assert_(isinstance(test, np.recarray))
# 断言测试结果与预期结果相等
assert_equal(test, control)
# 使用pytest的参数化装饰器,定义一个参数化测试方法,测试从CSV文件中读取结构化数据
@pytest.mark.parametrize("filename_type", [Path, str])
# 忽略与'recfromcsv'相关的DeprecationWarning警告
@pytest.mark.filterwarnings("ignore:.*recfromcsv.*:DeprecationWarning")
def test_recfromcsv(self, filename_type):
# 使用临时路径创建一个以'.txt'结尾的文件
with temppath(suffix='.txt') as path:
# 将路径转换为指定类型(Path对象或字符串)
path = filename_type(path)
# 打开文件并写入数据'A,B\n0,1\n2,3'
with open(path, 'w') as f:
f.write('A,B\n0,1\n2,3')
# 定义参数字典,指定缺失值标记为"N/A",使用列名,并区分大小写
kwargs = dict(
missing_values="N/A", names=True, case_sensitive=True
)
# 调用recfromcsv函数读取CSV文件,并传入参数kwargs
test = recfromcsv(path, dtype=None, **kwargs)
# 创建预期结果,一个包含元组的NumPy数组,指定列A和B的数据类型为整数
control = np.array([(0, 1), (2, 3)], dtype=[('A', int), ('B', int)])
# 断言测试结果是一个np.recarray结构
assert_(isinstance(test, np.recarray))
# 断言测试结果与预期结果相等
assert_equal(test, control)
# 定义一个测试函数,用于验证从 gzip 压缩文件加载数据的功能
def test_gzip_load():
# 创建一个 5x5 的随机数组
a = np.random.random((5, 5))
# 创建一个 BytesIO 对象,用于在内存中操作二进制数据
s = BytesIO()
# 使用 gzip 压缩模式创建 GzipFile 对象,将数据写入 s
f = gzip.GzipFile(fileobj=s, mode="w")
np.save(f, a) # 将数组 a 保存到压缩文件中
f.close() # 关闭文件流
s.seek(0) # 将文件指针移动到文件开头
# 使用 gzip 解压模式创建 GzipFile 对象,从 s 中加载数据并与原始数组 a 进行比较
f = gzip.GzipFile(fileobj=s, mode="r")
assert_array_equal(np.load(f), a) # 断言加载的数组与原始数组 a 相等
# 下面两个类提供了最小的 API 来保存(save())/加载(load())数组
# `test_ducktyping` 函数确保它们能够正常工作
class JustWriter:
def __init__(self, base):
self.base = base
def write(self, s):
return self.base.write(s)
def flush(self):
return self.base.flush()
class JustReader:
def __init__(self, base):
self.base = base
def read(self, n):
return self.base.read(n)
def seek(self, off, whence=0):
return self.base.seek(off, whence)
# 测试 duck typing 功能,确保 JustWriter 和 JustReader 类可以正确保存和加载数组
def test_ducktyping():
# 创建一个 5x5 的随机数组
a = np.random.random((5, 5))
# 创建一个 BytesIO 对象,用于在内存中操作二进制数据
s = BytesIO()
# 使用 JustWriter 类封装 s,保存数组 a 到 s 中
f = JustWriter(s)
np.save(f, a)
f.flush() # 刷新数据
s.seek(0) # 将文件指针移动到文件开头
# 使用 JustReader 类封装 s,加载数据并与原始数组 a 进行比较
f = JustReader(s)
assert_array_equal(np.load(f), a) # 断言加载的数组与原始数组 a 相等
# 从 gzip 压缩的文件中加载数据并进行测试
def test_gzip_loadtxt():
# 创建一个 BytesIO 对象,用于在内存中操作二进制数据
s = BytesIO()
# 使用 gzip 压缩模式创建 GzipFile 对象,写入一个简单的字符串数据
g = gzip.GzipFile(fileobj=s, mode='w')
g.write(b'1 2 3\n')
g.close()
s.seek(0) # 将文件指针移动到文件开头
# 创建一个临时文件,将压缩数据写入,然后使用 np.loadtxt() 加载数据并进行断言
with temppath(suffix='.gz') as name:
with open(name, 'wb') as f:
f.write(s.read())
res = np.loadtxt(name)
s.close() # 关闭 BytesIO 对象
assert_array_equal(res, [1, 2, 3]) # 断言加载的数据与预期的数组相等
# 从字符串中加载 gzip 压缩的数据并进行测试
def test_gzip_loadtxt_from_string():
# 创建一个 BytesIO 对象,用于在内存中操作二进制数据
s = BytesIO()
# 使用 gzip 压缩模式创建 GzipFile 对象,写入一个简单的字符串数据
f = gzip.GzipFile(fileobj=s, mode="w")
f.write(b'1 2 3\n')
f.close()
s.seek(0) # 将文件指针移动到文件开头
# 使用 gzip 解压模式创建 GzipFile 对象,从 s 中加载数据并进行断言
f = gzip.GzipFile(fileobj=s, mode="r")
assert_array_equal(np.loadtxt(f), [1, 2, 3]) # 断言加载的数据与预期的数组相等
# 测试 npz 文件中保存的字典数据
def test_npzfile_dict():
# 创建一个 BytesIO 对象,用于在内存中操作二进制数据
s = BytesIO()
# 创建两个 3x3 的零矩阵
x = np.zeros((3, 3))
y = np.zeros((3, 3))
# 将 x 和 y 保存到 npz 格式的 s 中
np.savez(s, x=x, y=y)
s.seek(0) # 将文件指针移动到文件开头
# 从 s 中加载数据
z = np.load(s)
# 断言 x 和 y 存在于 z 中
assert_('x' in z)
assert_('y' in z)
assert_('x' in z.keys())
assert_('y' in z.keys())
# 遍历 z 中的每个元素,断言其形状为 (3, 3)
for f, a in z.items():
assert_(f in ['x', 'y'])
assert_equal(a.shape, (3, 3))
# 断言 z 中有两个元素
assert_(len(z.items()) == 2)
# 遍历 z 中的每个键,断言其存在于 ['x', 'y'] 中
for f in z:
assert_(f in ['x', 'y'])
# 断言 z 中的 'x' 数组与 z['x'] 数组内容相同
assert (z.get('x') == z['x']).all()
# 使用 pytest.mark.skipif 装饰器跳过不支持 refcount 的平台的测试
@pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
def test_load_refcount():
# 创建一个 BytesIO 对象,用于在内存中操作二进制数据
f = BytesIO()
# 将一个简单的数组保存到 f 中
np.savez(f, [1, 2, 3])
f.seek(0) # 将文件指针移动到文件开头
# 使用 assert_no_gc_cycles() 上下文管理器,确保加载的对象能够直接释放,不依赖于 gc
with assert_no_gc_cycles():
np.load(f)
f.seek(0) # 将文件指针移动到文件开头
dt = [("a", 'u1', 2), ("b", 'u1', 2)]
# 使用 assert_no_gc_cycles() 上下文管理器来确保没有循环引用的垃圾回收
with assert_no_gc_cycles():
# 使用 np.loadtxt() 函数从文本输入加载数据并按指定的数据类型 (dt) 转换
x = np.loadtxt(TextIO("0 1 2 3"), dtype=dt)
# 断言 x 的值与指定的 numpy 数组相等
assert_equal(x, np.array([((0, 1), (2, 3))], dtype=dt))
# 定义一个测试函数,用于测试加载多个数组直至文件结束
def test_load_multiple_arrays_until_eof():
# 创建一个字节流对象
f = BytesIO()
# 在字节流中保存数组1
np.save(f, 1)
# 在字节流中保存数组2
np.save(f, 2)
# 将字节流的读写位置设置回起始位置
f.seek(0)
# 断言从字节流中加载的第一个数组等于1
assert np.load(f) == 1
# 断言从字节流中加载的第二个数组等于2
assert np.load(f) == 2
# 使用 pytest 的断言,期望从字节流中加载数据时抛出 EOFError 异常
with pytest.raises(EOFError):
np.load(f)