学习Numpy矩阵乘法

525 阅读1分钟

在这篇文章中,我们将看到如何使用NumPy进行矩阵乘法。numpy矩阵乘法的输入参数是两个类似数组的对象(也可以是numpy ndarray或python列表),它产生两个矩阵的乘积作为输出。在NumPy数组上执行矩阵乘法比在python列表上执行矩阵乘法更有效率。

让我们从导入NumPy开始,使用NumPy的矩阵乘法np.matmul执行一个简单的矩阵乘法。

$ python3

Python 3.8.5 (default, Mar 8 2021, 13:02:45)

[GCC 9.3.0] on linux2

输入 "help"、"copyright"、"credits "或 "license "获取更多信息。

>>> import numpy as np

>>> a = np.array([[1, 2, 3],

...           [4, 5, 6]])

>>> a.shape

(2, 3)

>>> b = np.array([[1, 1],

...           [1, 1],

...           [1, 1]])

>>> b.shape

(3, 2)

>>> c = np.matmul(a, b)

>>> c.shape

(2, 2)

>>> c

array([[ 6,  6],

           [15, 15]])

numpy中的矩阵乘法遵循的签名是(n, k) * (k, m) -> (n, m)。有时,我们需要对一个矩阵进行简单的标量乘法。为了执行标量乘法,可以使用操作符*。

>>> a = np.ones((2, 3))

>>> a

array([[1., 1., 1.],

           [1., 1., 1.]])

>>> a * 2

array([[2., 2., 2.],

           [2., 2., 2.]])

np.matmul也可以用来执行多维矩阵的乘法。在多维矩阵的情况下,输入矩阵的最后两个维度被考虑用于矩阵乘法。

>>> a = np.ones((8, 3, 2))

>>> a.shape

(8, 3, 2)

>>> b = np.ones((8, 2, 5))

>>> b.shape

(8, 2, 5)

>>> c = np.matmul(a, b)

>>> c.shape

(8, 3, 5)

当对np.matmul的输入进行操作时,numpy会比较它们的形状来检查两个数组之间的矩阵乘法是否兼容。理想情况下,矩阵1的最后一维应该与矩阵2的倒数第二维相同。如果它们不兼容,会产生一个值错误。

>>> a = np.ones((2, 3))

>>> a.shape

(2, 3)

>>> b = np.ones((1, 2))

>>> b.shape

(1, 2)

>>> np.matmul(a, b)

Traceback (most recent call last):

...

ValueError: matmul: Input operand 1 has a mismatch ...