torch.matmul()

534 阅读1分钟

torch.matmul

torch.matmul是tensor的乘法,输入可以是高维的。
当输入是都是二维时,就是普通的矩阵乘法,和tensor.mm函数用法相同。

a = torch.tensor([[1,2], [3,4]])
b = torch.tensor([[2,2], [3,4]])

torch.matmul(a, b)
Out[34]: 
tensor([[ 8, 10],
        [18, 22]])
        
torch.matmul(a, b).shape
Out[35]: torch.Size([2, 2])

如果维度更高呢?前面的维度必须要相同,然后最里面的两个维度符合矩阵相乘的形状限制:i×jj×k。

a = torch.tensor([[[1,2], [3,4], [5,6]],[[7,8], [9,10], [11,12]]])

a
Out[37]: 
tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],
        [[ 7,  8],
         [ 9, 10],
         [11, 12]]])
         
a.shape
Out[38]: torch.Size([2, 3, 2])

b = torch.tensor([[[1,2], [3,4]],[[7,8], [9,10]]])
b
Out[40]: 
tensor([[[ 1,  2],
         [ 3,  4]],
        [[ 7,  8],
         [ 9, 10]]])
         
b.shape
Out[41]: torch.Size([2, 2, 2])

torch.matmul(a, b)
Out[42]: 
tensor([[[  7,  10],
         [ 15,  22],
         [ 23,  34]],
        [[121, 136],
         [153, 172],
         [185, 208]]])

# a 和 b 的最外面的维度都是 2,相同。
# 最里面两个维度分别是 3 × 2 和 2 × 2,那么乘完以后就是 3 × 2
torch.matmul(a, b).shape
Out[43]: torch.Size([2, 3, 2])