torch.matmul
torch.matmul是tensor的乘法,输入可以是高维的。
当输入是都是二维时,就是普通的矩阵乘法,和tensor.mm函数用法相同。
a = torch.tensor([[1,2], [3,4]])
b = torch.tensor([[2,2], [3,4]])
torch.matmul(a, b)
Out[34]:
tensor([[ 8, 10],
[18, 22]])
torch.matmul(a, b).shape
Out[35]: torch.Size([2, 2])
如果维度更高呢?前面的维度必须要相同,然后最里面的两个维度符合矩阵相乘的形状限制:i×j,j×k。
a = torch.tensor([[[1,2], [3,4], [5,6]],[[7,8], [9,10], [11,12]]])
a
Out[37]:
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[ 7, 8],
[ 9, 10],
[11, 12]]])
a.shape
Out[38]: torch.Size([2, 3, 2])
b = torch.tensor([[[1,2], [3,4]],[[7,8], [9,10]]])
b
Out[40]:
tensor([[[ 1, 2],
[ 3, 4]],
[[ 7, 8],
[ 9, 10]]])
b.shape
Out[41]: torch.Size([2, 2, 2])
torch.matmul(a, b)
Out[42]:
tensor([[[ 7, 10],
[ 15, 22],
[ 23, 34]],
[[121, 136],
[153, 172],
[185, 208]]])
# a 和 b 的最外面的维度都是 2,相同。
# 最里面两个维度分别是 3 × 2 和 2 × 2,那么乘完以后就是 3 × 2
torch.matmul(a, b).shape
Out[43]: torch.Size([2, 3, 2])