学习NumPy的roll函数

753 阅读2分钟

NumPy的滚动函数用于沿指定的轴滚动输入数组中的元素。滚动指的是移动给定数组中元素的位置的处理。

如果一个元素从第一个位置移到最后一个位置,它就会被移回第一个位置。

让我们来探索NumPy中的滚动函数。

函数语法

该函数的语法如下所示。

numpy.roll(a, shift, axis=None)

参数如图所示。

  1. a - 定义了输入数组。
  2. shift - 指数组中的元素被移位的位数。
  3. axis - 指的是指定元素被移位的轴。

函数的返回值

该函数返回一个数组,其中指定轴上的元素按shift参数指定的系数进行了移位。

注意:输出数组的形状与输入数组相同。

例子1

请看下面的示例代码。

import numpy as np
arr = np.array([[1,2,3], [4,5,6]])
print(f"original: {arr}")
print(f"shifted: {np.roll(arr, shift=1, axis=0)}")

上面的代码显示了如何使用roll函数将一个二维数组中的元素沿0轴移动1倍。

结果输出如图所示。

original: [[1 2 3]
[4 5 6]]
shifted: [[4 5 6]
[1 2 3]]

例2

考虑另一个例子,沿轴1执行同样的操作。

arr = np.array([[1,2,3], [4,5,6]])
print(f"original: {arr}")
print(f"shifted: {np.roll(arr, shift=1, axis=1)}")

在这种情况下,roll函数沿轴1进行移位操作并返回。

original: [[1 2 3]
[4 5 6]]
shifted: [[3 1 2]
[6 4 5]]

例3

下面的代码说明了如何使用roll函数将数组中的元素移到5位。

arr = np.array([[1,2,3], [4,5,6]])
print(f"original: {arr}")
print(f"shifted: {np.roll(arr, shift=5, axis=0)}")

这里,我们将移位参数设置为5,轴设置为0,得到的数组如图所示。

original: [[1 2 3]
[4 5 6]]
shifted: [[4 5 6]
[1 2 3]]

例5

你也可以把移位值指定为一个元组。在这种情况下,轴必须是一个相同大小的元组。

以下面的代码为例。

arr = np.arange(10).reshape(2,5)
print(f"original: {arr}")
print(f"shifted: {np.roll(arr, (2,1), axis=(1,0))}")

上面的代码应该返回。

original: [[0 1 2 3 4]
[5 6 7 8 9]]
shifted: [[8 9 5 6 7]
[3 4 0 1 2]]

结尾

在这篇文章中,我们讨论了NumPy的roll函数,它是什么,它的参数,以及返回值。我们还用各种例子演示了如何使用该函数。