PyTorch 多维 Tensor 乘法
可太难了
问题描述
有俩矩阵:x: torch.Size([1, 3])
和 y: torch.Size([3, 16, 256, 1024])
,看起来是直接用某种乘法就能搞定,但是我试过 torch.mm()
,torch.mul()
,torch.matmul()
,*
全都不行。
解决办法
后来才知道 torch.matmul
乘的机制很奇怪,我直接写怎么做吧:
mat1 = x.float()
mat2 = y.permute(2, 3, 0, 1).float()
res = torch.matmul(mat1, mat2).permute(2, 3, 0, 1)
才行