PyTorch 多维 Tensor 乘法

85 阅读1分钟

PyTorch 多维 Tensor 乘法

可太难了

问题描述

有俩矩阵:x: torch.Size([1, 3])y: torch.Size([3, 16, 256, 1024]),看起来是直接用某种乘法就能搞定,但是我试过 torch.mm()torch.mul()torch.matmul()* 全都不行。

解决办法

后来才知道 torch.matmul 乘的机制很奇怪,我直接写怎么做吧:

mat1 = x.float()
mat2 = y.permute(2, 3, 0, 1).float()
res = torch.matmul(mat1, mat2).permute(2, 3, 0, 1)

才行