PyTorch | 广播机制

128 阅读1分钟

什么是广播机制

两个形状不同的Tensor按元素运算时,可能会触发广播机制:先适当复制元素使这两个Tensor形状相同后再按元素运算。

广播条件

  1. 两个Tensor都至少有一个维度
# 这种情况不行,x不满足条件
x = torch.zeros(0)
y = torch.randn(2,2)
x + y
  1. 从右往左看两个Tensor的每一个维度,对应的两个维度满足以下条件之一:

    a. 这两个维度大小相等 b. 某个维度一个Tensor有,另一个Tensor没有 c. 某个维度一个Tensor有,另一个Tensor大小为1

例如:

x = torch.randn(5,3,4,1)
y = torch.randn(  3,1,1)
x + y

从右往左看,两个Tensor在维度3上大小都是1,满足条件a;在维度2上大小分别是4和1,满足条件c;在维度1上大小相等,满足条件a;在维度0上满足条件b,因此可以广播。两个Tensor维度从右往左看,如果出现两个Tensor在某个维度上大小不相等且两个维度大小都不为1,那么这两个张量一定不能广播。

怎样广播

先将条件b转化为条件c,即在缺失维度的位置上新增一个维度,大小为1。再将大小为1的维度广播的和对应维度一样大。

# 初始
x = torch.randn(5,3,4,1)
y = torch.randn(  3,1,1)
# 第一步
x = torch.randn(5,3,4,1)
y = torch.randn(1,3,1,1)
# 第二步
x = torch.randn(5,3,4,1)
y = torch.randn(5,3,1,1)