pytorch中的各种乘法

2,228 阅读4分钟

小知识,大挑战!本文正在参与“程序员必备小知识”创作活动。

那么多相乘,讲实话我到现在也没仔细梳理过,所以现在搞一下子。


首先声明一个向量和一个二维矩阵

import torch
vec = torch.arange(4)
mtx = torch.arange(12).reshape(4,3)
print(vec, mtx,sep='\n')

输出结果:

>> 
tensor([0, 1, 2, 3])
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])

按位置*

这个*在pytorch中是按位置相乘,存在广播机制。

import torch
vec = torch.arange(4)
mtx = torch.arange(12).reshape(4,3)
print(vec*vec)
print(mtx*mtx)
>>
tensor([0, 1, 4, 9])
tensor([[  0,   1,   4],
        [  9,  16,  25],
        [ 36,  49,  64],
        [ 81, 100, 121]])

但是需要注意的一点是虽然众多地方提到向量默认是列向量,但是在pytorch中一维的张量没有这种说法。 就算你用3×4的张量乘4×1的张量,得出的结果本应该是3×1的张量,但是因为是一维张量,也会变成默认的3(也不是1×3)。

可以执行的状态下print(mtx*vec)print(vec*mtx)的结果是完全一样的。

但是上面的例子中如果执行print(mtx*vec)或者print(vec*mtx)会报错。因为默认情况下,一维的张量和矩阵执行*操作的时候,一维张量中元素的个数必须和二维矩阵列数相同,否则广播功能失效。

当然也可以使用reshap()为其增加一个维度。但是增加维度之后要遵守一些维度规则。

import torch
vec = torch.arange(4).reshape(4,1) # 增加维度
mtx = torch.arange(12).reshape(4,3)
print(vec*mtx)
print(mtx*vec)
>>
tensor([[ 0,  0,  0],
        [ 3,  4,  5],
        [12, 14, 16],
        [27, 30, 33]])
tensor([[ 0,  0,  0],
        [ 3,  4,  5],
        [12, 14, 16],
        [27, 30, 33]])

比如上边矩阵是4×3的。

第二行代码你可以使用

  • vec = torch.arange(4).reshape(4,1)
  • vec = torch.arange(3).reshape(1,3)

就是说必须在行或者列上保持元素个数的一致。

数乘torch.mul

torch.mul(input, value, out=None)

用标量值value乘以输入input的每个元素,并返回一个新的结果张量。 就是张量的数乘运算。

import torch
vec = torch.arange(4)
mtx = torch.arange(12).reshape(3,4)
print(torch.mul(vec,2))
print(torch.mul(mtx,2))
>>
tensor([0, 2, 4, 6])
tensor([[ 0,  2,  4,  6],
        [ 8, 10, 12, 14],
        [16, 18, 20, 22]])

矩阵向量相乘torch.mv

torch.mv(mat, vec, out=None) → Tensor

对矩阵mat和向量vec进行相乘。 如果mat 是一个n×m张量,vec 是一个m元 1维张量,将会输出一个n 元 1维张量。

必须前边是矩阵后边是向量,维度要符合矩阵乘法。出来的是一维张量。

import torch
vec = torch.arange(4)
mtx = torch.arange(12).reshape(3,4)
print(torch.mv(mtx,vec))
>>
tensor([14, 38, 62])

矩阵乘法torch.mm

torch.mm(mat1, mat2, out=None) → Tensor

对矩阵mat1mat2进行相乘。 如果mat1 是一个n×m张量,mat2 是一个 m×p张量,将会输出一个 n×p张量out

就是我们线代中学的矩阵乘法,维度必须对应正确。

import torch
mtx = torch.arange(12)
m1 = mtx.reshape(3,4)
m2 = mtx.reshape(4,3)
print(torch.mm(m1, m2))
>>
tensor([[ 42,  48,  54],
        [114, 136, 158],
        [186, 224, 262]])

点乘积torch.dot

torch.dot(tensor1, tensor2) → float

计算两个张量的点乘积(内积),两个张量都为一维向量。

import torch
vec = torch.arange(4)
print(torch.dot(vec, vec))
>>
tensor(14)

黑科技@

还存在一个黑科技@,也是严格按照第一个参数的列数要等于第二个参数的行数。

import torch
vec = torch.arange(4)
mtx = torch.arange(12)
m1 = mtx.reshape(4,3)
m2 = mtx.reshape(3,4)

print(vec @ vec)
print(vec @ m1)
print(m2 @ vec)
print(m1 @ m2)
>>
tensor(14)
tensor([42, 48, 54])
tensor([14, 38, 62])
tensor([[ 20,  23,  26,  29],
        [ 56,  68,  80,  92],
        [ 92, 113, 134, 155],
        [128, 158, 188, 218]])

上边的结果可能不够直观,那看看下边:

import torch
vec = torch.arange(4)
mtx = torch.arange(12)
m1 = mtx.reshape(4,3)
m2 = mtx.reshape(3,4)

print(vec @ vec==torch.dot(vec,vec))
print(vec @ m1) # 本句直接使用torch.mv()无法执行。
print(m2 @ vec==torch.mv(m2,vec))
print(m1 @ m2==torch.mm(m1,m2))

使用一个@就可以替代上边的那三个函数。

  • 对一维张量执行@操作就是dot
  • 对一维和二维张量执行操作就是mv
  • 对二维张量执行@操作就是mm
>>
tensor(True)
tensor([42, 48, 54])
tensor([True, True, True])
tensor([[True, True, True, True],
        [True, True, True, True],
        [True, True, True, True],
        [True, True, True, True]])

第二个无法替换的怎么办?为了满足强迫症,可以这样:

import torch
vec = torch.arange(4)
mtx = torch.arange(12).reshape(4,3)

print(vec @ mtx) # 本句直接使用torch.mv()无法执行。
print(torch.mm(vec.reshape(1,4),mtx))

print(vec @ mtx==torch.mm(vec.reshape(1,4),mtx))
>>
tensor([42, 48, 54])
tensor([[42, 48, 54]])
tensor([[True, True, True]])

再加一个torch.matmul

vec = torch.arange(3)
mtx = torch.arange(12).reshape(3,4)
print(torch.matmul(vec,mtx))
print(torch.matmul(vec,vec))
print(torch.matmul(mtx.T,mtx))
print(torch.matmul(mtx.T,vec))
>>
tensor([20, 23, 26, 29])
tensor(5)
tensor([[ 80,  92, 104, 116],
        [ 92, 107, 122, 137],
        [104, 122, 140, 158],
        [116, 137, 158, 179]])
tensor([20, 23, 26, 29])

@看起来差不多,也是可以:

  • 对一维张量执行操作就是dot
  • 对一维和二维张量执行操作就是mv
  • 对二维张量执行操作就是mm

但是他们的区别在于matmul不知局限于一二维,可以进行高维张量的乘法。