## NumPy:核心数据结构详解

243 阅读4分钟

在机器学习和深度学习领域,NumPy 是一个不可或缺的工具。它为 Python 提供了强大的科学计算支持,尤其是其核心数据结构——多维数组(ndarray)。今天,我们就来深入了解一下 NumPy 的核心数据结构以及它的基本操作。

为什么学习 NumPy?

NumPy 是 Python 中用于科学计算的基础包,它提供了高效的多维数组对象以及针对数组的各种快速操作。无论是 PyTorch 还是 TensorFlow,这些主流的深度学习框架中的基本计算单元 Tensor,都与 NumPy 数组有着类似的计算逻辑。因此,掌握 NumPy 对于学习这些框架以及进行数据科学和科学计算都至关重要。

安装 NumPy

安装 NumPy 非常简单,可以通过以下命令进行安装:

bash复制

conda install numpy

或者使用 pip:

bash复制

pip install numpy

NumPy 数组

NumPy 数组是 NumPy 的核心,它是一个多维数组对象,称为 ndarray。与 Python 中的列表相比,NumPy 数组具有以下特点:

  1. 固定大小:NumPy 数组在创建时大小固定,不能动态改变。如果需要改变大小,需要创建一个新的数组。
  2. 数据类型一致:NumPy 数组中的所有元素必须具有相同的数据类型。
  3. 高效运算:NumPy 对数组运算进行了优化,速度更快,且占用内存更少。

创建数组

创建 NumPy 数组最简单的方法是将一个列表传入 np.array()np.asarray()。例如:

Python复制

import numpy as np

# 创建一维数组
arr_1_d = np.asarray([1])
print(arr_1_d)  # 输出:[1]

# 创建二维数组
arr_2_d = np.asarray([[1, 2], [3, 4]])
print(arr_2_d)  # 输出:
                # [[1 2]
                #  [3 4]]

数组的属性

NumPy 数组有多个重要的属性,包括维度、形状、大小和数据类型。

ndim

ndim 表示数组的维度数。例如:

Python复制

print(arr_1_d.ndim)  # 输出:1
print(arr_2_d.ndim)  # 输出:2

shape

shape 表示数组的形状,是一个整数元组。例如:

Python复制

print(arr_1_d.shape)  # 输出:(1,)
print(arr_2_d.shape)  # 输出:(2, 2)

size

size 表示数组中元素的总数,等于 shape 中各维度大小的乘积。例如:

Python复制

print(arr_2_d.size)  # 输出:4

dtype

dtype 表示数组中元素的数据类型。例如:

Python复制

print(arr_2_d.dtype)  # 输出:int64

数组的形状变换

可以使用 reshape() 方法改变数组的形状,但需要注意变换前后数组的元素总数必须一致。例如:

Python复制

arr_2_d = np.asarray([[1, 2], [3, 4]])
print(arr_2_d.reshape((4, 1)))  # 输出:
                                # [[1]
                                #  [2]
                                #  [3]
                                #  [4]]

其他创建数组的方法

NumPy 提供了多种方法来创建数组:

np.ones() 和 np.zeros()

np.ones() 创建全为 1 的数组,np.zeros() 创建全为 0 的数组。例如:

Python复制

print(np.ones(shape=(2, 3)))  # 输出:
                              # [[1. 1. 1.]
                              #  [1. 1. 1.]]
print(np.zeros(shape=(2, 3)))  # 输出:
                               # [[0. 0. 0.]
                               #  [0. 0. 0.]]

np.arange()

np.arange() 创建一个在指定区间内的数组。例如:

Python复制

print(np.arange(5))  # 输出:[0 1 2 3 4]
print(np.arange(2, 9, 3))  # 输出:[2 5 8]

np.linspace()

np.linspace() 创建一个等差数列。例如:

Python复制

print(np.linspace(start=2, stop=10, num=3))  # 输出:[ 2.  6. 10.]

数组的轴

轴是 NumPy 数组中的一个重要概念,它经常出现在聚合函数中。例如,np.sum()np.max() 等函数可以通过指定轴来实现不同的聚合操作。

对于一个二维数组,axis=0 表示沿着行的方向聚合,axis=1 表示沿着列的方向聚合。例如:

Python复制

interest_score = np.random.randint(10, size=(4, 3))
print(interest_score)  # 输出一个随机生成的 4x3 数组

# 沿着 axis=0 求和
print(np.sum(interest_score, axis=0))  # 输出每列的和

# 沿着 axis=1 求和
print(np.sum(interest_score, axis=1))  # 输出每行的和

对于更高维度的数组,axis 的含义可以类推。例如,对于一个形状为 (3, 2, 3) 的三维数组,axis=0 表示沿着第一个维度聚合,axis=1 表示沿着第二个维度聚合,axis=2 表示沿着第三个维度聚合。

小结

NumPy 是 Python 中用于科学计算的基础包,它的核心是多维数组对象 ndarray。通过掌握 NumPy 的数组创建、属性访问以及形状变换,我们可以更高效地进行数据处理。此外,理解数组的轴的概念对于使用聚合函数进行数据操作至关重要。