在查看pytorch实现的Swin Transformer源码中,对create_mask函数有些不太理解,可能是因为自己python基础差的原因,导致有些代码看不懂。后面经过查看资料才弄懂了。在此记录下自己学习的过程。
我学习的是B站up主霹雳吧啦Wz视频中提供的代码,代码网址如下:pytorch_classification/swin_transformer。 在源码的第431-455行实现了对create_mask函数的创建。以下为代码的截图:
比如输入的feature map是99大小的,window_size是33大小的。 shift_size是[window_size/2]并取整,所以在这shift_size=3//2=1。
shift_size是窗口移动的大小,通过以下操作完成原始feature map的shift。
x = torch.roll(shifted_x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
执行完SW-MSA后,要将移动后的窗口还原回去。
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
代码434行和435行先创建Hp和Wp,Hp和Wp是mask的高和宽。必须确保mask的高和宽是window_size的整数倍。第437行把mask先创建为shape为[1, Hp, Wp, 1]的全为0的张量。438-440行就是在h方向进行切片操作:
438行是在h方向从0开始,切到-window_size(-3),注意是左闭右开,所以是从0切到-3上面一个。439行是在h方向从-window_size(-3)切片到-shift_size(-1),是左闭右开。所以是从-3切到-2。440行是从-shift_size(-1)切到最后。
w方向与h方向切片操作一模一样,因此,w和h切片后的feature map如下图所示:
代码444行到448行是对feature map按照划分的区域进行数字的填充(相同区域填充相同的数字)。填充后的feature map如图所示:
代码450行把填充后的feature map划分为33的窗口。如下图:
代码451行把上图feature map按照每个窗口展平,展平后shape为[num_windows,window_size[0]window_size[1]]=[9,9]。展平后的mask_windows如下:
452行,mask_windows.unsqueeze(1)是在第二维度新增一个维度,shape变为[9,1,9];mask_windows.unsqueeze(2)是在第三维度新增一个维度,shape变为[9,9,1]。
要想二者可以相减,必须二者要相同维度[9,9,9]才能进行运算。因此,需要使用广播机制,把mask_windows.unsqueeze(1)和mask_windows.unsqueeze(2)都拓展为[9,9,9]。
因画图比较繁琐,我只取[4 4 5 4 4 5 7 8 8]也就是对feature map进行窗口划分的最后一个区域进行演示:
这个33的区域包含4个区域,不同区域不进行自注意力机制(使不同部分输出的softmax()=0)
将[4 4 5 4 4 5 7 8 8]分别在第二维度复制9次,在第三维度复制9次,得下图:
二者相减所得的结果如下图:
第一个深度方向即为33区域第一个元素与其他元素的区域关系(相同区域为0,不同区域不为0);第二个深度方向即为33区域第二个元素与其他元素的区域关系(相同区域为0,不同区域不为0)...。要使不同区域间softmax()=0,那么应为一个很大的负数。
第454行,将所得结果为0的地方记为0,不为0的地方全部赋值为-100。然后将mask与计算的各相加,-100的作用为的是使不同区域的值经过softmax后值为0(屏蔽不同区域的自注意力机制)。原理如下图,图直接拿的是B站up主霹雳吧啦Wz讲解的ppt。