在进行张量运算时,会存在两个tensor形状不一致但进行运算的情况,该广播机制需要遵循以下规则:
- 每个张量至少为一维张量
- 从后往前比较张量的形状,当前维度的大小要么相等,要么其中一个等于一,要么其中一个不存在
- 在满足前两条的前提下,从后往前比较时,
- 若相等,则继续往前比较,否则进入下一个判断,
- 若其中一个等于1,则等于1的张量在当前比较的轴上,重复n次(n=另一个张量在当前轴的维度数量),重复的轴为当前轴
- 若其中一个不存在,则不存在的张量在当前比较轴上,重复n(n=另一个张量在当前轴的维度数量),重复的轴为当前轴
代码如下:
创建张量
import numpy as np
m = np.arange(32).reshape([4,4,-1])
t = np.arange(4).reshape([4,-1])
m.shape # (4, 4, 2)
t.shape # (4, 1)
第一次比较:
# 第一次比较,倒数第一个轴上维度不一致,且t在这个轴上为维度数量为1,则按照m在这个轴上的维度数量进行重复
t1 = t.repeat(2,axis=1)
t1.shape # (4, 2)
第二次比较:
# 第二次比较,倒数第二个轴上维度一致,不需要重复
t2 = t1
第三次比较:
# 第三次,倒数第三个轴上维度不一致,且t2在该轴上没有数量,则将t升维
t3 = np.expand_dims(t2,axis=0).repeat(4,axis=0)
t3.shape # (4, 4, 2)
按照规则比较并重复后的运算结果:
m*t3
'''
array([[[ 0, 0], [ 2, 3], [ 8, 10], [18, 21]],
[[ 0, 0], [10, 11], [24, 26], [42, 45]],
[[ 0, 0], [18, 19], [40, 42], [66, 69]],
[[ 0, 0], [26, 27], [56, 58], [90, 93]]])
'''
直接运算的结果:
m*t
'''
array([[[ 0, 0], [ 2, 3], [ 8, 10], [18, 21]],
[[ 0, 0], [10, 11], [24, 26], [42, 45]],
[[ 0, 0], [18, 19], [40, 42], [66, 69]],
[[ 0, 0], [26, 27], [56, 58], [90, 93]]])
'''