NumPy 源码解析(十二)
.\numpy\numpy\fft\_helper.pyi
from typing import Any, TypeVar, overload, Literal as L
from numpy import generic, integer, floating, complexfloating
from numpy._typing import (
NDArray,
ArrayLike,
_ShapeLike,
_ArrayLike,
_ArrayLikeFloat_co,
_ArrayLikeComplex_co,
)
_SCT = TypeVar("_SCT", bound=generic)
__all__: list[str]
@overload
def fftshift(x: _ArrayLike[_SCT], axes: None | _ShapeLike = ...) -> NDArray[_SCT]: ...
@overload
def fftshift(x: ArrayLike, axes: None | _ShapeLike = ...) -> NDArray[Any]: ...
@overload
def ifftshift(x: _ArrayLike[_SCT], axes: None | _ShapeLike = ...) -> NDArray[_SCT]: ...
@overload
def ifftshift(x: ArrayLike, axes: None | _ShapeLike = ...) -> NDArray[Any]: ...
@overload
def fftfreq(
n: int | integer[Any],
d: _ArrayLikeFloat_co = ...,
device: None | L["cpu"] = ...,
) -> NDArray[floating[Any]]: ...
@overload
def fftfreq(
n: int | integer[Any],
d: _ArrayLikeComplex_co = ...,
device: None | L["cpu"] = ...,
) -> NDArray[complexfloating[Any, Any]]: ...
@overload
def rfftfreq(
n: int | integer[Any],
d: _ArrayLikeFloat_co = ...,
device: None | L["cpu"] = ...,
) -> NDArray[floating[Any]]: ...
@overload
def rfftfreq(
n: int | integer[Any],
d: _ArrayLikeComplex_co = ...,
device: None | L["cpu"] = ...,
) -> NDArray[complexfloating[Any, Any]]: ...
.\numpy\numpy\fft\_pocketfft.py
array_function_dispatch = functools.partial(
overrides.array_function_dispatch, module='numpy.fft')
def _raw_fft(a, n, axis, is_real, is_forward, norm, out=None):
if n < 1:
raise ValueError(f"Invalid number of FFT data points ({n}) specified.")
if not is_forward:
norm = _swap_direction(norm)
real_dtype = result_type(a.real.dtype, 1.0)
if norm is None or norm == "backward":
fct = 1
elif norm == "ortho":
fct = reciprocal(sqrt(n, dtype=real_dtype))
elif norm == "forward":
fct = reciprocal(n, dtype=real_dtype)
else:
raise ValueError(f'Invalid norm value {norm}; should be "backward",'
'"ortho" or "forward".')
if is_real:
if is_forward:
ufunc = pfu.rfft_n_even if n % 2 == 0 else pfu.rfft_n_odd
n_out = n // 2 + 1
else:
ufunc = pfu.irfft
else:
ufunc = pfu.fft if is_forward else pfu.ifft
axis = normalize_axis_index(axis, a.ndim)
if out is None:
if is_real and not is_forward:
out_dtype = real_dtype
else:
out_dtype = result_type(a.dtype, 1j)
out = empty(a.shape[:axis] + (n_out,) + a.shape[axis+1:],
dtype=out_dtype)
elif ((shape := getattr(out, "shape", None)) is not None
and (len(shape) != a.ndim or shape[axis] != n_out)):
raise ValueError("output array has wrong shape.")
return ufunc(a, fct, axes=[(axis,), (), (axis,)], out=out)
这段代码是一个函数内的逻辑判断和返回语句。根据给定的条件,它决定如何处理输出数组 `out`,并调用 `ufunc` 函数进行计算。
_SWAP_DIRECTION_MAP = {"backward": "forward", None: "forward",
"ortho": "ortho", "forward": "backward"}
def _swap_direction(norm):
try:
return _SWAP_DIRECTION_MAP[norm]
except KeyError:
raise ValueError(f'Invalid norm value {norm}; should be "backward", '
'"ortho" or "forward".') from None
def _fft_dispatcher(a, n=None, axis=None, norm=None, out=None):
return (a, out)
@array_function_dispatch(_fft_dispatcher)
def fft(a, n=None, axis=-1, norm=None, out=None):
"""
Compute the one-dimensional discrete Fourier Transform.
This function computes the one-dimensional *n*-point discrete Fourier
Transform (DFT) with the efficient Fast Fourier Transform (FFT)
algorithm [CT].
Parameters
----------
a : array_like
Input array, can be complex.
n : int, optional
Length of the transformed axis of the output.
If `n` is smaller than the length of the input, the input is cropped.
If it is larger, the input is padded with zeros. If `n` is not given,
the length of the input along the axis specified by `axis` is used.
axis : int, optional
Axis over which to compute the FFT. If not given, the last axis is
used.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype.
.. versionadded:: 2.0.0
Returns
-------
out : complex ndarray
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
Raises
------
IndexError
If `axis` is not a valid axis of `a`.
See Also
--------
numpy.fft : for definition of the DFT and conventions used.
ifft : The inverse of `fft`.
fft2 : The two-dimensional FFT.
fftn : The *n*-dimensional FFT.
rfftn : The *n*-dimensional FFT of real input.
fftfreq : Frequency bins for given FFT parameters.
Notes
-----
FFT (Fast Fourier Transform) refers to a way the discrete Fourier
Transform (DFT) can be calculated efficiently, by using symmetries in the
calculated terms. The symmetry is highest when `n` is a power of 2, and
the transform is therefore most efficient for these sizes.
The DFT is defined, with the conventions used in this implementation, in
the documentation for the `numpy.fft` module.
References
----------
"""
a = asarray(a)
if n is None:
n = a.shape[axis]
output = _raw_fft(a, n, axis, False, True, norm, out)
return output
@array_function_dispatch(_fft_dispatcher)
def ifft(a, n=None, axis=-1, norm=None, out=None):
"""
Compute the one-dimensional inverse discrete Fourier Transform.
This function computes the inverse of the one-dimensional *n*-point
discrete Fourier transform computed by `fft`. In other words,
``ifft(fft(a)) == a`` to within numerical accuracy.
For a general description of the algorithm and definitions,
see `numpy.fft`.
# The input should be ordered in the same way as is returned by `fft`,
# i.e.,
# ``a[0]`` should contain the zero frequency term,
# ``a[1:n//2]`` should contain the positive-frequency terms,
# ``a[n//2 + 1:]`` should contain the negative-frequency terms, in
# increasing order starting from the most negative frequency.
# For an even number of input points, ``A[n//2]`` represents the sum of
# the values at the positive and negative Nyquist frequencies, as the two
# are aliased together. See `numpy.fft` for details.
# Parameters
# a : array_like
# Input array, can be complex.
# n : int, optional
# Length of the transformed axis of the output.
# If `n` is smaller than the length of the input, the input is cropped.
# If it is larger, the input is padded with zeros. If `n` is not given,
# the length of the input along the axis specified by `axis` is used.
# See notes about padding issues.
# axis : int, optional
# Axis over which to compute the inverse DFT. If not given, the last
# axis is used.
# norm : {"backward", "ortho", "forward"}, optional
# .. versionadded:: 1.10.0
# Normalization mode (see `numpy.fft`). Default is "backward".
# Indicates which direction of the forward/backward pair of transforms
# is scaled and with what normalization factor.
# .. versionadded:: 1.20.0
# The "backward", "forward" values were added.
# out : complex ndarray, optional
# If provided, the result will be placed in this array. It should be
# of the appropriate shape and dtype.
# .. versionadded:: 2.0.0
# Returns
# -------
# out : complex ndarray
# The truncated or zero-padded input, transformed along the axis
# indicated by `axis`, or the last one if `axis` is not specified.
# Raises
# ------
# IndexError
# If `axis` is not a valid axis of `a`.
# See Also
# --------
# numpy.fft : An introduction, with definitions and general explanations.
# fft : The one-dimensional (forward) FFT, of which `ifft` is the inverse
# ifft2 : The two-dimensional inverse FFT.
# ifftn : The n-dimensional inverse FFT.
# Notes
# -----
# If the input parameter `n` is larger than the size of the input, the input
# is padded by appending zeros at the end. Even though this is the common
# approach, it might lead to surprising results. If a different padding is
# desired, it must be performed before calling `ifft`.
# 将输入转换为 ndarray 对象
a = asarray(a)
# 如果 n 为 None,则设为数组 a 在指定轴上的形状
if n is None:
n = a.shape[axis]
# 执行原始 FFT 算法,生成频谱数据
output = _raw_fft(a, n, axis, False, False, norm, out=out)
# 返回 FFT 的输出结果
return output
@array_function_dispatch(_fft_dispatcher)
# 使用 array_function_dispatch 装饰器来分派实现不同的 FFT 操作
def rfft(a, n=None, axis=-1, norm=None, out=None):
"""
Compute the one-dimensional discrete Fourier Transform for real input.
This function computes the one-dimensional *n*-point discrete Fourier
Transform (DFT) of a real-valued array by means of an efficient algorithm
called the Fast Fourier Transform (FFT).
Parameters
----------
a : array_like
Input array
n : int, optional
Number of points along transformation axis in the input to use.
If `n` is smaller than the length of the input, the input is cropped.
If it is larger, the input is padded with zeros. If `n` is not given,
the length of the input along the axis specified by `axis` is used.
axis : int, optional
Axis over which to compute the FFT. If not given, the last axis is
used.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype.
.. versionadded:: 2.0.0
Returns
-------
out : complex ndarray
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
If `n` is even, the length of the transformed axis is ``(n/2)+1``.
If `n` is odd, the length is ``(n+1)/2``.
Raises
------
IndexError
If `axis` is not a valid axis of `a`.
See Also
--------
numpy.fft : For definition of the DFT and conventions used.
irfft : The inverse of `rfft`.
fft : The one-dimensional FFT of general (complex) input.
fftn : The *n*-dimensional FFT.
rfftn : The *n*-dimensional FFT of real input.
Notes
-----
When the DFT is computed for purely real input, the output is
Hermitian-symmetric, i.e. the negative frequency terms are just the complex
conjugates of the corresponding positive-frequency terms, and the
negative-frequency terms are therefore redundant. This function does not
compute the negative frequency terms, and the length of the transformed
axis of the output is therefore ``n//2 + 1``.
When ``A = rfft(a)`` and fs is the sampling frequency, ``A[0]`` contains
the zero-frequency term 0*fs, which is real due to Hermitian symmetry.
If `n` is even, ``A[-1]`` contains the term representing both positive
and negative Nyquist frequency (+fs/2 and -fs/2), and must also be purely
real. If `n` is odd, there is no term at fs/2; ``A[-1]`` contains
"""
# 实现一维实数输入的快速傅里叶变换(FFT)
# 这个函数通过FFT算法计算实值数组的一维n点离散傅里叶变换(DFT)
# ...
# 具体实现部分,根据输入参数进行相应的FFT计算
# 将输入数组 `a` 转换为 `numpy` 数组,确保操作的一致性和正确性
a = asarray(a)
# 如果未指定变换长度 `n`,则默认为输入数组 `a` 在指定轴 `axis` 上的长度
if n is None:
n = a.shape[axis]
# 调用 `_raw_fft` 函数执行快速傅里叶变换,返回变换后的结果
# `True, True, norm, out=out` 分别代表:是否要进行正规化,是否要进行轴对称,正规化参数,输出参数
output = _raw_fft(a, n, axis, True, True, norm, out=out)
# 返回傅里叶变换的结果
return output
# 使用装饰器将函数注册到数组函数分派机制中,这可以根据输入数组的类型调用不同的函数版本
@array_function_dispatch(_fft_dispatcher)
# 定义函数 irfft,计算 rfft 的逆操作
def irfft(a, n=None, axis=-1, norm=None, out=None):
"""
Computes the inverse of `rfft`.
This function computes the inverse of the one-dimensional *n*-point
discrete Fourier Transform of real input computed by `rfft`.
In other words, ``irfft(rfft(a), len(a)) == a`` to within numerical
accuracy. (See Notes below for why ``len(a)`` is necessary here.)
The input is expected to be in the form returned by `rfft`, i.e. the
real zero-frequency term followed by the complex positive frequency terms
in order of increasing frequency. Since the discrete Fourier Transform of
real input is Hermitian-symmetric, the negative frequency terms are taken
to be the complex conjugates of the corresponding positive frequency terms.
Parameters
----------
a : array_like
The input array.
n : int, optional
Length of the transformed axis of the output.
For `n` output points, ``n//2+1`` input points are necessary. If the
input is longer than this, it is cropped. If it is shorter than this,
it is padded with zeros. If `n` is not given, it is taken to be
``2*(m-1)`` where ``m`` is the length of the input along the axis
specified by `axis`.
axis : int, optional
Axis over which to compute the inverse FFT. If not given, the last
axis is used.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype.
.. versionadded:: 2.0.0
Returns
-------
out : ndarray
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
The length of the transformed axis is `n`, or, if `n` is not given,
``2*(m-1)`` where ``m`` is the length of the transformed axis of the
input. To get an odd number of output points, `n` must be specified.
Raises
------
IndexError
If `axis` is not a valid axis of `a`.
See Also
--------
numpy.fft : For definition of the DFT and conventions used.
rfft : The one-dimensional FFT of real input, of which `irfft` is inverse.
fft : The one-dimensional FFT.
irfft2 : The inverse of the two-dimensional FFT of real input.
irfftn : The inverse of the *n*-dimensional FFT of real input.
Notes
-----
Returns the real valued `n`-point inverse discrete Fourier transform
of `a`, where `a` contains the non-negative frequency terms of a
"""
a = asarray(a)
# 将输入的数组 `a` 转换为 NumPy 数组,确保可以进行 FFT 操作
if n is None:
# 如果输入的输出长度 `n` 为 None,则根据输入数组的维度来确定输出长度
n = (a.shape[axis] - 1) * 2
# 调用内部函数 `_raw_fft` 进行 FFT 操作,生成输出结果
# 参数含义依次为:输入数组 `a`,输出长度 `n`,操作轴 `axis`,进行逆变换 `True`,不进行归一化 `False`,指定输出数组 `out`
output = _raw_fft(a, n, axis, True, False, norm, out=out)
# 返回 FFT 变换后的输出结果
return output
# 使用装饰器实现函数分派,将该函数与适当的_fft_dispatcher分派器相关联
@array_function_dispatch(_fft_dispatcher)
# 定义一个函数hfft,用于计算具有Hermitian对称性的信号的FFT,即实部频谱
def hfft(a, n=None, axis=-1, norm=None, out=None):
"""
Compute the FFT of a signal that has Hermitian symmetry, i.e., a real
spectrum.
Parameters
----------
a : array_like
The input array.
n : int, optional
Length of the transformed axis of the output. For `n` output
points, ``n//2 + 1`` input points are necessary. If the input is
longer than this, it is cropped. If it is shorter than this, it is
padded with zeros. If `n` is not given, it is taken to be ``2*(m-1)``
where ``m`` is the length of the input along the axis specified by
`axis`.
axis : int, optional
Axis over which to compute the FFT. If not given, the last
axis is used.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype.
.. versionadded:: 2.0.0
Returns
-------
out : ndarray
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
The length of the transformed axis is `n`, or, if `n` is not given,
``2*m - 2`` where ``m`` is the length of the transformed axis of
the input. To get an odd number of output points, `n` must be
specified, for instance as ``2*m - 1`` in the typical case,
Raises
------
IndexError
If `axis` is not a valid axis of `a`.
See also
--------
rfft : Compute the one-dimensional FFT for real input.
ihfft : The inverse of `hfft`.
Notes
-----
`hfft`/`ihfft` are a pair analogous to `rfft`/`irfft`, but for the
opposite case: here the signal has Hermitian symmetry in the time
domain and is real in the frequency domain. So here it's `hfft` for
which you must supply the length of the result if it is to be odd.
* even: ``ihfft(hfft(a, 2*len(a) - 2)) == a``, within roundoff error,
* odd: ``ihfft(hfft(a, 2*len(a) - 1)) == a``, within roundoff error.
The correct interpretation of the hermitian input depends on the length of
the original data, as given by `n`. This is because each input shape could
correspond to either an odd or even length signal. By default, `hfft`
assumes an even output length which puts the last entry at the Nyquist
frequency; aliasing with its symmetric counterpart. By Hermitian symmetry,
the value is thus treated as purely real. To avoid losing information, the
shape of the full signal **must** be given.
"""
# 实现代码逻辑在此处
# 将输入数组 `a` 转换为 NumPy 数组
a = asarray(a)
# 如果未指定频谱长度 `n`,则设定为 `(a.shape[axis] - 1) * 2`
if n is None:
n = (a.shape[axis] - 1) * 2
# 交换规范化方向 `norm` 的定义,返回新的规范化方式
new_norm = _swap_direction(norm)
# 对输入数组 `a` 进行共轭处理后,执行逆快速傅里叶变换
# 返回频谱,可以指定频谱长度 `n`,轴 `axis`,以及输出数组 `out` 的选项
output = irfft(conjugate(a), n, axis, norm=new_norm, out=None)
return output
# 使用array_function_dispatch修饰符,将函数_ihfft分派给_fft_dispatcher
def ihfft(a, n=None, axis=-1, norm=None, out=None):
"""
Compute the inverse FFT of a signal that has Hermitian symmetry.
Parameters
----------
a : array_like
Input array.
n : int, optional
Length of the inverse FFT, the number of points along
transformation axis in the input to use. If `n` is smaller than
the length of the input, the input is cropped. If it is larger,
the input is padded with zeros. If `n` is not given, the length of
the input along the axis specified by `axis` is used.
axis : int, optional
Axis over which to compute the inverse FFT. If not given, the last
axis is used.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype.
.. versionadded:: 2.0.0
Returns
-------
out : complex ndarray
The truncated or zero-padded input, transformed along the axis
indicated by `axis`, or the last one if `axis` is not specified.
The length of the transformed axis is ``n//2 + 1``.
See also
--------
hfft, irfft
Notes
-----
`hfft`/`ihfft` are a pair analogous to `rfft`/`irfft`, but for the
opposite case: here the signal has Hermitian symmetry in the time
domain and is real in the frequency domain. So here it's `hfft` for
which you must supply the length of the result if it is to be odd:
* even: ``ihfft(hfft(a, 2*len(a) - 2)) == a``, within roundoff error,
* odd: ``ihfft(hfft(a, 2*len(a) - 1)) == a``, within roundoff error.
Examples
--------
>>> spectrum = np.array([ 15, -4, 0, -1, 0, -4])
>>> np.fft.ifft(spectrum)
array([1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 3.+0.j, 2.+0.j])
>>> np.fft.ihfft(spectrum)
array([ 1.-0.j, 2.-0.j, 3.-0.j, 4.-0.j])
"""
# 将a转换为数组
a = asarray(a)
# 如果n为None,则n等于a在axis轴上的形状
if n is None:
n = a.shape[axis]
# 将norm参数用于交换方向,并赋值给new_norm
new_norm = _swap_direction(norm)
# 使用rfft函数计算FFT的逆变换,并将结果存入out中
out = rfft(a, n, axis, norm=new_norm, out=out)
# 返回结果的共轭
return conjugate(out, out=out)
def _cook_nd_args(a, s=None, axes=None, invreal=0):
# 如果形状s为None
if s is None:
shapeless = True
# 如果轴为None,则s等于a的形状的列表形式
if axes is None:
s = list(a.shape)
# 否则,s等于在轴上获取a的形状的列表形式
else:
s = take(a.shape, axes)
else:
shapeless = False
# s转换为列表形式
s = list(s)
# 如果 axes 参数为 None
if axes is None:
# 如果 shapeless 参数为 False
if not shapeless:
# 提示消息,警告用户在未来版本中将不再支持 axes 参数为 None 的情况
msg = ("`axes` should not be `None` if `s` is not `None` "
"(Deprecated in NumPy 2.0). In a future version of NumPy, "
"this will raise an error and `s[i]` will correspond to "
"the size along the transformed axis specified by "
"`axes[i]`. To retain current behaviour, pass a sequence "
"[0, ..., k-1] to `axes` for an array of dimension k.")
# 发出警告消息
warnings.warn(msg, DeprecationWarning, stacklevel=3)
# 设置 axes 为一个从负数索引到 -1 的列表,用于表示数组的各个维度
axes = list(range(-len(s), 0))
# 如果 s 的长度与 axes 的长度不同,抛出数值错误
if len(s) != len(axes):
raise ValueError("Shape and axes have different lengths.")
# 如果 invreal 为 True 且 shapeless 为 True
if invreal and shapeless:
# 计算 s 中最后一个元素的值,用于对应维度的变换
s[-1] = (a.shape[axes[-1]] - 1) * 2
# 如果 s 中包含 None 值
if None in s:
# 提示消息,警告用户在未来版本中不再支持 s 中包含 None 值的情况
msg = ("Passing an array containing `None` values to `s` is "
"deprecated in NumPy 2.0 and will raise an error in "
"a future version of NumPy. To use the default behaviour "
"of the corresponding 1-D transform, pass the value matching "
"the default for its `n` parameter. To use the default "
"behaviour for every axis, the `s` argument can be omitted.")
# 发出警告消息
warnings.warn(msg, DeprecationWarning, stacklevel=3)
# 根据 s 和 axes 的定义,构建新的 s 数组,用于描述变换后的数组形状
s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)]
# 返回 s 数组和 axes 数组
return s, axes
# 定义一个函数 `_raw_fftnd`,用于执行 N 维 FFT 变换
def _raw_fftnd(a, s=None, axes=None, function=fft, norm=None, out=None):
# 将输入参数 `a` 转换为 ndarray 类型
a = asarray(a)
# 根据传入的参数 `a`, `s`, `axes`,获取处理后的 `s` 和 `axes`
s, axes = _cook_nd_args(a, s, axes)
# 倒序遍历 `axes`,对数组 `a` 执行 FFT 变换
itl = list(range(len(axes)))
itl.reverse()
for ii in itl:
a = function(a, n=s[ii], axis=axes[ii], norm=norm, out=out)
# 返回变换后的数组 `a`
return a
# 定义一个函数 `_fftn_dispatcher`,用于 FFTN 的分派器,返回 `(a, out)`
def _fftn_dispatcher(a, s=None, axes=None, norm=None, out=None):
return (a, out)
# 使用装饰器 `array_function_dispatch`,将 `_fftn_dispatcher` 注册为 `fftn` 的分派器
@array_function_dispatch(_fftn_dispatcher)
# 定义函数 `fftn`,用于计算 N 维离散傅里叶变换(DFT)
def fftn(a, s=None, axes=None, norm=None, out=None):
"""
Compute the N-dimensional discrete Fourier Transform.
This function computes the *N*-dimensional discrete Fourier Transform over
any number of axes in an *M*-dimensional array by means of the Fast Fourier
Transform (FFT).
Parameters
----------
a : array_like
Input array, can be complex.
s : sequence of ints, optional
Shape (length of each transformed axis) of the output
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
This corresponds to ``n`` for ``fft(x, n)``.
Along any axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
.. versionchanged:: 2.0
If it is ``-1``, the whole input is used (no padding/trimming).
If `s` is not given, the shape of the input along the axes specified
by `axes` is used.
.. deprecated:: 2.0
If `s` is not ``None``, `axes` must not be ``None`` either.
.. deprecated:: 2.0
`s` must contain only ``int`` s, not ``None`` values. ``None``
values currently mean that the default value for ``n`` is used
in the corresponding 1-D transform, but this behaviour is
deprecated.
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last ``len(s)``
axes are used, or all axes if `s` is also not specified.
Repeated indices in `axes` means that the transform over that axis is
performed multiple times.
.. deprecated:: 2.0
If `s` is specified, the corresponding `axes` to be transformed
must be explicitly specified too.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype for all axes (and hence is
incompatible with passing in all but the trivial ``s``).
.. versionadded:: 2.0.0
Returns
-------
```
"""
使用 `_raw_fftnd` 函数对输入数据 `a` 进行多维 FFT 变换,指定了`s`、`axes`、`fft`、`norm` 参数,
并将结果存储到 `out` 中。
"""
return _raw_fftnd(a, s, axes, fft, norm, out=out)
@array_function_dispatch(_fftn_dispatcher)
def ifftn(a, s=None, axes=None, norm=None, out=None):
"""
计算 N 维逆离散傅里叶变换。
该函数通过快速傅里叶变换(FFT)在 M 维数组的任意数量的轴上计算逆 N 维离散傅里叶变换。
换句话说,``ifftn(fftn(a)) == a``,在数值精度范围内成立。
有关所使用的定义和约定的描述,请参阅 `numpy.fft`。
输入应该按照与 `fftn` 返回的顺序相同的方式排序,即所有轴的零频率项应位于低序角落,
所有轴的正频率项应位于前半部分,所有轴的奈奎斯特频率项应位于中间,
所有轴的负频率项应按照递减负频率的顺序排列。
Parameters
----------
a : array_like
输入数组,可以是复数。
s : sequence of ints, optional
输出的形状(每个变换轴的长度)。(`s[0]` 对应轴 0, `s[1]` 对应轴 1, 等等)。
对应于 `ifft(x, n)` 中的 `n`。
沿任何轴,如果给定的形状小于输入的形状,则输入被截断。
如果大于输入,则用零填充输入。
.. versionchanged:: 2.0
如果是 `-1`,则使用整个输入(无填充/裁剪)。
如果未给出 `s`,则使用由 `axes` 指定的轴的输入形状。请参阅 `ifft` 的零填充问题。
.. deprecated:: 2.0
如果 `s` 不是 `None`,`axes` 也不应为 `None`。
.. deprecated:: 2.0
`s` 必须只包含 `int`,不包含 `None` 值。`None` 值当前表示在相应的一维变换中使用默认值 `n`,
但这种行为已经不推荐使用。
axes : sequence of ints, optional
进行逆傅里叶变换的轴。如果未给出,则使用最后 `len(s)` 个轴,如果 `s` 也未指定,则使用所有轴。
`axes` 中的重复索引意味着在该轴上执行多次逆变换。
.. deprecated:: 2.0
如果指定了 `s`,则必须显式指定要变换的相应 `axes`。
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
规范化模式(参见 `numpy.fft`)。默认为 "backward"。
指示哪个方向的正向/反向变换对被缩放,并使用什么规范化因子。
.. versionadded:: 1.20.0
添加了 "backward"、"forward" 值。
out : ndarray, optional
输出数组,结果将被放置在其中。它应该具有与预期输出相同的形状和数据类型,
否则,将引发异常。如果未提供,则将分配新的数组。
.. versionadded:: 1.16.0
"""
"""
return _raw_fftnd(a, s, axes, ifft, norm, out=out)
@array_function_dispatch(_fftn_dispatcher)
# 使用装饰器指定特定的函数调度器,用于快速傅里叶变换(FFT)的调度
def fft2(a, s=None, axes=(-2, -1), norm=None, out=None):
"""
Compute the 2-dimensional discrete Fourier Transform.
This function computes the *n*-dimensional discrete Fourier Transform
over any axes in an *M*-dimensional array by means of the
Fast Fourier Transform (FFT). By default, the transform is computed over
the last two axes of the input array, i.e., a 2-dimensional FFT.
Parameters
----------
a : array_like
Input array, can be complex
s : sequence of ints, optional
Shape (length of each transformed axis) of the output
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
This corresponds to ``n`` for ``fft(x, n)``.
Along each axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
.. versionchanged:: 2.0
If it is ``-1``, the whole input is used (no padding/trimming).
If `s` is not given, the shape of the input along the axes specified
by `axes` is used.
.. deprecated:: 2.0
If `s` is not ``None``, `axes` must not be ``None`` either.
.. deprecated:: 2.0
`s` must contain only ``int`` s, not ``None`` values. ``None``
values currently mean that the default value for ``n`` is used
in the corresponding 1-D transform, but this behaviour is
deprecated.
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last two
axes are used. A repeated index in `axes` means the transform over
that axis is performed multiple times. A one-element sequence means
that a one-dimensional FFT is performed. Default: ``(-2, -1)``.
.. deprecated:: 2.0
If `s` is specified, the corresponding `axes` to be transformed
must not be ``None``.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype for all axes (and hence only the
last axis can have ``s`` not equal to the shape at that axis).
.. versionadded:: 2.0.0
Returns
-------
out : complex ndarray
The truncated or zero-padded input, transformed along the axes
indicated by `axes`, or the last two axes if `axes` is not given.
Raises
------
ValueError
If `s` and `axes` have different length, or `axes` not given and
``len(s) != 2``.
"""
# 返回 `a` 的多维傅里叶变换结果。
# `_raw_fftnd` 是实际执行傅里叶变换的函数。
# `a` 是输入的数组,其元素类型通常为复数。
# `s` 是指定每个维度的大小的元组,决定了输出数组的形状。
# `axes` 是指定要在哪些轴上应用 FFT 的整数列表。
# `fft` 是指定要使用的 FFT 函数,通常为 numpy.fft.fft。
# `norm` 是一个布尔值,指定是否对变换结果进行归一化。
# `out` 是可选的输出数组,用于存储结果。
# 如果 `axes` 中的元素大于 `a` 的轴数,将引发 IndexError。
See Also
--------
numpy.fft : 提供离散傅里叶变换的整体视图,包括定义和使用的约定。
ifft2 : 二维逆傅里叶变换。
fft : 一维傅里叶变换。
fftn : *n* 维傅里叶变换。
fftshift : 将零频率项移至数组中心。
对于二维输入,交换第一和第三象限,以及第二和第四象限。
Notes
-----
`fft2` 实际上是 `fftn` 的一种变体,其在 `axes` 参数上有不同的默认设置。
输出类似于 `fft`,在转换后的轴的低阶角落包含零频率项,第一半轴包含正频率项,
轴中间包含 Nyquist 频率项,第二半轴按递减负频率顺序排列。
详细信息和绘图示例,请参阅 `fftn`,以及 `numpy.fft` 提供的定义和使用约定。
Examples
--------
>>> a = np.mgrid[:5, :5][0]
>>> np.fft.fft2(a)
array([[ 50. +0.j , 0. +0.j , 0. +0.j , # 结果可能会有所不同
0. +0.j , 0. +0.j ],
[-12.5+17.20477401j, 0. +0.j , 0. +0.j ,
0. +0.j , 0. +0.j ],
[-12.5 +4.0614962j , 0. +0.j , 0. +0.j ,
0. +0.j , 0. +0.j ],
[-12.5 -4.0614962j , 0. +0.j , 0. +0.j ,
0. +0.j , 0. +0.j ],
[-12.5-17.20477401j, 0. +0.j , 0. +0.j ,
0. +0.j , 0. +0.j ]])
# 使用 array_function_dispatch 装饰器,将函数注册为 ifftn 调度器的一部分
@array_function_dispatch(_fftn_dispatcher)
# 定义 ifft2 函数,用于计算二维反离散傅里叶变换的逆变换
def ifft2(a, s=None, axes=(-2, -1), norm=None, out=None):
"""
Compute the 2-dimensional inverse discrete Fourier Transform.
This function computes the inverse of the 2-dimensional discrete Fourier
Transform over any number of axes in an M-dimensional array by means of
the Fast Fourier Transform (FFT). In other words, ``ifft2(fft2(a)) == a``
to within numerical accuracy. By default, the inverse transform is
computed over the last two axes of the input array.
The input, analogously to `ifft`, should be ordered in the same way as is
returned by `fft2`, i.e. it should have the term for zero frequency
in the low-order corner of the two axes, the positive frequency terms in
the first half of these axes, the term for the Nyquist frequency in the
middle of the axes and the negative frequency terms in the second half of
both axes, in order of decreasingly negative frequency.
Parameters
----------
a : array_like
Input array, can be complex.
s : sequence of ints, optional
Shape (length of each axis) of the output (``s[0]`` refers to axis 0,
``s[1]`` to axis 1, etc.). This corresponds to `n` for ``ifft(x, n)``.
Along each axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
.. versionchanged:: 2.0
If it is ``-1``, the whole input is used (no padding/trimming).
If `s` is not given, the shape of the input along the axes specified
by `axes` is used. See notes for issue on `ifft` zero padding.
.. deprecated:: 2.0
If `s` is not ``None``, `axes` must not be ``None`` either.
.. deprecated:: 2.0
`s` must contain only ``int`` s, not ``None`` values. ``None``
values currently mean that the default value for ``n`` is used
in the corresponding 1-D transform, but this behaviour is
deprecated.
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last two
axes are used. A repeated index in `axes` means the transform over
that axis is performed multiple times. A one-element sequence means
that a one-dimensional FFT is performed. Default: ``(-2, -1)``.
.. deprecated:: 2.0
If `s` is specified, the corresponding `axes` to be transformed
must not be ``None``.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
"""
return _raw_fftnd(a, s, axes, ifft, norm, out=None)
# 调用 `_raw_fftnd` 函数进行 n 维 FFT 或逆 FFT 变换
return _raw_fftnd(a, s, axes, ifft, norm, out=None)
这段代码是一个函数的返回语句,调用了名为 `_raw_fftnd` 的函数,用于执行 n 维的 FFT(快速傅里叶变换)或逆 FFT 变换。函数的参数解释如下:
- `a`: 输入数组,进行 FFT 或逆 FFT 变换的原始数据。
- `s`: 可选参数,用于指定输出数组的形状。如果提供了 `out` 参数,`s` 应与 `out` 的形状兼容。
- `axes`: 可选参数,指定沿着哪些轴进行变换。如果未给出,则默认为最后两个轴。
- `ifft`: 可选参数,布尔值,指示是否执行逆 FFT 变换。如果为 `True`,执行逆 FFT;如果为 `False`,执行正向 FFT。
- `norm`: 可选参数,指定是否进行归一化处理。
- `out`: 可选参数,指定变换结果存储的目标数组。
函数返回变换后的结果数组。
# 使用 array_function_dispatch 装饰器将该函数分派给 _fftn_dispatcher 处理
@array_function_dispatch(_fftn_dispatcher)
# 定义 rfftn 函数,用于计算 N 维实数输入的离散傅里叶变换
def rfftn(a, s=None, axes=None, norm=None, out=None):
"""
Compute the N-dimensional discrete Fourier Transform for real input.
This function computes the N-dimensional discrete Fourier Transform over
any number of axes in an M-dimensional real array by means of the Fast
Fourier Transform (FFT). By default, all axes are transformed, with the
real transform performed over the last axis, while the remaining
transforms are complex.
Parameters
----------
a : array_like
Input array, taken to be real.
s : sequence of ints, optional
Shape (length along each transformed axis) to use from the input.
(``s[0]`` refers to axis 0, ``s[1]`` to axis 1, etc.).
The final element of `s` corresponds to `n` for ``rfft(x, n)``, while
for the remaining axes, it corresponds to `n` for ``fft(x, n)``.
Along any axis, if the given shape is smaller than that of the input,
the input is cropped. If it is larger, the input is padded with zeros.
.. versionchanged:: 2.0
If it is ``-1``, the whole input is used (no padding/trimming).
If `s` is not given, the shape of the input along the axes specified
by `axes` is used.
.. deprecated:: 2.0
If `s` is not ``None``, `axes` must not be ``None`` either.
.. deprecated:: 2.0
`s` must contain only ``int`` s, not ``None`` values. ``None``
values currently mean that the default value for ``n`` is used
in the corresponding 1-D transform, but this behaviour is
deprecated.
axes : sequence of ints, optional
Axes over which to compute the FFT. If not given, the last ``len(s)``
axes are used, or all axes if `s` is also not specified.
.. deprecated:: 2.0
If `s` is specified, the corresponding `axes` to be transformed
must be explicitly specified too.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype for all axes (and hence is
incompatible with passing in all but the trivial ``s``).
.. versionadded:: 2.0.0
Returns
-------
a = asarray(a)
s, axes = _cook_nd_args(a, s, axes)
a = rfft(a, s[-1], axes[-1], norm, out=out)
for ii in range(len(axes)-1):
a = fft(a, s[ii], axes[ii], norm, out=out)
return a
@array_function_dispatch(_fftn_dispatcher)
def rfft2(a, s=None, axes=(-2, -1), norm=None, out=None):
"""
Compute the 2-dimensional FFT of a real array.
Parameters
----------
a : array
Input array, taken to be real.
s : sequence of ints, optional
Shape of the FFT.
.. versionchanged:: 2.0
If it is ``-1``, the whole input is used (no padding/trimming).
.. deprecated:: 2.0
If `s` is not ``None``, `axes` must not be ``None`` either.
.. deprecated:: 2.0
`s` must contain only ``int`` s, not ``None`` values. ``None``
values currently mean that the default value for ``n`` is used
in the corresponding 1-D transform, but this behaviour is
deprecated.
axes : sequence of ints, optional
Axes over which to compute the FFT. Default: ``(-2, -1)``.
.. deprecated:: 2.0
If `s` is specified, the corresponding `axes` to be transformed
must not be ``None``.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : complex ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype for the last inverse transform.
incompatible with passing in all but the trivial ``s``).
.. versionadded:: 2.0.0
Returns
-------
out : ndarray
The result of the real 2-D FFT.
See Also
--------
rfftn : Compute the N-dimensional discrete Fourier Transform for real
input.
Notes
-----
This is really just `rfftn` with different default behavior.
For more details see `rfftn`.
Examples
--------
>>> a = np.mgrid[:5, :5][0]
>>> np.fft.rfft2(a)
array([[ 50. +0.j , 0. +0.j , 0. +0.j ],
[-12.5+17.20477401j, 0. +0.j , 0. +0.j ],
[-12.5 +4.0614962j , 0. +0.j , 0. +0.j ],
[-12.5 -4.0614962j , 0. +0.j , 0. +0.j ],
[-12.5-17.20477401j, 0. +0.j , 0. +0.j ]])
"""
return rfftn(a, s, axes, norm, out=out)
@array_function_dispatch(_fftn_dispatcher)
def irfftn(a, s=None, axes=None, norm=None, out=None):
"""
Computes the inverse of `rfftn`.
This function computes the inverse of the N-dimensional discrete
Fourier Transform for real input over any number of axes in an
M-dimensional array by means of the Fast Fourier Transform (FFT). In
other words, ``irfftn(rfftn(a), a.shape) == a`` to within numerical
"""
a = asarray(a)
s, axes = _cook_nd_args(a, s, axes, invreal=1)
for ii in range(len(axes)-1):
a = ifft(a, s[ii], axes[ii], norm)
a = irfft(a, s[-1], axes[-1], norm, out=out)
return a
@array_function_dispatch(_fftn_dispatcher)
def irfft2(a, s=None, axes=(-2, -1), norm=None, out=None):
"""
Computes the inverse of `rfft2`.
Parameters
----------
a : array_like
The input array
s : sequence of ints, optional
Shape of the real output to the inverse FFT.
.. versionchanged:: 2.0
If it is ``-1``, the whole input is used (no padding/trimming).
.. deprecated:: 2.0
If `s` is not ``None``, `axes` must not be ``None`` either.
.. deprecated:: 2.0
`s` must contain only ``int`` s, not ``None`` values. ``None``
values currently mean that the default value for ``n`` is used
in the corresponding 1-D transform, but this behaviour is
deprecated.
axes : sequence of ints, optional
The axes over which to compute the inverse fft.
Default: ``(-2, -1)``, the last two axes.
.. deprecated:: 2.0
If `s` is specified, the corresponding `axes` to be transformed
must not be ``None``.
norm : {"backward", "ortho", "forward"}, optional
.. versionadded:: 1.10.0
Normalization mode (see `numpy.fft`). Default is "backward".
Indicates which direction of the forward/backward pair of transforms
is scaled and with what normalization factor.
.. versionadded:: 1.20.0
The "backward", "forward" values were added.
out : ndarray, optional
If provided, the result will be placed in this array. It should be
of the appropriate shape and dtype for the last transformation.
.. versionadded:: 2.0.0
Returns
-------
out : ndarray
The result of the inverse real 2-D FFT.
See Also
--------
rfft2 : The forward two-dimensional FFT of real input,
of which `irfft2` is the inverse.
rfft : The one-dimensional FFT for real input.
irfft : The inverse of the one-dimensional FFT of real input.
irfftn : Compute the inverse of the N-dimensional FFT of real input.
Notes
-----
This is really `irfftn` with different defaults.
For more details see `irfftn`.
Examples
--------
>>> a = np.mgrid[:5, :5][0]
>>> A = np.fft.rfft2(a)
>>> np.fft.irfft2(A, s=a.shape)
array([[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[2., 2., 2., 2., 2.],
[3., 3., 3., 3., 3.],
[4., 4., 4., 4., 4.]])
"""
return irfftn(a, s, axes, norm, out=None)
.\numpy\numpy\fft\_pocketfft.pyi
from collections.abc import Sequence
from typing import Literal as L
from numpy import complex128, float64
from numpy._typing import ArrayLike, NDArray, _ArrayLikeNumber_co
_NormKind = L[None, "backward", "ortho", "forward"]
__all__: list[str]
def fft(
a: ArrayLike,
n: None | int = ...,
axis: int = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def ifft(
a: ArrayLike,
n: None | int = ...,
axis: int = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def rfft(
a: ArrayLike,
n: None | int = ...,
axis: int = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def irfft(
a: ArrayLike,
n: None | int = ...,
axis: int = ...,
norm: _NormKind = ...,
out: None | NDArray[float64] = ...,
) -> NDArray[float64]: ...
def hfft(
a: _ArrayLikeNumber_co,
n: None | int = ...,
axis: int = ...,
norm: _NormKind = ...,
out: None | NDArray[float64] = ...,
) -> NDArray[float64]: ...
def ihfft(
a: ArrayLike,
n: None | int = ...,
axis: int = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def fftn(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def ifftn(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def rfftn(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def irfftn(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[float64] = ...,
) -> NDArray[float64]: ...
def fft2(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def ifft2(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def rfft2(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[complex128] = ...,
) -> NDArray[complex128]: ...
def irfft2(
a: ArrayLike,
s: None | Sequence[int] = ...,
axes: None | Sequence[int] = ...,
norm: _NormKind = ...,
out: None | NDArray[float64] = ...,
) -> NDArray[float64]: ...
.\numpy\numpy\fft\_pocketfft_umath.cpp
/*
* This file is part of pocketfft.
* Licensed under a 3-clause BSD style license - see LICENSE.md
*/
/*
* Main implementation file.
*
* Copyright (C) 2004-2018 Max-Planck-Society
* \author Martin Reinecke
*/
// 定义宏,防止使用废弃的 NumPy API
// 清理 PY_SSIZE_T 类型
// 引入 NumPy 头文件
// 引入 NumPy 的配置文件
// 定义不使用多线程选项
/*
* In order to ensure that C++ exceptions are converted to Python
* ones before crossing over to the C machinery, we must catch them.
* This template can be used to wrap a C++ written ufunc to do this via:
* wrap_legacy_cpp_ufunc<cpp_ufunc>
*/
template<PyUFuncGenericFunction cpp_ufunc>
static void
wrap_legacy_cpp_ufunc(char **args, npy_intp const *dimensions,
ptrdiff_t const *steps, void *func)
{
NPY_ALLOW_C_API_DEF
try {
cpp_ufunc(args, dimensions, steps, func);
}
catch (std::bad_alloc& e) {
NPY_ALLOW_C_API;
PyErr_NoMemory();
NPY_DISABLE_C_API;
}
catch (const std::exception& e) {
NPY_ALLOW_C_API;
PyErr_SetString(PyExc_RuntimeError, e.what());
NPY_DISABLE_C_API;
}
}
/*
* Transfer to and from a contiguous buffer.
* copy_input: copy min(nin, n) elements from input to buffer and zero rest.
* copy_output: copy n elements from buffer to output.
*/
template <typename T>
static inline void
copy_input(char *in, npy_intp step_in, size_t nin,
T buff[], size_t n)
{
// 复制输入数据到缓冲区,如果输入数据量小于缓冲区大小,则剩余部分填充为零
size_t ncopy = nin <= n ? nin : n;
char *ip = in;
size_t i;
for (i = 0; i < ncopy; i++, ip += step_in) {
buff[i] = *(T *)ip;
}
for (; i < n; i++) {
buff[i] = 0;
}
}
template <typename T>
static inline void
copy_output(T buff[], char *out, npy_intp step_out, size_t n)
{
// 从缓冲区复制数据到输出
char *op = out;
for (size_t i = 0; i < n; i++, op += step_out) {
*(T *)op = buff[i];
}
}
/*
* Gufunc loops calling the pocketfft code.
*/
template <typename T>
static void
fft_loop(char **args, npy_intp const *dimensions, ptrdiff_t const *steps,
void *func)
{
// 获取输入、中间、输出位置
char *ip = args[0], *fp = args[1], *op = args[2];
// 获取循环次数和步长
size_t n_outer = (size_t)dimensions[0];
ptrdiff_t si = steps[0], sf = steps[1], so = steps[2];
// 获取输入和输出的大小
size_t nin = (size_t)dimensions[1], nout = (size_t)dimensions[2];
// 获取输入和输出的步长
ptrdiff_t step_in = steps[3], step_out = steps[4];
// 获取 FFT 的方向,前向还是后向
bool direction = *((bool *)func); /* pocketfft::FORWARD or BACKWARD */
// 断言输出大小大于零
assert (nout > 0);
/*
* For the common case of nin >= nout, fixed factor, and suitably sized
* outer loop, we call pocketfft directly to benefit from its vectorization.
* (For nin>nout, this just removes the extra input points, as required;
* the vlen constraint avoids compiling extra code for longdouble, which
* cannot be vectorized so does not benefit.)
*/
constexpr auto vlen = pocketfft::detail::VLEN<T>::val;
if (vlen > 1 && n_outer >= vlen && nin >= nout && sf == 0) {
std::vector<size_t> shape = { n_outer, nout };
std::vector<ptrdiff_t> strides_in = { si, step_in };
std::vector<ptrdiff_t> strides_out = { so, step_out};
std::vector<size_t> axes = { 1 };
pocketfft::c2c(shape, strides_in, strides_out, axes, direction,
(std::complex<T> *)ip, (std::complex<T> *)op, *(T *)fp);
return;
}
/*
* Otherwise, use a non-vectorized loop in which we try to minimize copies.
* We do still need a buffer if the output is not contiguous.
*/
// 获取 FFT 执行计划,使用 pocketfft 库中的具体实现 pocketfft_c<T>
auto plan = pocketfft::detail::get_plan<pocketfft::detail::pocketfft_c<T>>(nout);
// 检查输出是否不是连续的,如果不是则需要缓冲区
auto buffered = (step_out != sizeof(std::complex<T>));
// 创建一个 std::complex<T> 类型的数组,用作缓冲区,如果需要的话
pocketfft::detail::arr<std::complex<T>> buff(buffered ? nout : 0);
// 外部循环,处理每个输入输出的数据块
for (size_t i = 0; i < n_outer; i++, ip += si, fp += sf, op += so) {
// 确定要操作的输出数据位置
std::complex<T> *op_or_buff = buffered ? buff.data() : (std::complex<T> *)op;
// 如果输入数据和输出数据不在同一位置,则进行复制操作
if (ip != (char*)op_or_buff) {
copy_input(ip, step_in, nin, op_or_buff, nout);
}
// 执行 FFT 变换
plan->exec((pocketfft::detail::cmplx<T> *)op_or_buff, *(T *)fp, direction);
// 如果使用了缓冲区,则需要将结果从缓冲区复制回输出位置
if (buffered) {
copy_output(op_or_buff, op, step_out, nout);
}
}
// 函数返回,处理结束
return;
}
template <typename T>
static void
rfft_impl(char **args, npy_intp const *dimensions, npy_intp const *steps,
void *func, size_t npts)
{
char *ip = args[0], *fp = args[1], *op = args[2];
size_t n_outer = (size_t)dimensions[0];
ptrdiff_t si = steps[0], sf = steps[1], so = steps[2];
size_t nin = (size_t)dimensions[1], nout = (size_t)dimensions[2];
ptrdiff_t step_in = steps[3], step_out = steps[4];
// 断言确保输出数据点数大于 0,且符合实际情况
assert (nout > 0 && nout == npts / 2 + 1);
/*
* Call pocketfft directly if vectorization is possible.
*/
// 如果支持向量化,并且条件允许,直接调用 pocketfft 库中的向量化函数
constexpr auto vlen = pocketfft::detail::VLEN<T>::val;
if (vlen > 1 && n_outer >= vlen && nin >= npts && sf == 0) {
// 定义输入和输出数据的形状和步长
std::vector<size_t> shape_in = { n_outer, npts };
std::vector<ptrdiff_t> strides_in = { si, step_in };
std::vector<ptrdiff_t> strides_out = { so, step_out};
std::vector<size_t> axes = { 1 };
// 调用向量化的 r2c FFT 变换
pocketfft::r2c(shape_in, strides_in, strides_out, axes, pocketfft::FORWARD,
(T *)ip, (std::complex<T> *)op, *(T *)fp);
// 函数返回,处理结束
return;
}
/*
* Otherwise, use a non-vectorized loop in which we try to minimize copies.
* We do still need a buffer if the output is not contiguous.
*/
// 获取 FFT 执行计划,使用 pocketfft 库中的具体实现 pocketfft_r<T>
auto plan = pocketfft::detail::get_plan<pocketfft::detail::pocketfft_r<T>>(npts);
// 检查输出是否不是连续的,如果不是则需要缓冲区
auto buffered = (step_out != sizeof(std::complex<T>));
// 创建一个 std::complex<T> 类型的数组,用作缓冲区,如果需要的话
pocketfft::detail::arr<std::complex<T>> buff(buffered ? nout : 0);
// 确定实际使用的输入数据点数,取最小值
auto nin_used = nin <= npts ? nin : npts;
for (size_t i = 0; i < n_outer; i++, ip += si, fp += sf, op += so) {
// 使用条件运算符确定 op_or_buff 是直接使用缓冲区还是指向 op 的指针
std::complex<T> *op_or_buff = buffered ? buff.data() : (std::complex<T> *)op;
/*
* 内部的 pocketfft 程序在原地工作,对于实数变换,频率数据因此需要压缩,
* 利用这一点,在零频率项(即所有输入的总和,因此必须是实数)和对于偶数点数的奈奎斯特频率项没有虚部。
* Pocketfft 使用 FFTpack 的顺序,R0,R1,I1,...Rn-1,In-1,Rn(仅当点数 npts 为奇数时最后的 In)。为了使解包易于进行,
* 我们在缓冲区中将实数数据偏移了一个位置,因此我们只需要移动 R0 并创建 I0=0。注意,copy_input 将会将偶数点数的情况下的 In 分量置零。
*/
// 调用 copy_input 函数,将输入数据复制到 op_or_buff 中
copy_input(ip, step_in, nin_used, &((T *)op_or_buff)[1], nout*2 - 1);
// 执行 FFT 变换,使用 pocketfft 库的前向变换
plan->exec(&((T *)op_or_buff)[1], *(T *)fp, pocketfft::FORWARD);
// 将 op_or_buff 的第一个元素设为其虚部,实现 I0->R0, I0=0 的转换
op_or_buff[0] = op_or_buff[0].imag();
// 如果使用了缓冲区,将处理完的输出数据复制回 op
if (buffered) {
copy_output(op_or_buff, op, step_out, nout);
}
}
// 函数返回
return;
/*
* For the forward real, we cannot know what the requested number of points is
* just based on the number of points in the complex output array (e.g., 10
* and 11 real input points both lead to 6 complex output points), so we
* define versions for both even and odd number of points.
*/
template <typename T>
static void
rfft_n_even_loop(char **args, npy_intp const *dimensions, npy_intp const *steps, void *func)
{
// 获取输出数组中的复数点数
size_t nout = (size_t)dimensions[2];
assert (nout > 0); // 断言:输出点数应大于零
// 计算输入点数
size_t npts = 2 * nout - 2;
// 调用实际的 FFT 实现函数
rfft_impl<T>(args, dimensions, steps, func, npts);
}
/*
* For the forward real, we cannot know what the requested number of points is
* just based on the number of points in the complex output array (e.g., 10
* and 11 real input points both lead to 6 complex output points), so we
* define versions for both even and odd number of points.
*/
template <typename T>
static void
rfft_n_odd_loop(char **args, npy_intp const *dimensions, npy_intp const *steps, void *func)
{
// 获取输出数组中的复数点数
size_t nout = (size_t)dimensions[2];
assert (nout > 0); // 断言:输出点数应大于零
// 计算输入点数
size_t npts = 2 * nout - 1;
// 调用实际的 FFT 实现函数
rfft_impl<T>(args, dimensions, steps, func, npts);
}
/*
* This function handles the inverse real FFT operation.
*/
template <typename T>
static void
irfft_loop(char **args, npy_intp const *dimensions, npy_intp const *steps, void *func)
{
// 获取输入、滤波器和输出数组的指针
char *ip = args[0], *fp = args[1], *op = args[2];
// 获取外部循环的大小
size_t n_outer = (size_t)dimensions[0];
// 获取输入和输出数组的步长
ptrdiff_t si = steps[0], sf = steps[1], so = steps[2];
// 获取输入数组的大小和输出数组的复数点数
size_t nin = (size_t)dimensions[1], nout = (size_t)dimensions[2];
// 获取输入和输出数组的步长
ptrdiff_t step_in = steps[3], step_out = steps[4];
// 计算输入数组中的点数
size_t npts_in = nout / 2 + 1;
assert(nout > 0); // 断言:输出点数应大于零
/*
* Call pocketfft directly if vectorization is possible.
*/
// 如果支持向量化,并且满足调用条件,则直接调用 pocketfft 函数进行计算
constexpr auto vlen = pocketfft::detail::VLEN<T>::val;
if (vlen > 1 && n_outer >= vlen && nin >= npts_in && sf == 0) {
// 设置要进行计算的维度和步长
std::vector<size_t> axes = { 1 };
std::vector<size_t> shape_out = { n_outer, nout };
std::vector<ptrdiff_t> strides_in = { si, step_in };
std::vector<ptrdiff_t> strides_out = { so, step_out };
// 调用 pocketfft 的逆变换函数
pocketfft::c2r(shape_out, strides_in, strides_out, axes, pocketfft::BACKWARD,
(std::complex<T> *)ip, (T *)op, *(T *)fp);
return;
}
/*
* Otherwise, use a non-vectorized loop in which we try to minimize copies.
* We do still need a buffer if the output is not contiguous.
*/
// 否则,使用非向量化的循环进行计算,尽量减少拷贝操作
auto plan = pocketfft::detail::get_plan<pocketfft::detail::pocketfft_r<T>>(nout);
auto buffered = (step_out != sizeof(T));
// 如果输出不是连续的,则分配缓冲区
pocketfft::detail::arr<T> buff(buffered ? nout : 0);
for (size_t i = 0; i < n_outer; i++, ip += si, fp += sf, op += so) {
// 确定输出数组的位置,可以是缓冲区或者直接操作数组
T *op_or_buff = buffered ? buff.data() : (T *)op;
/*
* Pocket_fft 在原地操作,对于反向实数变换,频率数据需要压缩,
* 移除零频率项的虚部(这是所有输入的总和,因此必须是实数),
* 以及偶数点数时的奈奎斯特频率的虚部。因此,我们按以下顺序将数据复制到缓冲区
* (也被 FFTpack 使用):R0,R1,I1,...Rn-1,In-1,Rn[,In](对于奇数点数才有In)。
*/
// 复制 R0 到输出数组或缓冲区的第一个位置
op_or_buff[0] = ((T *)ip)[0]; /* copy R0 */
// 如果输出点数大于1
if (nout > 1) {
/*
* 复制 R1,I1... 直到 Rn-1,In-1(如果可能),如果不需要所有输入点数或者输入较短,
* 则提前停止并在其后补零。
*/
copy_input(ip + step_in, step_in, nin - 1,
(std::complex<T> *)&op_or_buff[1], (nout - 1) / 2);
// 对于偶数的 nout,仍然需要设置 Rn
if (nout % 2 == 0) {
op_or_buff[nout - 1] = (nout / 2 >= nin) ? (T)0 :
((T *)(ip + (nout / 2) * step_in))[0];
}
}
// 执行逆向变换操作
plan->exec(op_or_buff, *(T *)fp, pocketfft::BACKWARD);
// 如果使用了缓冲区,则将结果复制回输出数组
if (buffered) {
copy_output(op_or_buff, op, step_out, nout);
}
}
// 函数结束,无返回值
return;
}
// 定义用于 FFT 的通用函数指针数组,包含双精度、单精度和长双精度的前向 FFT
static PyUFuncGenericFunction fft_functions[] = {
wrap_legacy_cpp_ufunc<fft_loop<npy_double>>,
wrap_legacy_cpp_ufunc<fft_loop<npy_float>>,
wrap_legacy_cpp_ufunc<fft_loop<npy_longdouble>>
};
// 定义 FFT 的数据类型数组,包括复数双精度、双精度、复数单精度、单精度、复数长双精度、长双精度
static const char fft_types[] = {
NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE,
NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT,
NPY_CLONGDOUBLE, NPY_LONGDOUBLE, NPY_CLONGDOUBLE
};
// 定义用于 FFT 的数据指针数组,全部指向前向 FFT
static void *const fft_data[] = {
(void*)&pocketfft::FORWARD,
(void*)&pocketfft::FORWARD,
(void*)&pocketfft::FORWARD
};
// 定义用于 IFFT 的数据指针数组,全部指向后向 FFT
static void *const ifft_data[] = {
(void*)&pocketfft::BACKWARD,
(void*)&pocketfft::BACKWARD,
(void*)&pocketfft::BACKWARD
};
// 定义用于偶数长度实数 FFT 的通用函数指针数组,包含双精度、单精度和长双精度
static PyUFuncGenericFunction rfft_n_even_functions[] = {
wrap_legacy_cpp_ufunc<rfft_n_even_loop<npy_double>>,
wrap_legacy_cpp_ufunc<rfft_n_even_loop<npy_float>>,
wrap_legacy_cpp_ufunc<rfft_n_even_loop<npy_longdouble>>
};
// 定义用于奇数长度实数 FFT 的通用函数指针数组,包含双精度、单精度和长双精度
static PyUFuncGenericFunction rfft_n_odd_functions[] = {
wrap_legacy_cpp_ufunc<rfft_n_odd_loop<npy_double>>,
wrap_legacy_cpp_ufunc<rfft_n_odd_loop<npy_float>>,
wrap_legacy_cpp_ufunc<rfft_n_odd_loop<npy_longdouble>>
};
// 定义实数 FFT 的数据类型数组,包括双精度、复数双精度、单精度、复数单精度、长双精度、复数长双精度
static const char rfft_types[] = {
NPY_DOUBLE, NPY_DOUBLE, NPY_CDOUBLE,
NPY_FLOAT, NPY_FLOAT, NPY_CFLOAT,
NPY_LONGDOUBLE, NPY_LONGDOUBLE, NPY_CLONGDOUBLE
};
// 定义用于逆 FFT 的通用函数指针数组,包含双精度、单精度和长双精度
static PyUFuncGenericFunction irfft_functions[] = {
wrap_legacy_cpp_ufunc<irfft_loop<npy_double>>,
wrap_legacy_cpp_ufunc<irfft_loop<npy_float>>,
wrap_legacy_cpp_ufunc<irfft_loop<npy_longdouble>>
};
// 定义逆 FFT 的数据类型数组,包括复数双精度、双精度、单精度、复数单精度、长双精度、双精度
static const char irfft_types[] = {
NPY_CDOUBLE, NPY_DOUBLE, NPY_DOUBLE,
NPY_CFLOAT, NPY_FLOAT, NPY_FLOAT,
NPY_CLONGDOUBLE, NPY_LONGDOUBLE, NPY_LONGDOUBLE
};
// 添加通用函数到给定的 Python 字典
static int
add_gufuncs(PyObject *dictionary) {
PyObject *f;
// 创建 fft 函数对象并添加到字典中
f = PyUFunc_FromFuncAndDataAndSignature(
fft_functions, fft_data, fft_types, 3, 2, 1, PyUFunc_None,
"fft", "complex forward FFT\n", 0, "(n),()->(m)");
if (f == NULL) {
return -1;
}
PyDict_SetItemString(dictionary, "fft", f);
Py_DECREF(f);
// 创建 ifft 函数对象并添加到字典中
f = PyUFunc_FromFuncAndDataAndSignature(
fft_functions, ifft_data, fft_types, 3, 2, 1, PyUFunc_None,
"ifft", "complex backward FFT\n", 0, "(m),()->(n)");
if (f == NULL) {
return -1;
}
PyDict_SetItemString(dictionary, "ifft", f);
Py_DECREF(f);
// 创建 rfft_n_even 函数对象并添加到字典中
f = PyUFunc_FromFuncAndDataAndSignature(
rfft_n_even_functions, NULL, rfft_types, 3, 2, 1, PyUFunc_None,
"rfft_n_even", "real forward FFT for even n\n", 0, "(n),()->(m)");
if (f == NULL) {
return -1;
}
PyDict_SetItemString(dictionary, "rfft_n_even", f);
Py_DECREF(f);
// 创建 rfft_n_odd 函数对象并添加到字典中
f = PyUFunc_FromFuncAndDataAndSignature(
rfft_n_odd_functions, NULL, rfft_types, 3, 2, 1, PyUFunc_None,
"rfft_n_odd", "real forward FFT for odd n\n", 0, "(n),()->(m)");
if (f == NULL) {
return -1;
}
PyDict_SetItemString(dictionary, "rfft_n_odd", f);
Py_DECREF(f);
f = PyUFunc_FromFuncAndDataAndSignature(
irfft_functions, NULL, irfft_types, 3, 2, 1, PyUFunc_None,
"irfft", "real backward FFT\n", 0, "(m),()->(n)");
if (f == NULL) {
return -1;
}
PyDict_SetItemString(dictionary, "irfft", f);
Py_DECREF(f);
return 0;
}
static struct PyModuleDef moduledef = {
// 定义 Python 模块的基本信息,使用默认的头部初始化
PyModuleDef_HEAD_INIT,
"_multiarray_umath", // 模块名为 "_multiarray_umath"
NULL, // 模块的文档字符串为 NULL
-1, // 模块状态为 -1(表示模块不可重入)
NULL, // 模块方法结构体为 NULL
NULL, // 模块全局变量的结构体为 NULL
NULL, // 模块的初始化函数为 NULL
NULL, // 模块的清理函数为 NULL
NULL // 模块的销毁函数为 NULL
};
/* Initialization function for the module */
// 模块的初始化函数,命名为 PyInit__pocketfft_umath
PyMODINIT_FUNC PyInit__pocketfft_umath(void)
{
// 创建一个 Python 模块对象
PyObject *m = PyModule_Create(&moduledef);
// 如果创建失败,返回 NULL
if (m == NULL) {
return NULL;
}
/* Import the array and ufunc objects */
// 导入数组对象和通用函数对象
import_array();
import_ufunc();
// 获取模块的字典对象
PyObject *d = PyModule_GetDict(m);
// 如果添加通用函数失败,清理内存并返回 NULL
if (add_gufuncs(d) < 0) {
Py_DECREF(d);
Py_DECREF(m);
return NULL;
}
// 返回创建的模块对象
return m;
}
.\numpy\numpy\fft\__init__.py
"""
Discrete Fourier Transform (:mod:`numpy.fft`)
=============================================
.. currentmodule:: numpy.fft
The SciPy module `scipy.fft` is a more comprehensive superset
of ``numpy.fft``, which includes only a basic set of routines.
Standard FFTs
-------------
.. autosummary::
:toctree: generated/
fft Discrete Fourier transform.
ifft Inverse discrete Fourier transform.
fft2 Discrete Fourier transform in two dimensions.
ifft2 Inverse discrete Fourier transform in two dimensions.
fftn Discrete Fourier transform in N-dimensions.
ifftn Inverse discrete Fourier transform in N dimensions.
Real FFTs
---------
.. autosummary::
:toctree: generated/
rfft Real discrete Fourier transform.
irfft Inverse real discrete Fourier transform.
rfft2 Real discrete Fourier transform in two dimensions.
irfft2 Inverse real discrete Fourier transform in two dimensions.
rfftn Real discrete Fourier transform in N dimensions.
irfftn Inverse real discrete Fourier transform in N dimensions.
Hermitian FFTs
--------------
.. autosummary::
:toctree: generated/
hfft Hermitian discrete Fourier transform.
ihfft Inverse Hermitian discrete Fourier transform.
Helper routines
---------------
.. autosummary::
:toctree: generated/
fftfreq Discrete Fourier Transform sample frequencies.
rfftfreq DFT sample frequencies (for usage with rfft, irfft).
fftshift Shift zero-frequency component to center of spectrum.
ifftshift Inverse of fftshift.
Background information
----------------------
Fourier analysis is fundamentally a method for expressing a function as a
sum of periodic components, and for recovering the function from those
components. When both the function and its Fourier transform are
replaced with discretized counterparts, it is called the discrete Fourier
transform (DFT). The DFT has become a mainstay of numerical computing in
part because of a very fast algorithm for computing it, called the Fast
Fourier Transform (FFT), which was known to Gauss (1805) and was brought
to light in its current form by Cooley and Tukey [CT]_. Press et al. [NR]_
provide an accessible introduction to Fourier analysis and its
applications.
Because the discrete Fourier transform separates its input into
components that contribute at discrete frequencies, it has a great number
of applications in digital signal processing, e.g., for filtering, and in
this context the discretized input to the transform is customarily
referred to as a *signal*, which exists in the *time domain*. The output
is called a *spectrum* or *transform* and exists in the *frequency
domain*.
Implementation details
----------------------
There are many ways to define the DFT, varying in the sign of the
exponent, normalization, etc. In this implementation, the DFT is defined
as
"""
A_k = \\sum_{m=0}^{n-1} a_m \\exp\\left\\{-2\\pi i{mk \\over n}\\right\\}
\\qquad k = 0,\\ldots,n-1.
a_m = \\exp\\{2\\pi i\\,f m\\Delta t\\}
a_m = \\frac{1}{n}\\sum_{k=0}^{n-1}A_k\\exp\\left\\{2\\pi i{mk\\over n}\\right\\}
\\qquad m = 0,\\ldots,n-1.
Type Promotion
--------------
Normalization
-------------
Real and Hermitian transforms
-----------------------------
from . import _pocketfft, _helper
from . import helper
from ._pocketfft import *
from ._helper import *
__all__ = _pocketfft.__all__.copy()
__all__ += _helper.__all__
from numpy._pytesttester import PytestTester
test = PytestTester(__name__)
del PytestTester
.\numpy\numpy\fft\__init__.pyi
from numpy._pytesttester import PytestTester
from numpy.fft._pocketfft import (
fft as fft,
ifft as ifft,
rfft as rfft,
irfft as irfft,
hfft as hfft,
ihfft as ihfft,
rfftn as rfftn,
irfftn as irfftn,
rfft2 as rfft2,
irfft2 as irfft2,
fft2 as fft2,
ifft2 as ifft2,
fftn as fftn,
ifftn as ifftn,
)
from numpy.fft._helper import (
fftshift as fftshift,
ifftshift as ifftshift,
fftfreq as fftfreq,
rfftfreq as rfftfreq,
)
__all__: list[str]
test: PytestTester
.\numpy\numpy\lib\array_utils.py
from ._array_utils_impl import (
__all__,
__doc__,
byte_bounds,
normalize_axis_index,
normalize_axis_tuple,
)
.\numpy\numpy\lib\array_utils.pyi
from ._array_utils_impl import (
__all__ as __all__,
byte_bounds as byte_bounds,
normalize_axis_index as normalize_axis_index,
normalize_axis_tuple as normalize_axis_tuple,
)
.\numpy\numpy\lib\format.py
"""
Binary serialization
NPY format
==========
A simple format for saving numpy arrays to disk with the full
information about them.
The ``.npy`` format is the standard binary file format in NumPy for
persisting a *single* arbitrary NumPy array on disk. The format stores all
of the shape and dtype information necessary to reconstruct the array
correctly even on another machine with a different architecture.
The format is designed to be as simple as possible while achieving
its limited goals.
The ``.npz`` format is the standard format for persisting *multiple* NumPy
arrays on disk. A ``.npz`` file is a zip file containing multiple ``.npy``
files, one for each array.
Capabilities
------------
- Can represent all NumPy arrays including nested record arrays and
object arrays.
- Represents the data in its native binary form.
- Supports Fortran-contiguous arrays directly.
- Stores all of the necessary information to reconstruct the array
including shape and dtype on a machine of a different
architecture. Both little-endian and big-endian arrays are
supported, and a file with little-endian numbers will yield
a little-endian array on any machine reading the file. The
types are described in terms of their actual sizes. For example,
if a machine with a 64-bit C "long int" writes out an array with
"long ints", a reading machine with 32-bit C "long ints" will yield
an array with 64-bit integers.
- Is straightforward to reverse engineer. Datasets often live longer than
the programs that created them. A competent developer should be
able to create a solution in their preferred programming language to
read most ``.npy`` files that they have been given without much
documentation.
- Allows memory-mapping of the data. See `open_memmap`.
- Can be read from a filelike stream object instead of an actual file.
- Stores object arrays, i.e. arrays containing elements that are arbitrary
Python objects. Files with object arrays are not to be mmapable, but
can be read and written to disk.
Limitations
-----------
- Arbitrary subclasses of numpy.ndarray are not completely preserved.
Subclasses will be accepted for writing, but only the array data will
be written out. A regular numpy.ndarray object will be created
upon reading the file.
.. warning::
Due to limitations in the interpretation of structured dtypes, dtypes
with fields with empty names will have the names replaced by 'f0', 'f1',
etc. Such arrays will not round-trip through the format entirely
accurately. The data is intact; only the field names will differ. We are
working on a fix for this. This fix will not require a change in the
file format. The arrays with such structures can still be saved and
restored, and the correct dtype may be restored by using the
``loadedarray.view(correct_dtype)`` method.
File extensions
---------------
We recommend using the ``.npy`` and ``.npz`` extensions for files saved
"""
Format Version 3.0
------------------
This version replaces the ASCII string (which in practice was latin1) with
a utf8-encoded string, so supports structured types with any unicode field
names.
Notes
-----
The ``.npy`` format, including motivation for creating it and a comparison of
alternatives, is described in the
:doc:`"npy-format" NEP <neps:nep-0001-npy-format>`, however details have
evolved with time and this document is more current.
"""
import io
import os
import pickle
import warnings
import numpy
from numpy.lib._utils_impl import drop_metadata
__all__ = []
# 预期的键集合,用于检查数据类型描述字典的完整性
EXPECTED_KEYS = {'descr', 'fortran_order', 'shape'}
# 魔术前缀,用于识别.npy文件的起始标志
MAGIC_PREFIX = b'\x93NUMPY'
# 魔术字符串的长度
MAGIC_LEN = len(MAGIC_PREFIX) + 2
# 数组的对齐方式,默认为64,通常是2的幂,介于16到4096之间
ARRAY_ALIGN = 64
# 用于读取npz文件的缓冲区大小,以字节为单位
BUFFER_SIZE = 2**18
# 允许在64位系统中某个轴向上的地址空间内进行增长
GROWTH_AXIS_MAX_DIGITS = 21 # = len(str(8*2**64-1)) hypothetical int1 dtype
# 版本1.0和2.0之间的区别是头部长度由2字节(H)扩展为4字节(I),以支持大型结构化数组的存储
# 版本信息与对应的头部格式
_header_size_info = {
(1, 0): ('<H', 'latin1'),
(2, 0): ('<I', 'latin1'),
(3, 0): ('<I', 'utf8'),
}
# Python的literal_eval函数在处理大输入时并不安全,因为解析可能会变慢甚至导致解释器崩溃。
# 这是一个任意设置的低限,应该在实践中是安全的。
_MAX_HEADER_SIZE = 10000
def _check_version(version):
"""
检查给定的文件格式版本是否受支持。
Parameters
----------
version : tuple
文件格式的主次版本号组成的元组
Raises
------
ValueError
如果版本不是(1,0),(2,0)或(3,0)中的一个
"""
if version not in [(1, 0), (2, 0), (3, 0), None]:
msg = "we only support format version (1,0), (2,0), and (3,0), not %s"
raise ValueError(msg % (version,))
def magic(major, minor):
"""
返回给定文件格式版本的魔术字符串。
Parameters
----------
major : int in [0, 255]
主版本号,应在0到255之间
minor : int in [0, 255]
次版本号,应在0到255之间
Returns
-------
magic : str
魔术字符串,用于表示文件格式版本
Raises
------
ValueError
如果版本号超出范围
"""
if major < 0 or major > 255:
raise ValueError("major version must be 0 <= major < 256")
if minor < 0 or minor > 255:
raise ValueError("minor version must be 0 <= minor < 256")
return MAGIC_PREFIX + bytes([major, minor])
def read_magic(fp):
"""
读取文件中的魔术字符串,获取文件格式的版本信息。
Parameters
----------
fp : filelike object
文件对象或类似文件对象
Returns
-------
major : int
主版本号
minor : int
次版本号
"""
magic_str = _read_bytes(fp, MAGIC_LEN, "magic string")
if magic_str[:-2] != MAGIC_PREFIX:
msg = "the magic string is not correct; expected %r, got %r"
raise ValueError(msg % (MAGIC_PREFIX, magic_str[:-2]))
major, minor = magic_str[-2:]
return major, minor
def dtype_to_descr(dtype):
"""
从dtype对象获取可序列化的描述符。
Parameters
----------
dtype : dtype object
数据类型对象
Returns
-------
descr : str
序列化后的描述符字符串
"""
# .descr属性不能通过dtype()构造函数完全回转
# 简单类型(如dtype('float32'))有
`
"""
a descr which looks like a record array with one field with '' as
a name. The dtype() constructor interprets this as a request to give
a default name. Instead, we construct descriptor that can be passed to
dtype().
Parameters
----------
dtype : dtype
The dtype of the array that will be written to disk.
Returns
-------
descr : object
An object that can be passed to `numpy.dtype()` in order to
replicate the input dtype.
"""
# 注意:drop_metadata 可能不会返回正确的 dtype,例如对于用户自定义的 dtype。在这种情况下,我们下面的代码也会失败。
new_dtype = drop_metadata(dtype)
# 如果 drop_metadata 返回的 dtype 与原始的 dtype 不同,发出警告。
if new_dtype is not dtype:
warnings.warn("metadata on a dtype is not saved to an npy/npz. "
"Use another format (such as pickle) to store it.",
UserWarning, stacklevel=2)
# 如果 dtype 具有字段名,则返回该字段的描述符。
if dtype.names is not None:
# 这是一个记录数组。.descr 是合适的。XXX: 像填充字节这样的字段名称为空的部分仍然会被处理。这需要在 dtype() 的 C 实现中修复。
return dtype.descr
# 如果 dtype 不是遗留的,并且被认为是用户自定义的 dtype。
elif not type(dtype)._legacy:
# 这必须是用户定义的 dtype,因为 numpy 在公共 API 中还没有暴露任何非遗留 dtype。
#
# 非遗留 dtype 尚未具有 __array_interface__ 支持。作为一种权宜之计,我们使用 pickle 来保存数组,并且误导性地声称 dtype 是对象类型。
# 当加载数组时,descriptor 会随数组一起反序列化,并且头部的对象 dtype 会被丢弃。
#
# 未来的 NEP 应该定义一种序列化he "
"pickle protocol. Loading this file requires "
"allow_pickle=True to be set.",
UserWarning, stacklevel=2)
return "|O"
else:
# 如果以上条件都不符合,则返回 dtype 的字符串表示形式。
return dtype.str
def _wrap_header(header, version):
"""
Takes a stringified header, and attaches the prefix and padding to it
"""
# 确保版本信息不为空
assert version is not None
# 使用指定版本的格式和编码获取格式化字符串和编码方式
fmt, encoding = _header_size_info[version]
# 将头部字符串编码为指定编码方式的字节流
header = header.encode(encoding)
# 计算头部字符串长度加上一个字节的空位
hlen = len(header) + 1
# 计算需要填充的空白长度,使得 MAGIC_LEN、fmt 的结构体大小、hlen 加上 padlen 后能够被 ARRAY_ALIGN 整除
padlen = ARRAY_ALIGN - ((MAGIC_LEN + struct.calcsize(fmt) + hlen) % ARRAY_ALIGN)
try:
# 生成包含魔数和头部长度的前缀数据
header_prefix = magic(*version) + struct.pack(fmt, hlen + padlen)
except struct.error:
# 如果生成头部数据时发生结构错误,抛出 ValueError 异常
msg = "Header length {} too big for version={}".format(hlen, version)
raise ValueError(msg) from None
# 使用空格和换行符填充头部数据,以便使魔数字符串、头部长度短整型和头部数据都能在 ARRAY_ALIGN 字节边界上对齐。
# 这样做支持在像 Linux 这样的系统上内存映射对齐为 ARRAY_ALIGN 的数据类型,
# 其中 mmap() 的偏移量必须是页面对齐的(即文件的开头)。
return header_prefix + header + b' '*padlen + b'\n'
# 从文件头部读取数组的版本信息,封装了版本选择的逻辑
def _wrap_header_guess_version(header):
"""
Like `_wrap_header`, but chooses an appropriate version given the contents
"""
try:
# 尝试使用 (1, 0) 版本封装头部信息
return _wrap_header(header, (1, 0))
except ValueError:
pass
try:
# 尝试使用 (2, 0) 版本封装头部信息
ret = _wrap_header(header, (2, 0))
except UnicodeEncodeError:
pass
else:
# 如果成功,给出警告:格式为 2.0 的存储数组只能被 NumPy >= 1.9 读取
warnings.warn("Stored array in format 2.0. It can only be"
"read by NumPy >= 1.9", UserWarning, stacklevel=2)
return ret
# 尝试使用 (3, 0) 版本封装头部信息
header = _wrap_header(header, (3, 0))
# 给出警告:格式为 3.0 的存储数组只能被 NumPy >= 1.17 读取
warnings.warn("Stored array in format 3.0. It can only be "
"read by NumPy >= 1.17", UserWarning, stacklevel=2)
return header
def _write_array_header(fp, d, version=None):
""" Write the header for an array and returns the version used
Parameters
----------
fp : filelike object
文件对象,用于写入头部信息
d : dict
包含了适合写入文件头部的字符串表示的条目
version : tuple or None
版本号,None 表示使用最旧兼容版本。如果提供了具体版本号且格式不支持,则会引发 ValueError。
默认: None
"""
header = ["{"]
for key, value in sorted(d.items()):
# 在这里需要使用 repr,因为读取时需要 eval
header.append("'%s': %s, " % (key, repr(value)))
header.append("}")
header = "".join(header)
# 添加一些空余空间,以便可以在原地修改数组头部信息,例如在末尾追加数据时改变数组大小
shape = d['shape']
header += " " * ((GROWTH_AXIS_MAX_DIGITS - len(repr(
shape[-1 if d['fortran_order'] else 0]
))) if len(shape) > 0 else 0)
if version is None:
# 根据内容推测适合的版本号
header = _wrap_header_guess_version(header)
else:
# 使用指定版本号封装头部信息
header = _wrap_header(header, version)
fp.write(header)
def write_array_header_1_0(fp, d):
""" Write the header for an array using the 1.0 format.
Parameters
----------
fp : filelike object
文件对象,用于写入头部信息
d : dict
包含了适合写入文件头部的字符串表示的条目
"""
_write_array_header(fp, d, (1, 0))
def write_array_header_2_0(fp, d):
""" Write the header for an array using the 2.0 format.
The 2.0 format allows storing very large structured arrays.
.. versionadded:: 1.9.0
Parameters
----------
fp : filelike object
文件对象,用于写入头部信息
d : dict
包含了适合写入文件头部的字符串表示的条目
"""
_write_array_header(fp, d, (2, 0))
def read_array_header_1_0(fp, max_header_size=_MAX_HEADER_SIZE):
"""
Read an array header from a filelike object using the 1.0 file format
version.
This will leave the file object located just after the header.
Parameters
----------
fp : filelike object
文件对象,用于读取头部信息
max_header_size : int, optional
最大头部大小限制,默认为 _MAX_HEADER_SIZE
"""
# fp: 类文件对象
# 文件对象或类似文件的对象,具有 `.read()` 方法。
# 返回值
# -------
# shape: 元组,包含整数
# 数组的形状。
# fortran_order: 布尔值
# 如果数组数据是 C 连续或 Fortran 连续,则将其直接写出。否则,在写出之前将使其连续。
# dtype: dtype
# 文件数据的数据类型。
# max_header_size: 整数,可选
# 头部的最大允许大小。大型头部可能不安全加载,因此需要显式传递较大的值。
# 参见 :py:func:`ast.literal_eval()` 获取详细信息。
# 异常
# ------
# ValueError
# 如果数据无效。
"""
通过调用 _read_array_header 函数读取数组头部信息,
传递文件对象 fp 和版本号 (1, 0),同时可以指定最大头部大小 max_header_size。
"""
return _read_array_header(
fp, version=(1, 0), max_header_size=max_header_size)
# 从给定的文件对象中读取数组头部信息,使用版本为 2.0 的文件格式。
def read_array_header_2_0(fp, max_header_size=_MAX_HEADER_SIZE):
"""
Read an array header from a filelike object using the 2.0 file format
version.
This will leave the file object located just after the header.
.. versionadded:: 1.9.0
Parameters
----------
fp : filelike object
A file object or something with a `.read()` method like a file.
max_header_size : int, optional
Maximum allowed size of the header. Large headers may not be safe
to load securely and thus require explicitly passing a larger value.
See :py:func:`ast.literal_eval()` for details.
Returns
-------
shape : tuple of int
The shape of the array.
fortran_order : bool
The array data will be written out directly if it is either
C-contiguous or Fortran-contiguous. Otherwise, it will be made
contiguous before writing it out.
dtype : dtype
The dtype of the file's data.
Raises
------
ValueError
If the data is invalid.
"""
# 调用内部函数 _read_array_header,读取数组头部信息
return _read_array_header(
fp, version=(2, 0), max_header_size=max_header_size)
# 清理 npz 文件头部字符串中的 'L',使得 Python 2 生成的头部可以在 Python 3 中读取
def _filter_header(s):
"""Clean up 'L' in npz header ints.
Cleans up the 'L' in strings representing integers. Needed to allow npz
headers produced in Python2 to be read in Python3.
Parameters
----------
s : string
Npy file header.
Returns
-------
header : str
Cleaned up header.
"""
# 导入 tokenize 和 StringIO,用于处理字符串中的 'L'
import tokenize
from io import StringIO
# 生成字符串的 token 流
tokens = []
last_token_was_number = False
for token in tokenize.generate_tokens(StringIO(s).readline):
token_type = token[0]
token_string = token[1]
# 如果上一个 token 是数字且当前 token 是名字且为 'L',则跳过
if (last_token_was_number and
token_type == tokenize.NAME and
token_string == "L"):
continue
else:
tokens.append(token)
last_token_was_number = (token_type == tokenize.NUMBER)
# 重新组合 token 流成为字符串头部
return tokenize.untokenize(tokens)
def _read_array_header(fp, version, max_header_size=_MAX_HEADER_SIZE):
"""
see read_array_header_1_0
"""
# 读取一个无符号的小端 short int,它表示头部的长度
import ast
import struct
hinfo = _header_size_info.get(version)
if hinfo is None:
raise ValueError("Invalid version {!r}".format(version))
hlength_type, encoding = hinfo
# 读取头部长度的字节流
hlength_str = _read_bytes(fp, struct.calcsize(hlength_type), "array header length")
# 解包得到头部的长度值
header_length = struct.unpack(hlength_type, hlength_str)[0]
# 读取指定长度的头部数据
header = _read_bytes(fp, header_length, "array header")
# 将头部数据解码成字符串
header = header.decode(encoding)
# 如果 header 的长度超过了最大允许的 header 大小,抛出 ValueError 异常
if len(header) > max_header_size:
raise ValueError(
f"Header info length ({len(header)}) is large and may not be safe "
"to load securely.\n"
"To allow loading, adjust `max_header_size` or fully trust "
"the `.npy` file using `allow_pickle=True`.\n"
"For safety against large resource use or crashes, sandboxing "
"may be necessary.")
# 将 header 解析为 Python 字典对象 d,使用 ast.literal_eval 函数安全执行
# header 是一个漂亮打印的字符串表示的 Python 字典,以 ARRAY_ALIGN 字节边界对齐
# 字典的键是字符串
try:
d = ast.literal_eval(header)
except SyntaxError as e:
# 如果 header 解析失败,并且版本 <= (2, 0),尝试使用 _filter_header 进行处理
if version <= (2, 0):
header = _filter_header(header)
try:
d = ast.literal_eval(header)
except SyntaxError as e2:
# 如果第二次解析仍然失败,则抛出详细的 ValueError 异常
msg = "Cannot parse header: {!r}"
raise ValueError(msg.format(header)) from e2
else:
# 发出警告,说明需要额外的头部解析,因为文件是在 Python 2 上创建的
warnings.warn(
"Reading `.npy` or `.npz` file required additional "
"header parsing as it was created on Python 2. Save the "
"file again to speed up loading and avoid this warning.",
UserWarning, stacklevel=4)
else:
# 如果版本大于 (2, 0),直接抛出详细的 ValueError 异常
msg = "Cannot parse header: {!r}"
raise ValueError(msg.format(header)) from e
# 检查 d 是否为字典类型,如果不是则抛出 ValueError 异常
if not isinstance(d, dict):
msg = "Header is not a dictionary: {!r}"
raise ValueError(msg.format(d))
# 检查 d 的键集合是否与 EXPECTED_KEYS 一致,如果不一致则抛出 ValueError 异常
if EXPECTED_KEYS != d.keys():
keys = sorted(d.keys())
msg = "Header does not contain the correct keys: {!r}"
raise ValueError(msg.format(keys))
# 对 shape、fortran_order 和 descr 进行合理性检查
# 检查 shape 是否为元组且元素是否全为整数,如果不是则抛出 ValueError 异常
if (not isinstance(d['shape'], tuple) or
not all(isinstance(x, int) for x in d['shape'])):
msg = "shape is not valid: {!r}"
raise ValueError(msg.format(d['shape']))
# 检查 fortran_order 是否为布尔型,如果不是则抛出 ValueError 异常
if not isinstance(d['fortran_order'], bool):
msg = "fortran_order is not a valid bool: {!r}"
raise ValueError(msg.format(d['fortran_order']))
# 尝试将 descr 转换为有效的 dtype 描述符,如果失败则抛出 ValueError 异常
try:
dtype = descr_to_dtype(d['descr'])
except TypeError as e:
msg = "descr is not a valid dtype descriptor: {!r}"
raise ValueError(msg.format(d['descr'])) from e
# 返回解析后的有效数据:shape、fortran_order 和 dtype
return d['shape'], d['fortran_order'], dtype
# 检查并确保版本号是有效的
_check_version(version)
# 写入数组的头部信息到文件对象中,根据数组生成头部数据
_write_array_header(fp, header_data_from_array_1_0(array), version)
if array.itemsize == 0:
# 如果数组的元素字节大小为0,则缓冲区大小设为0
buffersize = 0
else:
# 否则,将缓冲区大小设置为16 MiB,以隐藏Python循环开销
buffersize = max(16 * 1024 ** 2 // array.itemsize, 1)
dtype_class = type(array.dtype)
if array.dtype.hasobject or not dtype_class._legacy:
# 如果数组包含Python对象或者其dtype不是传统的(legacy),则无法直接写出数据,需要使用pickle进行序列化
if not allow_pickle:
# 如果不允许使用pickle,并且数组包含对象,则抛出异常
if array.dtype.hasobject:
raise ValueError("Object arrays cannot be saved when "
"allow_pickle=False")
# 如果dtype不是传统的,并且不允许使用pickle,则抛出异常
if not dtype_class._legacy:
raise ValueError("User-defined dtypes cannot be saved "
"when allow_pickle=False")
# 如果未提供pickle_kwargs,则初始化为空字典
if pickle_kwargs is None:
pickle_kwargs = {}
# 使用pickle将数组数据写入文件对象
pickle.dump(array, fp, protocol=4, **pickle_kwargs)
elif array.flags.f_contiguous and not array.flags.c_contiguous:
# 如果数组是Fortran顺序存储(列优先),且不是C顺序存储(行优先)
if isfileobj(fp):
# 如果文件对象是真实的文件对象,则直接将数组的转置写入文件
array.T.tofile(fp)
else:
# 否则,使用nditer迭代器按块写入数组数据到文件对象
for chunk in numpy.nditer(
array, flags=['external_loop', 'buffered', 'zerosize_ok'],
buffersize=buffersize, order='F'):
fp.write(chunk.tobytes('C'))
else:
# 如果文件对象不是普通文件,检查是否是文件对象
if isfileobj(fp):
# 如果是文件对象,将数组内容写入文件对象
array.tofile(fp)
else:
# 如果文件对象不是普通文件,迭代数组的每个块并写入文件对象
for chunk in numpy.nditer(
array, flags=['external_loop', 'buffered', 'zerosize_ok'],
buffersize=buffersize, order='C'):
fp.write(chunk.tobytes('C'))
# 从一个NPY文件中读取一个数组
def read_array(fp, allow_pickle=False, pickle_kwargs=None, *,
max_header_size=_MAX_HEADER_SIZE):
"""
Read an array from an NPY file.
Parameters
----------
fp : file_like object
If this is not a real file object, then this may take extra memory
and time.
allow_pickle : bool, optional
Whether to allow writing pickled data. Default: False
.. versionchanged:: 1.16.3
Made default False in response to CVE-2019-6446.
pickle_kwargs : dict
Additional keyword arguments to pass to pickle.load. These are only
useful when loading object arrays saved on Python 2 when using
Python 3.
max_header_size : int, optional
Maximum allowed size of the header. Large headers may not be safe
to load securely and thus require explicitly passing a larger value.
See :py:func:`ast.literal_eval()` for details.
This option is ignored when `allow_pickle` is passed. In that case
the file is by definition trusted and the limit is unnecessary.
Returns
-------
array : ndarray
The array from the data on disk.
Raises
------
ValueError
If the data is invalid, or allow_pickle=False and the file contains
an object array.
"""
# 如果允许使用pickle,则忽略max_header_size限制,因为此时输入被视为完全可信任的
if allow_pickle:
max_header_size = 2**64
# 读取文件的魔数版本号
version = read_magic(fp)
# 检查文件版本是否符合要求
_check_version(version)
# 读取数组的形状、Fortran顺序和数据类型信息
shape, fortran_order, dtype = _read_array_header(
fp, version, max_header_size=max_header_size)
# 计算数组中元素的总数
if len(shape) == 0:
count = 1
else:
count = numpy.multiply.reduce(shape, dtype=numpy.int64)
# 现在读取实际的数据
if dtype.hasobject:
# 如果数组包含Python对象,则需要反序列化数据
if not allow_pickle:
raise ValueError("Object arrays cannot be loaded when "
"allow_pickle=False")
# 如果pickle_kwargs为None,则初始化为空字典
if pickle_kwargs is None:
pickle_kwargs = {}
try:
# 使用pickle加载数据
array = pickle.load(fp, **pickle_kwargs)
except UnicodeError as err:
# 如果出现UnicodeError异常,则提供更友好的错误消息
raise UnicodeError("Unpickling a python object failed: %r\n"
"You may need to pass the encoding= option "
"to numpy.load" % (err,)) from err
# 如果不满足以上条件,进入这个分支
else:
# 如果传入的文件对象是文件对象(通过isfileobj()函数判断)
if isfileobj(fp):
# 可以使用快速的fromfile()函数来读取数据
array = numpy.fromfile(fp, dtype=dtype, count=count)
else:
# 如果不是真正的文件对象,需要以占用大量内存的方式读取数据
# crc32 模块对于大于 2 ** 32 字节的读取会失败,
# 打破了从 gzip 流中读取大数据的功能。将读取分块为 BUFFER_SIZE 字节,
# 以避免问题并减少读取时的内存开销。在非分块情况下,count < max_read_count,
# 因此只进行一次读取。
# 使用 np.ndarray 而不是 np.empty,因为后者不能正确实例化零宽度字符串的数据类型;参见
# https://github.com/numpy/numpy/pull/6430
array = numpy.ndarray(count, dtype=dtype)
if dtype.itemsize > 0:
# 如果 dtype.itemsize == 0,则无需再读取
max_read_count = BUFFER_SIZE // min(BUFFER_SIZE, dtype.itemsize)
# 按块读取数据
for i in range(0, count, max_read_count):
read_count = min(max_read_count, count - i)
read_size = int(read_count * dtype.itemsize)
data = _read_bytes(fp, read_size, "array data")
array[i:i+read_count] = numpy.frombuffer(data, dtype=dtype,
count=read_count)
# 如果需要按 Fortran 顺序重新排列数组
if fortran_order:
array.shape = shape[::-1]
array = array.transpose()
else:
# 按指定形状设置数组形状
array.shape = shape
# 返回最终的数组
return array
# 定义一个函数,用于以内存映射方式打开 .npy 文件,并返回内存映射数组对象。
def open_memmap(filename, mode='r+', dtype=None, shape=None,
fortran_order=False, version=None, *,
max_header_size=_MAX_HEADER_SIZE):
"""
Open a .npy file as a memory-mapped array.
This may be used to read an existing file or create a new one.
Parameters
----------
filename : str or path-like
The name of the file on disk. This may *not* be a file-like
object.
mode : str, optional
The mode in which to open the file; the default is 'r+'. In
addition to the standard file modes, 'c' is also accepted to mean
"copy on write." See `memmap` for the available mode strings.
dtype : data-type, optional
The data type of the array if we are creating a new file in "write"
mode, if not, `dtype` is ignored. The default value is None, which
results in a data-type of `float64`.
shape : tuple of int
The shape of the array if we are creating a new file in "write"
mode, in which case this parameter is required. Otherwise, this
parameter is ignored and is thus optional.
fortran_order : bool, optional
Whether the array should be Fortran-contiguous (True) or
C-contiguous (False, the default) if we are creating a new file in
"write" mode.
version : tuple of int (major, minor) or None
If the mode is a "write" mode, then this is the version of the file
format used to create the file. None means use the oldest
supported version that is able to store the data. Default: None
max_header_size : int, optional
Maximum allowed size of the header. Large headers may not be safe
to load securely and thus require explicitly passing a larger value.
See :py:func:`ast.literal_eval()` for details.
Returns
-------
marray : memmap
The memory-mapped array.
Raises
------
ValueError
If the data or the mode is invalid.
OSError
If the file is not found or cannot be opened correctly.
See Also
--------
numpy.memmap
"""
# 检查 filename 是否为文件对象,如果是则抛出错误,因为内存映射不能使用现有文件句柄。
if isfileobj(filename):
raise ValueError("Filename must be a string or a path-like object."
" Memmap cannot use existing file handles.")
if 'w' in mode:
# 如果 'w' 在 mode 中,表示我们正在创建文件,而不是读取它。
# 检查是否需要创建文件的版本。
_check_version(version)
# 确保给定的 dtype 是真正的 dtype 对象,而不仅仅是可以解释为 dtype 对象的内容。
dtype = numpy.dtype(dtype)
# 如果 dtype 包含 Python 对象,则不能进行内存映射。
if dtype.hasobject:
msg = "Array can't be memory-mapped: Python objects in dtype."
raise ValueError(msg)
# 构建描述文件头所需的字典。
d = dict(
descr=dtype_to_descr(dtype), # 将 dtype 转换为描述器
fortran_order=fortran_order, # 是否按 Fortran 顺序存储
shape=shape, # 数组的形状
)
# 如果执行到这里,应该可以安全地创建文件。
with open(os.fspath(filename), mode+'b') as fp:
_write_array_header(fp, d, version) # 写入数组头信息
offset = fp.tell() # 记录当前文件指针位置
else:
# 否则,读取文件的头部信息。
with open(os.fspath(filename), 'rb') as fp:
version = read_magic(fp) # 读取文件的魔数(magic number)
_check_version(version) # 检查文件版本是否合适
# 从文件中读取数组的形状、存储顺序和数据类型。
shape, fortran_order, dtype = _read_array_header(
fp, version, max_header_size=max_header_size)
# 如果 dtype 包含 Python 对象,则不能进行内存映射。
if dtype.hasobject:
msg = "Array can't be memory-mapped: Python objects in dtype."
raise ValueError(msg)
offset = fp.tell() # 记录当前文件指针位置
# 根据 fortran_order 确定数组的存储顺序。
if fortran_order:
order = 'F'
else:
order = 'C'
# 如果 mode 是 'w+',需将其修改为 'r+',因为已经向文件写入数据。
if mode == 'w+':
mode = 'r+'
# 创建内存映射数组对象。
marray = numpy.memmap(filename, dtype=dtype, shape=shape, order=order,
mode=mode, offset=offset)
return marray
# 从文件对象 `fp` 中读取指定大小 `size` 的数据,直到读取完为止。
# 如果在读取完 `size` 字节之前遇到 EOF,则抛出 ValueError 异常。
# 对于非阻塞对象,仅支持继承自 io 对象的情况。
def _read_bytes(fp, size, error_template="ran out of data"):
"""
Read from file-like object until size bytes are read.
Raises ValueError if not EOF is encountered before size bytes are read.
Non-blocking objects only supported if they derive from io objects.
Required as e.g. ZipExtFile in python 2.6 can return less data than
requested.
"""
# 初始化一个空的 bytes 对象,用于存储读取的数据
data = bytes()
while True:
# 对于 io 文件(在 Python3 中是默认的),如果读取到末尾返回 None 或抛出异常
# 对于 Python2 的文件对象,可能会截断数据,无法处理非阻塞情况
try:
# 尝试从文件对象 `fp` 中读取剩余未读取的部分(直到 `size - len(data)` 字节)
r = fp.read(size - len(data))
# 将读取到的数据追加到 `data` 中
data += r
# 如果读取返回空数据(EOF)或者已经读取了 `size` 字节,停止循环
if len(r) == 0 or len(data) == size:
break
except BlockingIOError:
pass
# 检查实际读取的数据长度是否等于指定的 `size`
if len(data) != size:
# 如果实际读取长度不等于 `size`,抛出异常,显示预期读取的字节数和实际读取的字节数
msg = "EOF: reading %s, expected %d bytes got %d"
raise ValueError(msg % (error_template, size, len(data)))
else:
# 如果读取长度等于 `size`,返回读取到的数据
return data
def isfileobj(f):
# 检查对象 `f` 是否是文件对象(FileIO、BufferedReader、BufferedWriter 的实例)
if not isinstance(f, (io.FileIO, io.BufferedReader, io.BufferedWriter)):
return False
try:
# 尝试获取文件对象的 `fileno()` 方法,如果包装了 BytesIO 等对象可能会引发 OSError 异常
f.fileno()
return True
except OSError:
return False
.\numpy\numpy\lib\format.pyi
from typing import Any, Literal, Final
__all__: list[str]
EXPECTED_KEYS: Final[set[str]]
MAGIC_PREFIX: Final[bytes]
MAGIC_LEN: Literal[8]
ARRAY_ALIGN: Literal[64]
BUFFER_SIZE: Literal[262144]
def magic(major, minor): ...
def read_magic(fp): ...
def dtype_to_descr(dtype): ...
def descr_to_dtype(descr): ...
def header_data_from_array_1_0(array): ...
def write_array_header_1_0(fp, d): ...
def write_array_header_2_0(fp, d): ...
def read_array_header_1_0(fp): ...
def read_array_header_2_0(fp): ...
def write_array(fp, array, version=..., allow_pickle=..., pickle_kwargs=...): ...
def read_array(fp, allow_pickle=..., pickle_kwargs=...): ...
def open_memmap(filename, mode=..., dtype=..., shape=..., fortran_order=..., version=...): ...
.\numpy\numpy\lib\introspect.py
"""
Introspection helper functions.
"""
import re
__all__ = ['opt_func_info']
def opt_func_info(func_name=None, signature=None):
"""
Returns a dictionary containing the currently supported CPU dispatched
features for all optimized functions.
Parameters
----------
func_name : str (optional)
Regular expression to filter by function name.
signature : str (optional)
Regular expression to filter by data type.
Returns
-------
dict
A dictionary where keys are optimized function names and values are
nested dictionaries indicating supported targets based on data types.
Examples
--------
Retrieve dispatch information for functions named 'add' or 'sub' and
data types 'float64' or 'float32':
>>> dict = np.lib.introspect.opt_func_info(
... func_name="add|abs", signature="float64|complex64"
... )
>>> import json
>>> print(json.dumps(dict, indent=2))
{
"absolute": {
"dd": {
"current": "SSE41",
"available": "SSE41 baseline(SSE SSE2 SSE3)"
},
"Ff": {
"current": "FMA3__AVX2",
"available": "AVX512F FMA3__AVX2 baseline(SSE SSE2 SSE3)"
},
"Dd": {
"current": "FMA3__AVX2",
"available": "AVX512F FMA3__AVX2 baseline(SSE SSE2 SSE3)"
}
},
"add": {
"ddd": {
"current": "FMA3__AVX2",
"available": "FMA3__AVX2 baseline(SSE SSE2 SSE3)"
},
"FFF": {
"current": "FMA3__AVX2",
"available": "FMA3__AVX2 baseline(SSE SSE2 SSE3)"
}
}
}
"""
from numpy._core._multiarray_umath import (
__cpu_targets_info__ as targets, dtype
)
if func_name is not None:
func_pattern = re.compile(func_name)
matching_funcs = {
k: v for k, v in targets.items()
if func_pattern.search(k)
}
else:
matching_funcs = targets
if signature is not None:
sig_pattern = re.compile(signature)
matching_sigs = {}
for k, v in matching_funcs.items():
matching_chars = {}
for chars, targets in v.items():
if any([
sig_pattern.search(c) or
sig_pattern.search(dtype(c).name)
for c in chars
]):
matching_chars[chars] = targets
if matching_chars:
matching_sigs[k] = matching_chars
else:
matching_sigs = matching_funcs
return matching_sigs
.\numpy\numpy\lib\mixins.py
"""
Mixin classes for custom array types that don't inherit from ndarray.
"""
from numpy._core import umath as um
__all__ = ['NDArrayOperatorsMixin']
def _disables_array_ufunc(obj):
"""True when __array_ufunc__ is set to None."""
try:
return obj.__array_ufunc__ is None
except AttributeError:
return False
def _binary_method(ufunc, name):
"""Implement a forward binary method with a ufunc, e.g., __add__."""
def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
return ufunc(self, other)
func.__name__ = '__{}__'.format(name)
return func
def _reflected_binary_method(ufunc, name):
"""Implement a reflected binary method with a ufunc, e.g., __radd__."""
def func(self, other):
if _disables_array_ufunc(other):
return NotImplemented
return ufunc(other, self)
func.__name__ = '__r{}__'.format(name)
return func
def _inplace_binary_method(ufunc, name):
"""Implement an in-place binary method with a ufunc, e.g., __iadd__."""
def func(self, other):
return ufunc(self, other, out=(self,))
func.__name__ = '__i{}__'.format(name)
return func
def _numeric_methods(ufunc, name):
"""Implement forward, reflected and inplace binary methods with a ufunc."""
return (_binary_method(ufunc, name),
_reflected_binary_method(ufunc, name),
_inplace_binary_method(ufunc, name))
def _unary_method(ufunc, name):
"""Implement a unary special method with a ufunc."""
def func(self):
return ufunc(self)
func.__name__ = '__{}__'.format(name)
return func
class NDArrayOperatorsMixin:
"""Mixin defining all operator special methods using __array_ufunc__.
This class implements the special methods for almost all of Python's
builtin operators defined in the `operator` module, including comparisons
(``==``, ``>``, etc.) and arithmetic (``+``, ``*``, ``-``, etc.), by
deferring to the ``__array_ufunc__`` method, which subclasses must
implement.
It is useful for writing classes that do not inherit from `numpy.ndarray`,
but that should support arithmetic and numpy universal functions like
arrays as described in `A Mechanism for Overriding Ufuncs
<https://numpy.org/neps/nep-0013-ufunc-overrides.html>`_.
As an trivial example, consider this implementation of an ``ArrayLike``
class that simply wraps a NumPy array and ensures that the result of any
"""
class ArrayLike(np.lib.mixins.NDArrayOperatorsMixin):
def __init__(self, value):
self.value = np.asarray(value)
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
out = kwargs.get('out', ())
for x in inputs + out:
if not isinstance(x, self._HANDLED_TYPES + (ArrayLike,)):
return NotImplemented
inputs = tuple(x.value if isinstance(x, ArrayLike) else x for x in inputs)
if out:
kwargs['out'] = tuple(x.value if isinstance(x, ArrayLike) else x for x in out)
result = getattr(ufunc, method)(*inputs, **kwargs)
if type(result) is tuple:
return tuple(type(self)(x) for x in result)
elif method == 'at':
return None
else:
return type(self)(result)
def __repr__(self):
return '%s(%r)' % (type(self).__name__, self.value)
__slots__ = ()
__lt__ = _binary_method(um.less, 'lt')
__le__ = _binary_method(um.less_equal, 'le')
__eq__ = _binary_method(um.equal, 'eq')
__ne__ = _binary_method(um.not_equal, 'ne')
__gt__ = _binary_method(um.greater, 'gt')
__ge__ = _binary_method(um.greater_equal, 'ge')
__add__, __radd__, __iadd__ = _numeric_methods(um.add, 'add')
__sub__, __rsub__, __isub__ = _numeric_methods(um.subtract, 'sub')
__mul__, __rmul__, __imul__ = _numeric_methods(um.multiply, 'mul')
__matmul__, __rmatmul__, __imatmul__ = _numeric_methods(um.matmul, 'matmul')
__truediv__, __rtruediv__, __itruediv__ = _numeric_methods(um.true_divide, 'truediv')
__floordiv__, __rfloordiv__, __ifloordiv__ = _numeric_methods(um.floor_divide, 'floordiv')
__mod__, __rmod__, __imod__ = _numeric_methods(um.remainder, 'mod')
__divmod__ = _binary_method(um.divmod, 'divmod')
__rdivmod__ = _reflected_binary_method(um.divmod, 'divmod')
__pow__, __rpow__, __ipow__ = _numeric_methods(um.power, 'pow')
__lshift__, __rlshift__, __ilshift__ = _numeric_methods(um.left_shift, 'lshift')
__rshift__, __rrshift__, __irshift__ = _numeric_methods(um.right_shift, 'rshift')
__and__, __rand__, __iand__ = _numeric_methods(um.bitwise_and, 'and')
__xor__, __rxor__, __ixor__ = _numeric_methods(um.bitwise_xor, 'xor')
__or__, __ror__, __ior__ = _numeric_methods(um.bitwise_or, 'or')
__neg__ = _unary_method(um.negative, 'neg')
__pos__ = _unary_method(um.positive, 'pos')
__abs__ = _unary_method(um.absolute, 'abs')
__invert__ = _unary_method(um.invert, 'invert')
.\numpy\numpy\lib\mixins.pyi
from abc import ABCMeta, abstractmethod
from typing import Literal as L, Any
from numpy import ufunc
__all__: list[str]
class NDArrayOperatorsMixin(metaclass=ABCMeta):
@abstractmethod
def __array_ufunc__(
self,
ufunc: ufunc,
method: L["__call__", "reduce", "reduceat", "accumulate", "outer", "at"],
*inputs: Any,
**kwargs: Any,
) -> Any: ...
def __lt__(self, other: Any) -> Any: ...
def __le__(self, other: Any) -> Any: ...
def __eq__(self, other: Any) -> Any: ...
def __ne__(self, other: Any) -> Any: ...
def __gt__(self, other: Any) -> Any: ...
def __ge__(self, other: Any) -> Any: ...
def __add__(self, other: Any) -> Any: ...
def __radd__(self, other: Any) -> Any: ...
def __iadd__(self, other: Any) -> Any: ...
def __sub__(self, other: Any) -> Any: ...
def __rsub__(self, other: Any) -> Any: ...
def __isub__(self, other: Any) -> Any: ...
def __mul__(self, other: Any) -> Any: ...
def __rmul__(self, other: Any) -> Any: ...
def __imul__(self, other: Any) -> Any: ...
def __matmul__(self, other: Any) -> Any: ...
def __rmatmul__(self, other: Any) -> Any: ...
def __imatmul__(self, other: Any) -> Any: ...
def __truediv__(self, other: Any) -> Any: ...
def __rtruediv__(self, other: Any) -> Any: ...
def __itruediv__(self, other: Any) -> Any: ...
def __floordiv__(self, other: Any) -> Any: ...
def __rfloordiv__(self, other: Any) -> Any: ...
def __ifloordiv__(self, other: Any) -> Any: ...
def __mod__(self, other: Any) -> Any: ...
def __rmod__(self, other: Any) -> Any: ...
def __imod__(self, other: Any) -> Any: ...
def __divmod__(self, other: Any) -> Any: ...
def __rdivmod__(self, other: Any) -> Any: ...
def __pow__(self, other: Any) -> Any: ...
def __rpow__(self, other: Any) -> Any: ...
def __ipow__(self, other: Any) -> Any: ...
def __lshift__(self, other: Any) -> Any: ...
def __rlshift__(self, other: Any) -> Any: ...
def __ilshift__(self, other: Any) -> Any: ...
def __rshift__(self, other: Any) -> Any: ...
def __rrshift__(self, other: Any) -> Any: ...
def __irshift__(self, other: Any) -> Any: ...
def __and__(self, other: Any) -> Any: ...
def __rand__(self, other: Any) -> Any: ...
def __iand__(self, other: Any) -> Any: ...
def __xor__(self, other: Any) -> Any: ...
def __rxor__(self, other: Any) -> Any: ...
def __ixor__(self, other: Any) -> Any: ...
def __or__(self, other: Any) -> Any: ...
def __ror__(self, other: Any) -> Any: ...
def __ior__(self, other: Any) -> Any: ...
def __neg__(self) -> Any: ...
def __pos__(self) -> Any: ...
def __abs__(self) -> Any: ...
def __invert__(self) -> Any: ...
.\numpy\numpy\lib\npyio.py
from ._npyio_impl import (
__doc__, DataSource, NpzFile
)
.\numpy\numpy\lib\npyio.pyi
from numpy.lib._npyio_impl import (
DataSource as DataSource,
NpzFile as NpzFile,
)
.\numpy\numpy\lib\recfunctions.py
"""
Collection of utilities to manipulate structured arrays.
Most of these functions were initially implemented by John Hunter for
matplotlib. They have been rewritten and extended for convenience.
"""
import itertools
import numpy as np
import numpy.ma as ma
from numpy import ndarray
from numpy.ma import MaskedArray
from numpy.ma.mrecords import MaskedRecords
from numpy._core.overrides import array_function_dispatch
from numpy._core.records import recarray
from numpy.lib._iotools import _is_string_like
_check_fill_value = np.ma.core._check_fill_value
__all__ = [
'append_fields', 'apply_along_fields', 'assign_fields_by_name',
'drop_fields', 'find_duplicates', 'flatten_descr',
'get_fieldstructure', 'get_names', 'get_names_flat',
'join_by', 'merge_arrays', 'rec_append_fields',
'rec_drop_fields', 'rec_join', 'recursive_fill_fields',
'rename_fields', 'repack_fields', 'require_fields',
'stack_arrays', 'structured_to_unstructured', 'unstructured_to_structured',
]
def _recursive_fill_fields_dispatcher(input, output):
return (input, output)
@array_function_dispatch(_recursive_fill_fields_dispatcher)
def recursive_fill_fields(input, output):
"""
Fills fields from output with fields from input,
with support for nested structures.
Parameters
----------
input : ndarray
Input array.
output : ndarray
Output array.
Notes
-----
* `output` should be at least the same size as `input`
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> a = np.array([(1, 10.), (2, 20.)], dtype=[('A', np.int64), ('B', np.float64)])
>>> b = np.zeros((3,), dtype=a.dtype)
>>> rfn.recursive_fill_fields(a, b)
array([(1, 10.), (2, 20.), (0, 0.)], dtype=[('A', '<i8'), ('B', '<f8')])
"""
newdtype = output.dtype
for field in newdtype.names:
try:
current = input[field]
except ValueError:
continue
if current.dtype.names is not None:
recursive_fill_fields(current, output[field])
else:
output[field][:len(current)] = current
return output
def _get_fieldspec(dtype):
"""
Produce a list of name/dtype pairs corresponding to the dtype fields
Similar to dtype.descr, but the second item of each tuple is a dtype, not a
string. As a result, this handles subarray dtypes
Can be passed to the dtype constructor to reconstruct the dtype, noting that
this (deliberately) discards field offsets.
Examples
--------
>>> dt = np.dtype([(('a', 'A'), np.int64), ('b', np.double, 3)])
>>> dt.descr
[(('a', 'A'), '<i8'), ('b', '<f8', (3,))]
>>> _get_fieldspec(dt)
[(('a', 'A'), dtype('int64')), ('b', dtype(('<f8', (3,))))]
"""
if dtype.names is None:
return [('', dtype)]
else:
fields = ((name, dtype.fields[name]) for name in dtype.names)
return [
(name if len(f) == 2 else (f[2], name), f[0])
for name, f in fields
]
def get_names(adtype):
"""
Returns the field names of the input datatype as a tuple. Input datatype
must have fields otherwise error is raised.
Parameters
----------
adtype : dtype
Input datatype
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> rfn.get_names(np.empty((1,), dtype=[('A', int)]).dtype)
('A',)
>>> rfn.get_names(np.empty((1,), dtype=[('A',int), ('B', float)]).dtype)
('A', 'B')
>>> adtype = np.dtype([('a', int), ('b', [('ba', int), ('bb', int)])])
>>> rfn.get_names(adtype)
('a', ('b', ('ba', 'bb')))
"""
listnames = []
names = adtype.names
for name in names:
current = adtype[name]
if current.names is not None:
listnames.append((name, tuple(get_names(current))))
else:
listnames.append(name)
return tuple(listnames)
def get_names_flat(adtype):
"""
Returns the field names of the input datatype as a tuple. Input datatype
must have fields otherwise error is raised.
Nested structure are flattened beforehand.
Parameters
----------
adtype : dtype
Input datatype
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> rfn.get_names_flat(np.empty((1,), dtype=[('A', int)]).dtype) is None
False
>>> rfn.get_names_flat(np.empty((1,), dtype=[('A',int), ('B', str)]).dtype)
('A', 'B')
>>> adtype = np.dtype([('a', int), ('b', [('ba', int), ('bb', int)])])
>>> rfn.get_names_flat(adtype)
('a', 'b', 'ba', 'bb')
"""
listnames = []
names = adtype.names
for name in names:
listnames.append(name)
current = adtype[name]
if current.names is not None:
listnames.extend(get_names_flat(current))
return tuple(listnames)
def flatten_descr(ndtype):
"""
Flatten a structured data-type description.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> ndtype = np.dtype([('a', '<i4'), ('b', [('ba', '<f8'), ('bb', '<i4')])])
>>> rfn.flatten_descr(ndtype)
(('a', dtype('int32')), ('ba', dtype('float64')), ('bb', dtype('int32')))
"""
names = ndtype.names
if names is None:
return (('', ndtype),)
else:
descr = []
for field in names:
(typ, _) = ndtype.fields[field]
if typ.names is not None:
descr.extend(flatten_descr(typ))
else:
descr.append((field, typ))
return tuple(descr)
def _zip_dtype(seqarrays, flatten=False):
newdtype = []
if flatten:
for a in seqarrays:
newdtype.extend(flatten_descr(a.dtype))
else:
for a in seqarrays:
current = a.dtype
if current.names is not None and len(current.names) == 1:
newdtype.extend(_get_fieldspec(current))
else:
newdtype.append(('', current))
return np.dtype(newdtype)
def _zip_descr(seqarrays, flatten=False):
return _zip_dtype(seqarrays, flatten=flatten).descr
def get_fieldstructure(adtype, lastname=None, parents=None,):
if parents is None:
parents = {}
names = adtype.names
for name in names:
current = adtype[name]
if current.names is not None:
if lastname:
parents[name] = [lastname, ]
else:
parents[name] = []
parents.update(get_fieldstructure(current, name, parents))
else:
lastparent = [_ for _ in (parents.get(lastname, []) or [])]
if lastparent:
lastparent.append(lastname)
elif lastname:
lastparent = [lastname, ]
parents[name] = lastparent or []
return parents
def _izip_fields_flat(iterable):
for element in iterable:
if isinstance(element, np.void):
yield from _izip_fields_flat(tuple(element))
else:
yield element
def _izip_fields(iterable):
for element in iterable:
if (hasattr(element, '__iter__') and
not isinstance(element, str)):
yield from _izip_fields(element)
elif isinstance(element, np.void) and len(tuple(element)) == 1:
yield from _izip_fields(element)
else:
yield element
def _izip_records(seqarrays, fill_value=None, flatten=True):
"""
Returns an iterator of concatenated items from a sequence of arrays.
Parameters
----------
seqarrays : sequence of arrays
The sequence of arrays to iterate over.
fill_value : any, optional
The value to use for missing fields in records.
flatten : bool, optional
Whether to flatten nested structures.
"""
if flatten:
zipfunc = _izip_fields_flat
else:
zipfunc = _izip_fields
for tup in itertools.zip_longest(*seqarrays, fillvalue=fill_value):
yield tuple(zipfunc(tup))
def _fix_output(output, usemask=True, asrecarray=False):
if not isinstance(output, MaskedArray):
usemask = False
if usemask:
if asrecarray:
output = output.view(MaskedRecords)
else:
output = ma.filled(output)
if asrecarray:
output = output.view(recarray)
return output
def _fix_defaults(output, defaults=None):
names = output.dtype.names
(data, mask, fill_value) = (output.data, output.mask, output.fill_value)
for (k, v) in (defaults or {}).items():
if k in names:
fill_value[k] = v
data[k][mask[k]] = v
return output
def _merge_arrays_dispatcher(seqarrays, fill_value=None, flatten=None, usemask=None, asrecarray=None):
return seqarrays
@array_function_dispatch(_merge_arrays_dispatcher)
def merge_arrays(seqarrays, fill_value=-1, flatten=False, usemask=False, asrecarray=False):
"""
Merge arrays field by field.
Parameters
----------
seqarrays : sequence of ndarrays
Sequence of arrays
fill_value : {float}, optional
Filling value used to pad missing data on the shorter arrays.
flatten : {False, True}, optional
Whether to collapse nested fields.
usemask : {False, True}, optional
Whether to return a masked array or not.
asrecarray : {False, True}, optional
Whether to return a recarray (MaskedRecords) or not.
Examples
--------
... (示例代码,略)
Notes
-----
* Without a mask, the missing value will be filled with something,
depending on what its corresponding type:
* ``-1`` for integers
* ``-1.0`` for floating point numbers
* ``'-'`` for characters
* ``'-1'`` for strings
* ``True`` for boolean values
* XXX: I just obtained these values empirically
"""
if (len(seqarrays) == 1):
seqarrays = np.asanyarray(seqarrays[0])
if isinstance(seqarrays, (ndarray, np.void)):
seqdtype = seqarrays.dtype
if seqdtype.names is None:
seqdtype = np.dtype([('', seqdtype)])
if not flatten or _zip_dtype((seqarrays,), flatten=True) == seqdtype:
seqarrays = seqarrays.ravel()
if usemask:
if asrecarray:
seqtype = MaskedRecords
else:
seqtype = MaskedArray
elif asrecarray:
seqtype = recarray
else:
seqtype = ndarray
return seqarrays.view(dtype=seqdtype, type=seqtype)
else:
seqarrays = (seqarrays,)
else:
seqarrays = [np.asanyarray(_m) for _m in seqarrays]
sizes = tuple(a.size for a in seqarrays)
maxlength = max(sizes)
newdtype = _zip_dtype(seqarrays, flatten=flatten)
seqdata = []
seqmask = []
if usemask:
for (a, n) in zip(seqarrays, sizes):
nbmissing = (maxlength - n)
data = a.ravel().__array__()
mask = ma.getmaskarray(a).ravel()
if nbmissing:
fval = _check_fill_value(fill_value, a.dtype)
if isinstance(fval, (ndarray, np.void)):
if len(fval.dtype) == 1:
fval = fval.item()[0]
fmsk = True
else:
fval = np.array(fval, dtype=a.dtype, ndmin=1)
fmsk = np.ones((1,), dtype=mask.dtype)
else:
fval = None
fmsk = True
seqdata.append(itertools.chain(data, [fval] * nbmissing))
seqmask.append(itertools.chain(mask, [fmsk] * nbmissing))
data = tuple(_izip_records(seqdata, flatten=flatten))
output = ma.array(np.fromiter(data, dtype=newdtype, count=maxlength),
mask=list(_izip_records(seqmask, flatten=flatten)))
if asrecarray:
output = output.view(MaskedRecords)
else:
for (a, n) in zip(seqarrays, sizes):
nbmissing = (maxlength - n)
data = a.ravel().__array__()
if nbmissing:
fval = _check_fill_value(fill_value, a.dtype)
if isinstance(fval, (ndarray, np.void)):
if len(fval.dtype) == 1:
fval = fval.item()[0]
else:
fval = np.array(fval, dtype=a.dtype, ndmin=1)
else:
fval = None
else:
fval = None
seqdata.append(itertools.chain(data, [fval] * nbmissing))
output = np.fromiter(tuple(_izip_records(seqdata, flatten=flatten)),
dtype=newdtype, count=maxlength)
if asrecarray:
output = output.view(recarray)
return output
def _drop_fields_dispatcher(base, drop_names, usemask=None, asrecarray=None):
return (base,)
@array_function_dispatch(_drop_fields_dispatcher)
def drop_fields(base, drop_names, usemask=True, asrecarray=False):
"""
返回一个删除了指定字段的新数组。
支持嵌套字段。
.. versionchanged:: 1.18.0
如果删除了所有字段,则返回一个字段数为 0 的数组,而不是像之前一样返回 ``None``。
Parameters
----------
base : array
输入的数组
drop_names : string 或 sequence
要删除的字段名或字段名列表。
usemask : {False, True}, optional
是否返回掩码数组。
asrecarray : string 或 sequence, optional
是否返回 recarray 或 mrecarray (`asrecarray=True`),或者是一个普通的 ndarray 或带有灵活 dtype 的掩码数组。默认为 False。
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> a = np.array([(1, (2, 3.0)), (4, (5, 6.0))],
... dtype=[('a', np.int64), ('b', [('ba', np.double), ('bb', np.int64)])])
>>> rfn.drop_fields(a, 'a')
array([((2., 3),), ((5., 6),)],
dtype=[('b', [('ba', '<f8'), ('bb', '<i8')])])
>>> rfn.drop_fields(a, 'ba')
array([(1, (3,)), (4, (6,))], dtype=[('a', '<i8'), ('b', [('bb', '<i8')])])
>>> rfn.drop_fields(a, ['ba', 'bb'])
array([(1,), (4,)], dtype=[('a', '<i8')])
"""
if _is_string_like(drop_names):
drop_names = [drop_names]
else:
drop_names = set(drop_names)
def _drop_descr(ndtype, drop_names):
names = ndtype.names
newdtype = []
for name in names:
current = ndtype[name]
if name in drop_names:
continue
if current.names is not None:
descr = _drop_descr(current, drop_names)
if descr:
newdtype.append((name, descr))
else:
newdtype.append((name, current))
return newdtype
newdtype = _drop_descr(base.dtype, drop_names)
output = np.empty(base.shape, dtype=newdtype)
output = recursive_fill_fields(base, output)
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
def _keep_fields(base, keep_names, usemask=True, asrecarray=False):
"""
返回一个仅包含指定字段并保持这些字段顺序的新数组。
Parameters
----------
base : array
输入的数组
keep_names : string 或 sequence
要保留的字段名或字段名列表。保持的顺序将与输入中的顺序一致。
usemask : {False, True}, optional
是否返回掩码数组。
"""
newdtype = [(n, base.dtype[n]) for n in keep_names]
output = np.empty(base.shape, dtype=newdtype)
output = recursive_fill_fields(base, output)
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
def _rec_drop_fields_dispatcher(base, drop_names):
return (base,)
@array_function_dispatch(_rec_drop_fields_dispatcher)
def rec_drop_fields(base, drop_names):
"""
Returns a new numpy.recarray with fields in `drop_names` dropped.
"""
return drop_fields(base, drop_names, usemask=False, asrecarray=True)
def _rename_fields_dispatcher(base, namemapper):
return (base,)
@array_function_dispatch(_rename_fields_dispatcher)
def rename_fields(base, namemapper):
"""
Rename the fields from a flexible-datatype ndarray or recarray.
Nested fields are supported.
Parameters
----------
base : ndarray
Input array whose fields must be modified.
namemapper : dictionary
Dictionary mapping old field names to their new version.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> a = np.array([(1, (2, [3.0, 30.])), (4, (5, [6.0, 60.]))],
... dtype=[('a', int),('b', [('ba', float), ('bb', (float, 2))])])
>>> rfn.rename_fields(a, {'a':'A', 'bb':'BB'})
array([(1, (2., [ 3., 30.])), (4, (5., [ 6., 60.]))],
dtype=[('A', '<i8'), ('b', [('ba', '<f8'), ('BB', '<f8', (2,))])])
"""
def _recursive_rename_fields(ndtype, namemapper):
newdtype = []
for name in ndtype.names:
newname = namemapper.get(name, name)
current = ndtype[name]
if current.names is not None:
newdtype.append(
(newname, _recursive_rename_fields(current, namemapper))
)
else:
newdtype.append((newname, current))
return newdtype
newdtype = _recursive_rename_fields(base.dtype, namemapper)
return base.view(newdtype)
def _append_fields_dispatcher(base, names, data, dtypes=None,
fill_value=None, usemask=None, asrecarray=None):
yield base
yield from data
@array_function_dispatch(_append_fields_dispatcher)
def append_fields(base, names, data, dtypes=None,
fill_value=-1, usemask=True, asrecarray=False):
"""
Add new fields to an existing array.
The names of the fields are given with the `names` arguments,
the corresponding values with the `data` arguments.
If a single field is appended, `names`, `data` and `dtypes` do not have
to be lists but just values.
Parameters
----------
base : array
Input array to extend.
names : string, sequence
String or sequence of strings corresponding to the names
of the new fields.
data : array or sequence of arrays
Array or sequence of arrays storing the fields to add to the base.
dtypes : sequence of datatypes, optional
Datatype or sequence of datatypes.
If None, the datatypes are estimated from the `data`.
fill_value : {float}, optional
Filling value used to pad missing data on the shorter arrays.
"""
usemask : {False, True}, optional
Whether to return a masked array or not.
asrecarray : {False, True}, optional
Whether to return a recarray (MaskedRecords) or not.
"""
# 检查传入的字段名是否合法
if isinstance(names, (tuple, list)):
# 如果传入的字段名是一个元组或列表
if len(names) != len(data):
# 如果字段名的数量与数据数组的数量不匹配,抛出数值错误异常
msg = "The number of arrays does not match the number of names"
raise ValueError(msg)
elif isinstance(names, str):
# 如果传入的字段名是字符串,则将其转换为单元素列表
names = [names, ]
data = [data, ]
#
if dtypes is None:
# 如果未指定数据类型
data = [np.array(a, copy=None, subok=True) for a in data]
# 将每个数据数组转换为 NumPy 数组
data = [a.view([(name, a.dtype)]) for (name, a) in zip(names, data)]
else:
# 如果指定了数据类型
if not isinstance(dtypes, (tuple, list)):
# 如果数据类型不是元组或列表,则转换为单元素列表
dtypes = [dtypes, ]
if len(data) != len(dtypes):
# 如果数据数组数量与数据类型数量不匹配
if len(dtypes) == 1:
# 如果数据类型数量为1,则复制以匹配数据数组数量
dtypes = dtypes * len(data)
else:
msg = "The dtypes argument must be None, a dtype, or a list."
raise ValueError(msg)
data = [np.array(a, copy=None, subok=True, dtype=d).view([(n, d)])
for (a, n, d) in zip(data, names, dtypes)]
#
base = merge_arrays(base, usemask=usemask, fill_value=fill_value)
# 将基本数据与指定参数合并为一个数组
if len(data) > 1:
# 如果数据数组数量大于1
data = merge_arrays(data, flatten=True, usemask=usemask,
fill_value=fill_value)
# 合并所有数据数组为一个数组
else:
data = data.pop()
# 否则取出单个数据数组
#
output = ma.masked_all(
max(len(base), len(data)),
dtype=_get_fieldspec(base.dtype) + _get_fieldspec(data.dtype))
# 创建一个全遮蔽的数组,大小为基本数据与合并数据数组的最大长度,并指定数据类型
output = recursive_fill_fields(base, output)
# 递归填充输出数组的字段
output = recursive_fill_fields(data, output)
# 递归填充数据数组的字段到输出数组
#
return _fix_output(output, usemask=usemask, asrecarray=asrecarray)
# 返回修正后的输出数组,根据参数决定是否使用掩码数组或返回记录数组
# 定义一个生成器函数,用于分发基本参数和数据,实现迭代功能
def _rec_append_fields_dispatcher(base, names, data, dtypes=None):
yield base
# 使用生成器将数据添加到生成器函数中
yield from data
# 将array_function_dispatch装饰器应用于_rec_append_fields_dispatcher函数,用于分派相应的功能
@array_function_dispatch(_rec_append_fields_dispatcher)
# 定义rec_append_fields函数,用于向现有数组添加新字段
def rec_append_fields(base, names, data, dtypes=None):
"""
Add new fields to an existing array.
The names of the fields are given with the `names` arguments,
the corresponding values with the `data` arguments.
If a single field is appended, `names`, `data` and `dtypes` do not have
to be lists but just values.
Parameters
----------
base : array
Input array to extend.
names : string, sequence
String or sequence of strings corresponding to the names
of the new fields.
data : array or sequence of arrays
Array or sequence of arrays storing the fields to add to the base.
dtypes : sequence of datatypes, optional
Datatype or sequence of datatypes.
If None, the datatypes are estimated from the `data`.
See Also
--------
append_fields
Returns
-------
appended_array : np.recarray
返回扩展后的数组,作为结构化数组(np.recarray)返回。
"""
# 调用append_fields函数,将新字段添加到基础数组中,并返回扩展后的结果
return append_fields(base, names, data=data, dtypes=dtypes,
asrecarray=True, usemask=False)
# 定义一个分发器函数,用于分派_repack_fields_dispatcher函数
def _repack_fields_dispatcher(a, align=None, recurse=None):
# 返回传入的参数a,作为迭代结果
return (a,)
# 将array_function_dispatch装饰器应用于_repack_fields_dispatcher函数,用于分派相应的功能
@array_function_dispatch(_repack_fields_dispatcher)
# 定义repack_fields函数,用于重新打包结构化数组或数据类型在内存中的字段
def repack_fields(a, align=False, recurse=False):
"""
Re-pack the fields of a structured array or dtype in memory.
The memory layout of structured datatypes allows fields at arbitrary
byte offsets. This means the fields can be separated by padding bytes,
their offsets can be non-monotonically increasing, and they can overlap.
This method removes any overlaps and reorders the fields in memory so they
have increasing byte offsets, and adds or removes padding bytes depending
on the `align` option, which behaves like the `align` option to
`numpy.dtype`.
If `align=False`, this method produces a "packed" memory layout in which
each field starts at the byte the previous field ended, and any padding
bytes are removed.
If `align=True`, this methods produces an "aligned" memory layout in which
each field's offset is a multiple of its alignment, and the total itemsize
is a multiple of the largest alignment, by adding padding bytes as needed.
Parameters
----------
a : ndarray or dtype
array or dtype for which to repack the fields.
align : boolean
If true, use an "aligned" memory layout, otherwise use a "packed" layout.
recurse : boolean
If True, also repack nested structures.
Returns
-------
repacked : ndarray or dtype
Copy of `a` with fields repacked, or `a` itself if no repacking was
needed.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> def print_offsets(d):
... print("offsets:", [d.fields[name][1] for name in d.names])
... print("itemsize:", d.itemsize)
...
"""
# 返回传入的参数a,作为重新打包后的结果
return (a,)
# 创建一个 NumPy 数据类型对象 `dt`,包含三个字段:'f0', 'f1', 'f2',分别对应于无符号字节(u1)、小端序64位整数(<i8)、小端序双精度浮点数(<f8)
dt = np.dtype('u1, <i8, <f8', align=True)
# 打印出创建的数据类型对象 `dt`
dt
# dtype({'names': ['f0', 'f1', 'f2'], 'formats': ['u1', '<i8', '<f8'], \
# 定义函数 print_offsets,用于打印偏移量和数据类型的大小
def print_offsets(dt):
"""
打印数据类型的偏移量和每个字段的大小
Parameters:
dt : np.dtype
NumPy 数据类型对象,描述了字段的布局和大小
"""
# 获取字段名和偏移量的列表
offsets = [f[1] for f in dt.fields.values()]
# 计算数据类型对象的字节大小
itemsize = dt.itemsize
# 打印字段偏移量列表和数据类型的大小
print(f"offsets: {offsets}")
print(f"itemsize: {itemsize}")
# 重新组织数据类型的字段布局,使得字段紧凑排列
def repack_fields(a, align=True, recurse=False):
"""
重新组织数据类型的字段布局,以减少内存占用
Parameters:
a : np.dtype or np.ndarray
要重新组织字段布局的数据类型或数组
align : bool, optional
是否按照字节对齐,默认为 True
recurse : bool, optional
是否递归处理嵌套的数据类型,默认为 False
Returns:
np.dtype
重新组织后的数据类型对象
"""
# 如果输入不是 np.dtype 对象,则递归调用 repack_fields 处理其 dtype 属性
if not isinstance(a, np.dtype):
dt = repack_fields(a.dtype, align=align, recurse=recurse)
return a.astype(dt, copy=False)
# 如果数据类型没有字段名,则直接返回
if a.names is None:
return a
# 存储字段信息的列表
fieldinfo = []
# 遍历每个字段名
for name in a.names:
# 获取字段元组 (dtype, offset, title)
tup = a.fields[name]
# 如果指定递归,则重新组织字段的 dtype
if recurse:
fmt = repack_fields(tup[0], align=align, recurse=True)
else:
fmt = tup[0]
# 如果字段元组包含标题信息,则重新组织为 (title, name)
if len(tup) == 3:
name = (tup[2], name)
# 将字段信息添加到列表中
fieldinfo.append((name, fmt))
# 创建并返回重新组织后的数据类型对象
dt = np.dtype(fieldinfo, align=align)
return np.dtype((a.type, dt))
# 获取数据类型的所有标量字段的列表,包括嵌套字段,按照从左到右的顺序排列
def _get_fields_and_offsets(dt, offset=0):
"""
返回数据类型 "dt" 中所有标量字段的列表,包括嵌套字段,按照从左到右的顺序排列
Parameters:
dt : np.dtype
NumPy 数据类型对象,描述了字段的布局和大小
offset : int, optional
偏移量,用于计算字段的绝对偏移,默认为 0
Returns:
list
包含 (dtype, count, offset) 元组的列表,描述了所有标量字段
"""
# 计算元素数和子数组中的元素数,返回基本 dtype 和元素数
def count_elem(dt):
count = 1
while dt.shape != ():
for size in dt.shape:
count *= size
dt = dt.base
return dt, count
# 存储字段列表的列表
fields = []
# 遍历每个字段名
for name in dt.names:
# 获取字段元组 (dtype, offset)
field = dt.fields[name]
f_dt, f_offset = field[0], field[1]
# 计算字段的元素数和基本 dtype
f_dt, n = count_elem(f_dt)
# 如果字段没有字段名,则直接添加到 fields 列表中
if f_dt.names is None:
fields.append((np.dtype((f_dt, (n,))), n, f_offset + offset))
else:
# 递归调用 _get_fields_and_offsets 处理子字段
subfields = _get_fields_and_offsets(f_dt, f_offset + offset)
size = f_dt.itemsize
# 扩展 fields 列表,处理子数组的情况
for i in range(n):
if i == 0:
fields.extend(subfields)
else:
fields.extend([(d, c, o + i*size) for d, c, o in subfields])
return fields
# 计算字段之间的公共步幅,如果步幅不是常数则返回 None
def _common_stride(offsets, counts, itemsize):
"""
返回字段之间的步幅,如果步幅不是常数则返回 None。counts 中的值指定子数组的长度,
子数组被视为许多连续字段,始终为正步幅。
Parameters:
offsets : list
字段的偏移量列表
counts : list
子数组的长度列表
itemsize : int
数据类型对象的字节大小
Returns:
int or None
字段之间的步幅,如果步幅不是常数则返回 None
"""
if len(offsets) <= 1:
return itemsize
# 检查是否存在负步幅
negative = offsets[1] < offsets[0]
if negative:
# 反转列表,使得偏移量升序排列
it = zip(reversed(offsets), reversed(counts))
else:
it = zip(offsets, counts)
prev_offset = None
stride = None
# 遍历迭代器中的偏移量和计数
for offset, count in it:
# 如果计数不为1,表示子数组总是 C 连续的
if count != 1:
# 如果需要负步长,则返回 None,因为子数组不可能有负步长
if negative:
return None
# 如果步长未指定,则设为元素大小
if stride is None:
stride = itemsize
# 如果步长与元素大小不同,则返回 None
if stride != itemsize:
return None
# 计算子数组的结束偏移量
end_offset = offset + (count - 1) * itemsize
else:
# 如果计数为1,直接将结束偏移量设为当前偏移量
end_offset = offset
# 如果存在前一个偏移量,则计算新的步长
if prev_offset is not None:
new_stride = offset - prev_offset
# 如果步长未指定,则设为新计算的步长
if stride is None:
stride = new_stride
# 如果当前步长与新计算的步长不同,则返回 None
if stride != new_stride:
return None
# 更新前一个偏移量为当前的结束偏移量
prev_offset = end_offset
# 如果需要负步长,则返回负的当前步长
if negative:
return -stride
# 否则返回当前步长
return stride
# 定义一个私有函数 _structured_to_unstructured_dispatcher,返回元组 (arr,)
def _structured_to_unstructured_dispatcher(arr, dtype=None, copy=None,
casting=None):
return (arr,)
# 使用装饰器 array_function_dispatch 将下面的函数注册为 arr 参数的处理函数
@array_function_dispatch(_structured_to_unstructured_dispatcher)
def structured_to_unstructured(arr, dtype=None, copy=False, casting='unsafe'):
"""
Converts an n-D structured array into an (n+1)-D unstructured array.
The new array will have a new last dimension equal in size to the
number of field-elements of the input array. If not supplied, the output
datatype is determined from the numpy type promotion rules applied to all
the field datatypes.
Nested fields, as well as each element of any subarray fields, all count
as a single field-elements.
Parameters
----------
arr : ndarray
Structured array or dtype to convert. Cannot contain object datatype.
dtype : dtype, optional
The dtype of the output unstructured array.
copy : bool, optional
If true, always return a copy. If false, a view is returned if
possible, such as when the `dtype` and strides of the fields are
suitable and the array subtype is one of `numpy.ndarray`,
`numpy.recarray` or `numpy.memmap`.
.. versionchanged:: 1.25.0
A view can now be returned if the fields are separated by a
uniform stride.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
See casting argument of `numpy.ndarray.astype`. Controls what kind of
data casting may occur.
Returns
-------
unstructured : ndarray
Unstructured array with one more dimension.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> a = np.zeros(4, dtype=[('a', 'i4'), ('b', 'f4,u2'), ('c', 'f4', 2)])
>>> a
array([(0, (0., 0), [0., 0.]), (0, (0., 0), [0., 0.]),
(0, (0., 0), [0., 0.]), (0, (0., 0), [0., 0.])],
dtype=[('a', '<i4'), ('b', [('f0', '<f4'), ('f1', '<u2')]), ('c', '<f4', (2,))])
>>> rfn.structured_to_unstructured(a)
array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
>>> b = np.array([(1, 2, 5), (4, 5, 7), (7, 8 ,11), (10, 11, 12)],
... dtype=[('x', 'i4'), ('y', 'f4'), ('z', 'f8')])
>>> np.mean(rfn.structured_to_unstructured(b[['x', 'z']]), axis=-1)
array([ 3. , 5.5, 9. , 11. ])
"""
# 检查 arr 是否为结构化数组,如果不是则抛出 ValueError 异常
if arr.dtype.names is None:
raise ValueError('arr must be a structured array')
# 调用 _get_fields_and_offsets 函数获取字段信息
fields = _get_fields_and_offsets(arr.dtype)
n_fields = len(fields)
# 如果 arr 中没有字段且未指定 dtype,则抛出 ValueError 异常
if n_fields == 0 and dtype is None:
raise ValueError("arr has no fields. Unable to guess dtype")
# 如果 arr 中没有字段,抛出 NotImplementedError 异常(这段代码暂时无法工作)
elif n_fields == 0:
raise NotImplementedError("arr with no fields is not supported")
# 解构 fields 元组,获取各个字段的 dtype、元素数量和偏移量
dts, counts, offsets = zip(*fields)
# 为每个字段生成一个默认名称
names = ['f{}'.format(n) for n in range(n_fields)]
# 如果未指定 dtype 参数,则根据 dts 中所有元素的基本数据类型推断输出的 dtype
if dtype is None:
out_dtype = np.result_type(*[dt.base for dt in dts])
else:
# 否则,使用指定的 dtype
out_dtype = np.dtype(dtype)
# 使用一系列视图和类型转换将数组转换为非结构化数组:
# 首先使用展平的字段视图(不适用于对象数组)
# 注意:dts 可能包含子数组的形状信息
flattened_fields = np.dtype({'names': names,
'formats': dts,
'offsets': offsets,
'itemsize': arr.dtype.itemsize})
arr = arr.view(flattened_fields)
# 我们只允许少数几种类型通过调整步幅转换为非结构化数组,因为我们知道它对于 np.matrix 或 np.ma.MaskedArray 是不起作用的。
can_view = type(arr) in (np.ndarray, np.recarray, np.memmap)
if (not copy) and can_view and all(dt.base == out_dtype for dt in dts):
# 所有元素已经具有正确的 dtype;如果它们有一个公共步幅,我们可以返回一个视图
common_stride = _common_stride(offsets, counts, out_dtype.itemsize)
if common_stride is not None:
wrap = arr.__array_wrap__
new_shape = arr.shape + (sum(counts), out_dtype.itemsize)
new_strides = arr.strides + (abs(common_stride), 1)
arr = arr[..., np.newaxis].view(np.uint8) # 视图为字节
arr = arr[..., min(offsets):] # 移除前导未使用数据
arr = np.lib.stride_tricks.as_strided(arr,
new_shape,
new_strides,
subok=True)
# 转换并再次去除最后一个维度
arr = arr.view(out_dtype)[..., 0]
if common_stride < 0:
arr = arr[..., ::-1] # 如果步幅为负数,则反转数组
if type(arr) is not type(wrap.__self__):
# 有些类型(如 recarray)在中间过程中转换为 ndarray,因此我们必须再次包装以匹配 copy=True 的行为。
arr = wrap(arr)
return arr
# 然后将所有字段转换为新的 dtype,并封装为紧凑格式
packed_fields = np.dtype({'names': names,
'formats': [(out_dtype, dt.shape) for dt in dts]})
arr = arr.astype(packed_fields, copy=copy, casting=casting)
# 最后安全地将紧凑格式视为非结构化类型
return arr.view((out_dtype, (sum(counts),)))
# 定义一个分派器函数,返回传入的数组作为元组的第一个元素
def _unstructured_to_structured_dispatcher(arr, dtype=None, names=None,
align=None, copy=None, casting=None):
return (arr,)
# 使用array_function_dispatch装饰器,将_unstructured_to_structured_dispatcher函数注册为unstructured_to_structured的分派器
@array_function_dispatch(_unstructured_to_structured_dispatcher)
def unstructured_to_structured(arr, dtype=None, names=None, align=False,
copy=False, casting='unsafe'):
"""
Converts an n-D unstructured array into an (n-1)-D structured array.
The last dimension of the input array is converted into a structure, with
number of field-elements equal to the size of the last dimension of the
input array. By default all output fields have the input array's dtype, but
an output structured dtype with an equal number of fields-elements can be
supplied instead.
Nested fields, as well as each element of any subarray fields, all count
towards the number of field-elements.
Parameters
----------
arr : ndarray
Unstructured array or dtype to convert.
dtype : dtype, optional
The structured dtype of the output array
names : list of strings, optional
If dtype is not supplied, this specifies the field names for the output
dtype, in order. The field dtypes will be the same as the input array.
align : boolean, optional
Whether to create an aligned memory layout.
copy : bool, optional
See copy argument to `numpy.ndarray.astype`. If true, always return a
copy. If false, and `dtype` requirements are satisfied, a view is
returned.
casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
See casting argument of `numpy.ndarray.astype`. Controls what kind of
data casting may occur.
Returns
-------
structured : ndarray
Structured array with fewer dimensions.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> dt = np.dtype([('a', 'i4'), ('b', 'f4,u2'), ('c', 'f4', 2)])
>>> a = np.arange(20).reshape((4,5))
>>> a
array([[ 0, 1, 2, 3, 4],
[ 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14],
[15, 16, 17, 18, 19]])
>>> rfn.unstructured_to_structured(a, dt)
array([( 0, ( 1., 2), [ 3., 4.]), ( 5, ( 6., 7), [ 8., 9.]),
(10, (11., 12), [13., 14.]), (15, (16., 17), [18., 19.])],
dtype=[('a', '<i4'), ('b', [('f0', '<f4'), ('f1', '<u2')]), ('c', '<f4', (2,))])
"""
# 如果数组是零维的,则抛出值错误
if arr.shape == ():
raise ValueError('arr must have at least one dimension')
# 获取数组最后一个维度的大小
n_elem = arr.shape[-1]
# 如果最后一个维度的大小为0,抛出未实现错误
if n_elem == 0:
# 太多其他地方的bug,现在无法工作
raise NotImplementedError("last axis with size 0 is not supported")
# 如果未提供dtype,则根据names创建一个默认的dtype,字段类型与输入数组的dtype相同
if dtype is None:
if names is None:
names = ['f{}'.format(n) for n in range(n_elem)]
# 根据names创建dtype,并根据align参数创建对齐的内存布局
out_dtype = np.dtype([(n, arr.dtype) for n in names], align=align)
# 获取结构化dtype的字段、计数和偏移量
fields = _get_fields_and_offsets(out_dtype)
dts, counts, offsets = zip(*fields)
else:
if names is not None:
raise ValueError("don't supply both dtype and names")
# 如果 dtype 参数不为 None,则将其转换为 np.dtype 对象
dtype = np.dtype(dtype)
# 对输入的 dtype 进行合理性检查,获取字段和偏移量
fields = _get_fields_and_offsets(dtype)
# 如果没有字段,则设置为空列表
if len(fields) == 0:
dts, counts, offsets = [], [], []
else:
# 解压字段元组为分离的数据类型、计数和偏移量
dts, counts, offsets = zip(*fields)
# 如果输入数组 arr 的最后一个维度的长度不等于字段数之和,则引发 ValueError 异常
if n_elem != sum(counts):
raise ValueError('The length of the last dimension of arr must '
'be equal to the number of fields in dtype')
# 将输出的数据类型设为输入的 dtype
out_dtype = dtype
# 如果 align 为 True 且输出数据类型不是对齐的结构体,则引发 ValueError 异常
if align and not out_dtype.isalignedstruct:
raise ValueError("align was True but dtype is not aligned")
# 为字段生成名称列表,如 ['f0', 'f1', ...]
names = ['f{}'.format(n) for n in range(len(fields))]
# 使用一系列视图和类型转换来转换为结构化数组:
# 第一步将 arr 视图转换为一个打包的结构化数组,使用一个统一的数据类型
packed_fields = np.dtype({'names': names,
'formats': [(arr.dtype, dt.shape) for dt in dts]})
arr = np.ascontiguousarray(arr).view(packed_fields)
# 接下来将其转换为一个展开但是扁平化的格式,各字段具有不同的数据类型
flattened_fields = np.dtype({'names': names,
'formats': dts,
'offsets': offsets,
'itemsize': out_dtype.itemsize})
arr = arr.astype(flattened_fields, copy=copy, casting=casting)
# 最后将其视图转换为最终的嵌套数据类型,并移除最后一个轴
return arr.view(out_dtype)[..., 0]
# 定义函数 _apply_along_fields_dispatcher,接收函数和数组作为参数,返回元组 (arr,)
def _apply_along_fields_dispatcher(func, arr):
return (arr,)
# 使用装饰器 array_function_dispatch 将 apply_along_fields 函数注册为 func 和 arr 参数组合的处理器
@array_function_dispatch(_apply_along_fields_dispatcher)
def apply_along_fields(func, arr):
"""
Apply function 'func' as a reduction across fields of a structured array.
This is similar to `numpy.apply_along_axis`, but treats the fields of a
structured array as an extra axis. The fields are all first cast to a
common type following the type-promotion rules from `numpy.result_type`
applied to the field's dtypes.
Parameters
----------
func : function
Function to apply on the "field" dimension. This function must
support an `axis` argument, like `numpy.mean`, `numpy.sum`, etc.
arr : ndarray
Structured array for which to apply func.
Returns
-------
out : ndarray
Result of the reduction operation
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> b = np.array([(1, 2, 5), (4, 5, 7), (7, 8 ,11), (10, 11, 12)],
... dtype=[('x', 'i4'), ('y', 'f4'), ('z', 'f8')])
>>> rfn.apply_along_fields(np.mean, b)
array([ 2.66666667, 5.33333333, 8.66666667, 11. ])
>>> rfn.apply_along_fields(np.mean, b[['x', 'z']])
array([ 3. , 5.5, 9. , 11. ])
"""
# 检查结构化数组是否有字段名,如果没有则抛出 ValueError 异常
if arr.dtype.names is None:
raise ValueError('arr must be a structured array')
# 将结构化数组 arr 转换为非结构化数组 uarr
uarr = structured_to_unstructured(arr)
# 调用 func 函数,沿着最后一个轴(字段轴)对 uarr 进行操作,返回结果
return func(uarr, axis=-1)
# 下面的方式可以工作并且避免了对轴的要求,但是非常慢:
#return np.apply_along_axis(func, -1, uarr)
# 定义函数 _assign_fields_by_name_dispatcher,接收 dst、src 和 zero_unassigned 参数,返回元组 (dst, src)
def _assign_fields_by_name_dispatcher(dst, src, zero_unassigned=None):
return dst, src
# 使用装饰器 array_function_dispatch 将 assign_fields_by_name 函数注册为 dst、src 和 zero_unassigned 参数组合的处理器
@array_function_dispatch(_assign_fields_by_name_dispatcher)
def assign_fields_by_name(dst, src, zero_unassigned=True):
"""
Assigns values from one structured array to another by field name.
Normally in numpy >= 1.14, assignment of one structured array to another
copies fields "by position", meaning that the first field from the src is
copied to the first field of the dst, and so on, regardless of field name.
This function instead copies "by field name", such that fields in the dst
are assigned from the identically named field in the src. This applies
recursively for nested structures. This is how structure assignment worked
in numpy >= 1.6 to <= 1.13.
Parameters
----------
dst : ndarray
src : ndarray
The source and destination arrays during assignment.
zero_unassigned : bool, optional
If True, fields in the dst for which there was no matching
field in the src are filled with the value 0 (zero). This
was the behavior of numpy <= 1.13. If False, those fields
are not modified.
"""
# 如果目标结构化数组 dst 没有字段名,则直接将整个 src 复制到 dst
if dst.dtype.names is None:
dst[...] = src
return
# 遍历目标数据结构的字段名列表
for name in dst.dtype.names:
# 检查当前字段名是否存在于源数据结构的字段名列表中
if name not in src.dtype.names:
# 如果目标字段名不存在于源字段名列表中
# 并且 zero_unassigned 参数为 True,则将目标字段置为 0
if zero_unassigned:
dst[name] = 0
else:
# 如果目标字段名存在于源字段名列表中,则调用函数将对应字段的值赋给目标字段
assign_fields_by_name(dst[name], src[name],
zero_unassigned)
# 返回输入数组作为元组的单元素元组
def _require_fields_dispatcher(array, required_dtype):
return (array,)
# 使用 array_function_dispatch 装饰器来声明 require_fields 函数的调度
@array_function_dispatch(_require_fields_dispatcher)
def require_fields(array, required_dtype):
"""
Casts a structured array to a new dtype using assignment by field-name.
This function assigns from the old to the new array by name, so the
value of a field in the output array is the value of the field with the
same name in the source array. This has the effect of creating a new
ndarray containing only the fields "required" by the required_dtype.
If a field name in the required_dtype does not exist in the
input array, that field is created and set to 0 in the output array.
Parameters
----------
a : ndarray
array to cast
required_dtype : dtype
datatype for output array
Returns
-------
out : ndarray
array with the new dtype, with field values copied from the fields in
the input array with the same name
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> a = np.ones(4, dtype=[('a', 'i4'), ('b', 'f8'), ('c', 'u1')])
>>> rfn.require_fields(a, [('b', 'f4'), ('c', 'u1')])
array([(1., 1), (1., 1), (1., 1), (1., 1)],
dtype=[('b', '<f4'), ('c', 'u1')])
>>> rfn.require_fields(a, [('b', 'f4'), ('newf', 'u1')])
array([(1., 0), (1., 0), (1., 0), (1., 0)],
dtype=[('b', '<f4'), ('newf', 'u1')])
"""
# 使用 required_dtype 创建一个形状与输入数组相同的空数组
out = np.empty(array.shape, dtype=required_dtype)
# 调用 assign_fields_by_name 函数,通过字段名从输入数组复制值到输出数组
assign_fields_by_name(out, array)
# 返回新创建的输出数组
return out
# 返回输入数组作为元组的单元素元组
def _stack_arrays_dispatcher(arrays, defaults=None, usemask=None,
asrecarray=None, autoconvert=None):
return arrays
# 使用 array_function_dispatch 装饰器来声明 stack_arrays 函数的调度
@array_function_dispatch(_stack_arrays_dispatcher)
def stack_arrays(arrays, defaults=None, usemask=True, asrecarray=False,
autoconvert=False):
"""
Superposes arrays fields by fields
Parameters
----------
arrays : array or sequence
Sequence of input arrays.
defaults : dictionary, optional
Dictionary mapping field names to the corresponding default values.
usemask : {True, False}, optional
Whether to return a MaskedArray (or MaskedRecords is
`asrecarray==True`) or a ndarray.
asrecarray : {False, True}, optional
Whether to return a recarray (or MaskedRecords if `usemask==True`)
or just a flexible-type ndarray.
autoconvert : {False, True}, optional
Whether automatically cast the type of the field to the maximum.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> x = np.array([1, 2,])
>>> rfn.stack_arrays(x) is x
True
>>> z = np.array([('A', 1), ('B', 2)], dtype=[('A', '|S3'), ('B', float)])
>>> zz = np.array([('a', 10., 100.), ('b', 20., 200.), ('c', 30., 300.)],
... dtype=[('A', '|S3'), ('B', np.double), ('C', np.double)])
>>> test = rfn.stack_arrays((z,zz))
>>> test
"""
# 返回输入的 arrays 参数
return arrays
# 创建一个带有数据、掩码和填充值的结构化数组
masked_array(data=[(b'A', 1.0, --), (b'B', 2.0, --), (b'a', 10.0, 100.0),
(b'b', 20.0, 200.0), (b'c', 30.0, 300.0)],
mask=[(False, False, True), (False, False, True),
(False, False, False), (False, False, False),
(False, False, False)],
fill_value=(b'N/A', 1e+20, 1e+20),
dtype=[('A', 'S3'), ('B', '<f8'), ('C', '<f8')])
"""
# 检查输入参数 arrays 是否为 ndarray 类型
if isinstance(arrays, ndarray):
return arrays
# 如果 arrays 只包含一个元素,则直接返回该元素
elif len(arrays) == 1:
return arrays[0]
# 将 arrays 中的每个元素转换为任意数组并展平
seqarrays = [np.asanyarray(a).ravel() for a in arrays]
# 计算每个序列数组的长度
nrecords = [len(a) for a in seqarrays]
# 获取每个序列数组的数据类型
ndtype = [a.dtype for a in seqarrays]
# 获取每个字段的名称
fldnames = [d.names for d in ndtype]
#
# 从第一个数据类型获取字段描述符
dtype_l = ndtype[0]
newdescr = _get_fieldspec(dtype_l)
# 提取字段的名称列表
names = [n for n, d in newdescr]
# 遍历其他数据类型的字段描述符
for dtype_n in ndtype[1:]:
# 遍历每个字段及其类型的元组
for fname, fdtype in _get_fieldspec(dtype_n):
# 如果字段名称不在当前名称列表中,则添加新的字段描述符
if fname not in names:
newdescr.append((fname, fdtype))
names.append(fname)
else:
# 如果字段名称已存在,检查类型是否兼容
nameidx = names.index(fname)
_, cdtype = newdescr[nameidx]
if autoconvert:
# 如果自动转换为真,则更新字段的最大类型
newdescr[nameidx] = (fname, max(fdtype, cdtype))
elif fdtype != cdtype:
# 如果类型不兼容且禁止自动转换,则引发类型错误异常
raise TypeError("Incompatible type '%s' <> '%s'" %
(cdtype, fdtype))
# 如果只有一个字段,则使用 concatenate 连接数组
if len(newdescr) == 1:
output = ma.concatenate(seqarrays)
else:
#
# 创建一个具有指定形状和数据类型的掩码数组
output = ma.masked_all((np.sum(nrecords),), newdescr)
# 计算偏移量,用于确定每个序列数组的位置范围
offset = np.cumsum(np.r_[0, nrecords])
seen = []
# 遍历序列数组、字段名称和偏移量,填充输出数组
for (a, n, i, j) in zip(seqarrays, fldnames, offset[:-1], offset[1:]):
names = a.dtype.names
if names is None:
output['f%i' % len(seen)][i:j] = a
else:
for name in n:
output[name][i:j] = a[name]
if name not in seen:
seen.append(name)
#
# 返回修正后的输出结果
return _fix_output(_fix_defaults(output, defaults),
usemask=usemask, asrecarray=asrecarray)
# 定义一个私有函数 _find_duplicates_dispatcher,用于作为 find_duplicates 函数的分发器
def _find_duplicates_dispatcher(
a, key=None, ignoremask=None, return_index=None):
# 返回输入数组 a,作为分发器的返回结果
return (a,)
# 使用装饰器将 _find_duplicates_dispatcher 注册为 find_duplicates 函数的分发器
@array_function_dispatch(_find_duplicates_dispatcher)
def find_duplicates(a, key=None, ignoremask=True, return_index=False):
"""
Find the duplicates in a structured array along a given key
Parameters
----------
a : array-like
Input array
key : {string, None}, optional
Name of the fields along which to check the duplicates.
If None, the search is performed by records
ignoremask : {True, False}, optional
Whether masked data should be discarded or considered as duplicates.
return_index : {False, True}, optional
Whether to return the indices of the duplicated values.
Examples
--------
>>> from numpy.lib import recfunctions as rfn
>>> ndtype = [('a', int)]
>>> a = np.ma.array([1, 1, 1, 2, 2, 3, 3],
... mask=[0, 0, 1, 0, 0, 0, 1]).view(ndtype)
>>> rfn.find_duplicates(a, ignoremask=True, return_index=True)
(masked_array(data=[(1,), (1,), (2,), (2,)],
mask=[(False,), (False,), (False,), (False,)],
fill_value=(999999,),
dtype=[('a', '<i8')]), array([0, 1, 3, 4]))
"""
# 将输入数组转换为一个扁平化的 NumPy 数组
a = np.asanyarray(a).ravel()
# 获取数组元素的字段结构的字典
fields = get_fieldstructure(a.dtype)
# 设置基础数据为输入数组
base = a
# 如果指定了 key,则根据 key 所指定的字段进行排序
if key:
for f in fields[key]:
base = base[f]
base = base[key]
# 获取排序后的索引和排序后的数据
sortidx = base.argsort()
sortedbase = base[sortidx]
sorteddata = sortedbase.filled()
# 比较排序后的数据,找到重复项
flag = (sorteddata[:-1] == sorteddata[1:])
# 如果 ignoremask 为 True,则在需要时将 flag 设置为 False
if ignoremask:
sortedmask = sortedbase.recordmask
flag[sortedmask[1:]] = False
# 将 flag 向左扩展一个位置,以便包括左边的点
flag = np.concatenate(([False], flag))
# 需要将左边的点也包含进去,否则会遗漏
flag[:-1] = flag[:-1] + flag[1:]
# 根据排序后的索引找到重复项的值
duplicates = a[sortidx][flag]
# 如果 return_index 为 True,则返回重复项和它们的索引
if return_index:
return (duplicates, sortidx[flag])
else:
# 否则,只返回重复项的值
return duplicates
# 定义一个私有函数 _join_by_dispatcher,用于作为 join_by 函数的分发器
def _join_by_dispatcher(
key, r1, r2, jointype=None, r1postfix=None, r2postfix=None,
defaults=None, usemask=None, asrecarray=None):
# 返回 r1 和 r2,作为分发器的返回结果
return (r1, r2)
# 使用装饰器将 _join_by_dispatcher 注册为 join_by 函数的分发器
@array_function_dispatch(_join_by_dispatcher)
def join_by(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
defaults=None, usemask=True, asrecarray=False):
"""
Join arrays `r1` and `r2` on key `key`.
The key should be either a string or a sequence of string corresponding
to the fields used to join the array. An exception is raised if the
`key` field cannot be found in the two input arrays. Neither `r1` nor
`r2` should have any duplicates along `key`: the presence of duplicates
will make the output quite unreliable. Note that duplicates are not
"""
# 函数 join_by 的文档字符串,描述了函数的作用和参数信息,但不包括代码功能的详细解释
# 检查连接类型参数 jointype 是否合法
if jointype not in ('inner', 'outer', 'leftouter'):
raise ValueError(
"The 'jointype' argument should be in 'inner', "
"'outer' or 'leftouter' (got '%s' instead)" % jointype
)
# 如果 key 是字符串,则转换为包含一个元素的元组
if isinstance(key, str):
key = (key,)
# 检查 key 是否有重复项
if len(set(key)) != len(key):
# 如果有重复项,抛出 ValueError 异常
dup = next(x for n,x in enumerate(key) if x in key[n+1:])
raise ValueError("duplicate join key %r" % dup)
# 检查 key 中的字段名是否存在于 r1 和 r2 的数据类型中
for name in key:
if name not in r1.dtype.names:
# 如果 key 中的字段名在 r1 中不存在,抛出 ValueError 异常
raise ValueError('r1 does not have key field %r' % name)
if name not in r2.dtype.names:
# 如果 key 中的字段名在 r2 中不存在,抛出 ValueError 异常
raise ValueError('r2 does not have key field %r' % name)
# 将 r1 和 r2 转换为扁平化数组
r1 = r1.ravel()
r2 = r2.ravel()
# 获取 r1 和 r2 的字段名,并赋值给 r1names 和 r2names
(r1names, r2names) = (r1.dtype.names, r2.dtype.names)
# 检查字段名是否有冲突
collisions = (set(r1names) & set(r2names)) - set(key)
# 如果发生冲突并且 r1postfix 和 r2postfix 都为空,则抛出 ValueError 异常
if collisions and not (r1postfix or r2postfix):
msg = "r1 and r2 contain common names, r1postfix and r2postfix "
msg += "can't both be empty"
raise ValueError(msg)
# 创建仅包含键的临时数组
# (使用 `r1` 中键的顺序保持向后兼容性)
key1 = [n for n in r1names if n in key]
# 从 `r1` 和 `r2` 中仅保留与 `key1` 相关的字段
r1k = _keep_fields(r1, key1)
r2k = _keep_fields(r2, key1)
# 将两个数组连接起来进行比较
aux = ma.concatenate((r1k, r2k))
# 根据 `key` 的顺序对 `aux` 进行排序并返回排序后的索引
idx_sort = aux.argsort(order=key)
aux = aux[idx_sort]
# 获取共同的键
flag_in = ma.concatenate(([False], aux[1:] == aux[:-1]))
flag_in[:-1] = flag_in[1:] + flag_in[:-1]
idx_in = idx_sort[flag_in]
idx_1 = idx_in[(idx_in < nb1)]
idx_2 = idx_in[(idx_in >= nb1)] - nb1
(r1cmn, r2cmn) = (len(idx_1), len(idx_2))
# 根据联接类型进行处理
if jointype == 'inner':
(r1spc, r2spc) = (0, 0)
elif jointype == 'outer':
idx_out = idx_sort[~flag_in]
idx_1 = np.concatenate((idx_1, idx_out[(idx_out < nb1)]))
idx_2 = np.concatenate((idx_2, idx_out[(idx_out >= nb1)] - nb1))
(r1spc, r2spc) = (len(idx_1) - r1cmn, len(idx_2) - r2cmn)
elif jointype == 'leftouter':
idx_out = idx_sort[~flag_in]
idx_1 = np.concatenate((idx_1, idx_out[(idx_out < nb1)]))
(r1spc, r2spc) = (len(idx_1) - r1cmn, 0)
# 从每个输入中选择条目
(s1, s2) = (r1[idx_1], r2[idx_2])
# 构建输出数组的新描述......
# 从键字段开始
ndtype = _get_fieldspec(r1k.dtype)
# 添加来自 `r1` 的字段
for fname, fdtype in _get_fieldspec(r1.dtype):
if fname not in key:
ndtype.append((fname, fdtype))
# 添加来自 `r2` 的字段
for fname, fdtype in _get_fieldspec(r2.dtype):
# 我们之前是否已经见过当前的名称?
# 我们每次都需要重建这个列表
names = list(name for name, dtype in ndtype)
try:
nameidx = names.index(fname)
except ValueError:
# ... 我们之前没有见过:将描述添加到当前列表中
ndtype.append((fname, fdtype))
else:
# 发生冲突
_, cdtype = ndtype[nameidx]
if fname in key:
# 当前字段是键的一部分:取最大的 dtype
ndtype[nameidx] = (fname, max(fdtype, cdtype))
else:
# 当前字段不是键的一部分:添加后缀,并将新字段放置在旧字段的旁边
ndtype[nameidx:nameidx + 1] = [
(fname + r1postfix, cdtype),
(fname + r2postfix, fdtype)
]
# 从新字段重新构建 dtype
ndtype = np.dtype(ndtype)
# 找到最大的共同字段数:
# r1cmn 和 r2cmn 应该相等,但...
cmn = max(r1cmn, r2cmn)
# 构建一个空数组
# 创建一个所有元素都是掩码值的数组,形状为 (cmn + r1spc + r2spc,),数据类型为 ndtype
output = ma.masked_all((cmn + r1spc + r2spc,), dtype=ndtype)
# 获取输出数组的字段名列表
names = output.dtype.names
# 遍历 r1names 中的字段名
for f in r1names:
# 从 s1 中获取名为 f 的字段数据
selected = s1[f]
# 如果字段名 f 不在 names 中,或者在 r2names 中但没有 r2postfix 且不在 key 中,则添加 r1postfix 后缀
if f not in names or (f in r2names and not r2postfix and f not in key):
f += r1postfix
# 获取当前输出数组中名为 f 的字段
current = output[f]
# 将 selected 中前 r1cmn 个元素复制到 current 的前 r1cmn 个位置
current[:r1cmn] = selected[:r1cmn]
# 如果 jointype 是 'outer' 或 'leftouter',将 selected 中 r1cmn 之后的元素复制到 current 的 cmn 到 cmn + r1spc 位置
if jointype in ('outer', 'leftouter'):
current[cmn:cmn + r1spc] = selected[r1cmn:]
# 再次遍历 r2names 中的字段名
for f in r2names:
# 从 s2 中获取名为 f 的字段数据
selected = s2[f]
# 如果字段名 f 不在 names 中,或者在 r1names 中但没有 r1postfix 且不在 key 中,则添加 r2postfix 后缀
if f not in names or (f in r1names and not r1postfix and f not in key):
f += r2postfix
# 获取当前输出数组中名为 f 的字段
current = output[f]
# 将 selected 中前 r2cmn 个元素复制到 current 的前 r2cmn 个位置
current[:r2cmn] = selected[:r2cmn]
# 如果 jointype 是 'outer' 并且 r2spc 不为零,则将 selected 中最后 r2spc 个元素复制到 current 的末尾 r2spc 个位置
if (jointype == 'outer') and r2spc:
current[-r2spc:] = selected[r2cmn:]
# 对输出数组按照指定的 key 进行排序
output.sort(order=key)
# 构建关键字参数字典
kwargs = dict(usemask=usemask, asrecarray=asrecarray)
# 调用 _fix_defaults 函数处理默认值,然后调用 _fix_output 函数修正输出
return _fix_output(_fix_defaults(output, defaults), **kwargs)
# 使用 `_rec_join_dispatcher` 函数作为分派函数的装饰器,用于分发不同情况下的数组连接操作
def _rec_join_dispatcher(
key, r1, r2, jointype=None, r1postfix=None, r2postfix=None,
defaults=None):
# 返回 r1 和 r2 这两个参数的元组
return (r1, r2)
# 使用 `array_function_dispatch` 装饰器,将 `_rec_join_dispatcher` 函数与 `rec_join` 函数关联
@array_function_dispatch(_rec_join_dispatcher)
def rec_join(key, r1, r2, jointype='inner', r1postfix='1', r2postfix='2',
defaults=None):
"""
Join arrays `r1` and `r2` on keys.
Alternative to join_by, that always returns a np.recarray.
See Also
--------
join_by : equivalent function
"""
# 设置关键字参数的默认值,并创建关键字参数字典 `kwargs`
kwargs = dict(jointype=jointype, r1postfix=r1postfix, r2postfix=r2postfix,
defaults=defaults, usemask=False, asrecarray=True)
# 调用 `join_by` 函数进行数组连接操作,返回一个 `np.recarray` 类型的对象
return join_by(key, r1, r2, **kwargs)