张量形状不一致时的广播机制

380 阅读2分钟

        在进行张量运算时,会存在两个tensor形状不一致但进行运算的情况,该广播机制需要遵循以下规则:

  1. 每个张量至少为一维张量
  2. 从后往前比较张量的形状,当前维度的大小要么相等,要么其中一个等于一,要么其中一个不存在
  3. 在满足前两条的前提下,从后往前比较时,
    1. 若相等,则继续往前比较,否则进入下一个判断,
    2. 若其中一个等于1,则等于1的张量在当前比较的轴上,重复n次(n=另一个张量在当前轴的维度数量),重复的轴为当前轴
    3. 若其中一个不存在,则不存在的张量在当前比较轴上,重复n(n=另一个张量在当前轴的维度数量),重复的轴为当前轴

代码如下:

创建张量

import numpy as np
m = np.arange(32).reshape([4,4,-1])
t = np.arange(4).reshape([4,-1])
m.shape # (4, 4, 2)
t.shape # (4, 1)

第一次比较:

# 第一次比较,倒数第一个轴上维度不一致,且t在这个轴上为维度数量为1,则按照m在这个轴上的维度数量进行重复
t1 = t.repeat(2,axis=1)
t1.shape #  (4, 2)

第二次比较:

# 第二次比较,倒数第二个轴上维度一致,不需要重复
t2 = t1

第三次比较:

# 第三次,倒数第三个轴上维度不一致,且t2在该轴上没有数量,则将t升维
t3 = np.expand_dims(t2,axis=0).repeat(4,axis=0)
t3.shape # (4, 4, 2)

按照规则比较并重复后的运算结果:

m*t3
'''
array([[[ 0, 0], [ 2, 3], [ 8, 10], [18, 21]], 
       [[ 0, 0], [10, 11], [24, 26], [42, 45]], 
       [[ 0, 0], [18, 19], [40, 42], [66, 69]], 
       [[ 0, 0], [26, 27], [56, 58], [90, 93]]])
'''

直接运算的结果:

m*t
'''
array([[[ 0, 0], [ 2, 3], [ 8, 10], [18, 21]], 
       [[ 0, 0], [10, 11], [24, 26], [42, 45]], 
       [[ 0, 0], [18, 19], [40, 42], [66, 69]], 
       [[ 0, 0], [26, 27], [56, 58], [90, 93]]])
'''