什么是广播机制
两个形状不同的Tensor按元素运算时,可能会触发广播机制:先适当复制元素使这两个Tensor形状相同后再按元素运算。
广播条件
- 两个Tensor都至少有一个维度
# 这种情况不行,x不满足条件
x = torch.zeros(0)
y = torch.randn(2,2)
x + y
-
从右往左看两个Tensor的每一个维度,对应的两个维度满足以下条件之一:
a. 这两个维度大小相等 b. 某个维度一个Tensor有,另一个Tensor没有 c. 某个维度一个Tensor有,另一个Tensor大小为1
例如:
x = torch.randn(5,3,4,1)
y = torch.randn( 3,1,1)
x + y
从右往左看,两个Tensor在维度3上大小都是1,满足条件a;在维度2上大小分别是4和1,满足条件c;在维度1上大小相等,满足条件a;在维度0上满足条件b,因此可以广播。两个Tensor维度从右往左看,如果出现两个Tensor在某个维度上大小不相等且两个维度大小都不为1,那么这两个张量一定不能广播。
怎样广播
先将条件b转化为条件c,即在缺失维度的位置上新增一个维度,大小为1。再将大小为1的维度广播的和对应维度一样大。
# 初始
x = torch.randn(5,3,4,1)
y = torch.randn( 3,1,1)
# 第一步
x = torch.randn(5,3,4,1)
y = torch.randn(1,3,1,1)
# 第二步
x = torch.randn(5,3,4,1)
y = torch.randn(5,3,1,1)