举例说明Python的__matmul__()方法

196 阅读3分钟

语法

object.__matmul__(self, other)

调用Python__matmul__() 方法是为了实现矩阵乘法操作@ 。例如,为了评估表达式x @ y ,Python试图调用x.__matmul__(y)

我们称这种方法为*"* Dunder Method",即*"Double UnderscoreMethod"(也叫"Magic Method")*。

@ 操作符从3.5开始被引入Python的核心语法,这要感谢 PEP 465。它的唯一目标是解决矩阵乘法的问题。它甚至带有一个漂亮的助记符--@ 就是*代表mATrices。

不寻常的是,@ 被添加到核心 Python 语言中,而它只在某些库中使用。幸运的是,我们唯一一次使用@ 是用于装饰器函数。所以你不太可能会感到困惑。

例子

在下面的例子中,你创建了一个自定义类Data 并覆盖了__matmul__() 方法,这样就创建了一个新的Data 对象,其值是两个操作数ab 的矩阵乘法,类型为Data

class Data:
        
    def __matmul__(self, other):
        return '... my result of matmul...'


a = Data()
b = Data()
c = a @ b

print(c)
# ... my result of matmul...

如果你没有定义__matmul__() 方法,Python 会提出一个TypeError

如何解决TypeError:@的操作数类型不受支持

考虑下面的代码片段,你试图在没有定义dunder方法的情况下将两个自定义对象相乘__matmul__()

class Data:
    pass


a = Data()
b = Data()
c = a @ b

print(c)

在我的电脑上运行这个导致以下错误信息。

Traceback (most recent call last):
  File "C:\Users\xcent\Desktop\code.py", line 7, in <module>
    c = a @ b
TypeError: unsupported operand type(s) for @: 'Data' and 'Data'

这个错误的原因是,__matmul__() dunder方法从未被定义过,而且默认情况下没有为自定义对象定义该方法。所以,为了解决TypeError: unsupported operand type(s) for @ ,你需要在你的类定义中提供__matmul__(self, other) 方法,如前所示。

class Data:
        
    def __matmul__(self, other):
        return '... my result of matmul...'

NumPy的矩阵乘法@

np.matmul() vs np.dot() vs @ Matrix Multiplication Operators

要在两个NumPy数组之间进行矩阵乘法,请查看@ 操作符。

# Python >= 3.5
# 2x2 arrays where each value is 1.0
>>> A = np.ones((2, 2))
>>> B = np.ones((2, 2))

>>> A @ B
array([[2., 2.],
      [2., 2.]]) 

Python的__mul__与__rmul__对比

假设,你想对两个对象进行矩阵乘法运算xy

print(x @ y)

Python 首先尝试调用左边对象的__matmul__() 方法x.__matmul__(y) 。但是这可能因为两个原因而失败。

  1. 方法x.__matmul__() 首先没有实现,或者
  2. 方法x.__matmul__() 已经实现,但返回一个NotImplemented 值,表明数据类型不兼容。

如果失败了,Python 试图通过调用y.__rmatmul__() 来解决这个问题,用于反向矩阵乘法的右侧运算符y 。如果这个方法被实现,Python 知道它没有遇到非交换性操作的潜在问题。如果它只是执行y.__matmul__(x) 而不是x.__matmul__(y) ,那么如果乘法是非交换性的,就会引起错误。这就是为什么需要y.__rmatmul__(x) ,表明矩阵乘法毕竟是可能的。

因此,x.__matmul__(y)x.__rmatmul__(y) 之间的区别是,前者计算x @ y 而后者计算y @ x - 两者都调用定义在对象x 上的各自矩阵乘法。