MaxViT: Multi-Axis Vision Transformer论文浅析
开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 13 天,点击查看活动详情
1、MaxViT主体结构与创新点
1.1 研究动机
卷积神经网络经历了从AlexNet到ResNet再到Vision Transformer,其在计算机视觉任务中的表现越来越好,通过注意力机制,Vision Transformer取得了非常好的效果。然而,在没有充分的预训练情况下,Vision Transformer通常不会取得很好的效果,并且由于注意力算子需要二次复杂度,因此在层次网络的早期或高分辨率阶段通过完全注意力进行全局交互的计算量很大。如何有效地结合全局和局部交互,在计算预算下平衡模型大小和可推广性仍然是一个具有挑战性的问题。
在该论文中,作者提出了一种新型的Transformer模块,称为多轴自注意力(multi-axis self-attention, Max-SA),它可以作为基本的架构组件,在单个块中执行局部和全局空间交互。与完全自注意力相比,Max-SA具有更大的灵活性和效率,即自然适应不同的输入长度,具有线性复杂度。此外,Max-SA仅具有线性复杂度,可以用作网络任何层的通用独立注意力模块,增加少量的计算量。
其主要创新点包含如下三点:
- MaxViT是一个通用的Transformer结构,在每一个块内都可以实现局部与全局之间的空间交互,同时可适应不同分辨率的输入大小。
- Max-SA通过分解空间轴得到窗口注意力(Block attention)与网格注意力(Grid attention),将传统计算方法的二次复杂度降到线性复杂度。
- MBConv作为自注意力计算的补充,利用其固有的归纳偏差来提升模型的泛化能力,避免陷入过拟合。
1.2 Max-SA主要结构
作者通过引入Max-SA模块,将传统的自注意机制分解为窗格注意力(Block attention)与网格注意力(Grid attention)两种稀疏形式,在不损失非局部性的情况下,将传统注意力机制的计算复杂度从二次复杂度降低到线性。并且Max-SA具有灵活性和可伸缩性,我们可以通过简单地将Max-SA与MBConv在分层体系结构中叠加,从而构建一个称为MaxViT的视觉Backbone,MaxViT主要结构如上图2所示。
class MaxViT(nn.Layer):
def __init__(self, args):
super().__init__()
self.conv_stem = nn.Sequential(nn.Conv2D(args['input_dim'], args['stem_dim'], 3,2,3//2),
nn.BatchNorm2D(args['stem_dim']),
nn.GELU(),
nn.Conv2D(args['stem_dim'], args['stem_dim'], 3,1,3//2),
nn.BatchNorm2D(args['stem_dim']),
nn.GELU())
in_dim = args['stem_dim']
self.max_blocks = nn.LayerList([])
for i,num_block in enumerate(args['stage_num_block']):
layers = nn.LayerList([])
out_dim = args['stage_dim']*(2**i)
num_head = args['num_heads']*(2**i)
for i in range(num_block):
pooling_size = args['pooling_size']if i == 0 else 1
layers.append(Max_Block(in_dim,out_dim,num_head,args['block_size'],
args['grid_size'],args['mbconv_ksize'],pooling_size,
args['mbconv_expand_rate'],args['se_rate'],args['mlp_ratio'],
args['qkv_bias'],args['qk_scale'], args['drop'], args['attn_drop'],
args['drop_path'],args['act_layer'] ,args['norm_layer']))
in_dim = out_dim
self.max_blocks.append(layers)
self.last_conv = nn.Sequential(nn.Conv2D(in_dim,in_dim,1,),
nn.BatchNorm2D(in_dim),
nn.GELU())
self.proj = nn.Linear(in_dim,args['num_classes'])
self.softmax = nn.Softmax(1)
def forward(self, x):
x = self.conv_stem(x)
for blocks in self.max_blocks:
for block in blocks:
x = block(x)
x = self.last_conv(x)
x = self.softmax(self.proj(x.mean([2, 3])))
return x
1.3 Multi-axis Attention 详解
与局部卷积相比,全局相互作用是自注意力机制的优势之一。然而,直接将注意力应用于整个空间在计算上是不可行的,因为注意力算子需要二次复杂度,为了解决全局自注意力机制导致的二次计算复杂度,作者通过分解空间轴得到局部(block attention)与全局(grid attention)两种稀疏形式,巧妙的解决了计算复杂度的问题。如上所示,Max-SA模块主要包含Block Attention与Grid Attention两个部分。
class Max_Block(nn.Layer):
def __init__(self, in_dim, out_dim , num_heads=8.,block_size=(7,7), grid_size=(7,7),
mbconv_ksize = 3,pooling_size = 1,mbconv_expand_rate=4,se_reduce_rate=0.25,
mlp_ratio=4,qkv_bias=False,qk_scale=None, drop=0., attn_drop=0.,drop_path=0.,
act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
super().__init__()
self.mbconv = MBConv(in_dim,out_dim,mbconv_ksize,pooling_size,mbconv_expand_rate,se_reduce_rate,drop)
self.block_attn = Window_Block(out_dim, block_size, num_heads, mlp_ratio, qkv_bias,qk_scale, drop,
attn_drop,drop_path, act_layer ,norm_layer)
self.grid_attn = Grid_Block(out_dim, grid_size, num_heads, mlp_ratio, qkv_bias,qk_scale, drop,
attn_drop,drop_path, act_layer ,norm_layer)
def forward(self, x):
x = self.mbconv(x)
x = self.block_attn(x)
x = self.grid_attn(x)
return x
- Block Attention
将输入特征图划分为不重叠的窗口, 最后在每一个窗口中执行自注意力计算。虽然避免了全局自注意力机制的复杂计算,但是局部注意模型已经被证明不适用于大规模的数据集。所以作者提出一种稀疏的全局自注意力机制,被称作grid attention(网格注意力机制)。
class Window_Block(nn.Layer):
def __init__(self, dim, block_size=(7,7), num_heads=8, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0.,
attn_drop=0.,drop_path=0., act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
super().__init__()
self.block_size = block_size
mlp_hidden_dim = int(dim * mlp_ratio)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.attn = Rel_Attention(dim, block_size, num_heads, qkv_bias, qk_scale, attn_drop, drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self,x):
assert x.shape[2]%self.block_size[0] == 0 & x.shape[3]%self.block_size[1] == 0, 'image size should be divisible by block_size'
out = block(self.norm1(x),self.block_size)
out = self.attn(out)
x = x + self.drop_path(unblock(self.attn(out)))
out = self.mlp(self.norm2(x))
x = x + self.drop_path(out)
return x
- Grid Attention
不同于传统使用固定窗口大小来划分特征图的操作,grid attention 使用固定的大小的均匀网格将输人张量网格化, 可以有效平衡局部和全局之间的计算 (且仅具有线性复杂度)。
class Grid_Block(nn.Layer):
def __init__(self, dim, grid_size=(7,7), num_heads=8, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0.,
attn_drop=0.,drop_path=0., act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
super().__init__()
self.grid_size = grid_size
mlp_hidden_dim = int(dim * mlp_ratio)
self.norm1 = norm_layer(dim)
self.norm2 = norm_layer(dim)
self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.attn = Rel_Attention(dim, grid_size, num_heads, qkv_bias, qk_scale, attn_drop, drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self,x):
assert x.shape[2]%self.grid_size[0] == 0 & x.shape[3]%self.grid_size[1] == 0, 'image size should be divisible by grid_size'
grid_size = (x.shape[2]//self.grid_size[0], x.shape[3]//self.grid_size[1])
out = block(self.norm1(x),grid_size)
out = out.transpose([0,4,5,3,1,2])
out = self.attn(out).transpose([0,4,5,3,1,2])
x = x + self.drop_path(unblock(out))
out = self.mlp(self.norm2(x))
x = x + self.drop_path(out)
return x
1.4 MBConv
为了获得更丰富的特征表示,首先使用逐点卷积进行通道升维,在升维后的投影空间中进行Depth-wise卷积,紧随其后的SE用于增强重要通道的表征,最后再次使用逐点卷积恢复维度。可用如下公式表示:
对于每个阶段的第一个MBConv块,下采样是通过应用stride=2的深度可分离卷积( Depthwise Conv3x3)来完成的,而残差连接分支也 应用pooling 和 channel 映射:
MBConv包含如下特点:
- 采用了Depthwise Convlution,因此相比于传统卷积,Depthwise Conv的参数能够大大减少;
- 采用了“倒瓶颈”的结构,也就是说在卷积过程中,特征经历了升维和降维两个步骤,并利用卷积固有的归纳偏置,在一定程度上提升模型的泛化能力与可训练性。
- 相比于ViT中的显式位置编码,在Multi-axis Attention则使用MBConv来代替,这是因为深度可分离卷积可被视为条件位置编码(CPE)。
class MBConv(nn.Layer):
def __init__(self,in_dim,out_dim,kernel_size=3,stride_size=1,expand_rate = 4,se_rate = 0.25,dropout = 0.):
super().__init__()
hidden_dim = int(expand_rate * out_dim)
self.bn = nn.BatchNorm2D(in_dim)
self.expand_conv = nn.Sequential(nn.Conv2D(in_dim, hidden_dim, 1),
nn.BatchNorm2D(hidden_dim),
nn.GELU())
self.dw_conv = nn.Sequential(nn.Conv2D(hidden_dim, hidden_dim, kernel_size, stride_size, kernel_size//2, groups=hidden_dim),
nn.BatchNorm2D(hidden_dim),
nn.GELU())
self.se = SE(hidden_dim,max(1,int(out_dim*se_rate)))
self.out_conv = nn.Sequential(nn.Conv2D(hidden_dim, out_dim, 1),
nn.BatchNorm2D(out_dim))
if stride_size > 1:
self.proj = nn.Sequential(nn.MaxPool2D(kernel_size, stride_size, kernel_size//2),
nn.Conv2D(in_dim, out_dim, 1))
else:
self.proj = nn.Identity()
def forward(self, x):
out = self.bn(x)
out = self.expand_conv(out)
out = self.dw_conv(out)
out = self.se(out)
out = self.out_conv(out)
return out + self.proj(x)
1.5 Multi-Axis attention与Axial attention区别
论文所提出的方法不同于 Axial attention。如图 3 所示, 在 Axial attention 中 首先使用列注意力(column-wise attention),然后使用行注意力( row-wise attention) 来计算全局 注意力, 。然而 Multi-Axis attention 则先采用局部注意力 (block attention), 再使用稀疏的全局注意力 (grid attention), 这样的 设计充分考虑了图像的 2D 结构。
2、整体网络结构复现
在论文中,作者基于Max-SA模块搭建了四种网络结构(MaxViT model family (T/S/B/L)),本项目对这四种结构均进行了复现,其网络结构列表如下:
3、网络模型结构输出
import paddle
from maxvit import MaxViT,tiny_args
print(MaxViT(tiny_args)(paddle.zeros([2,3,224,224])).shape)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:654: UserWarning: When training, we now always track global mean and variance.
"When training, we now always track global mean and variance.")
[2, 1000]
model = MaxViT(tiny_args)
paddle.summary(model,(1,3,224,224))
-----------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===================================================================================
Conv2D-301 [[1, 3, 224, 224]] [1, 64, 112, 112] 1,792
BatchNorm2D-95 [[1, 64, 112, 112]] [1, 64, 112, 112] 256
GELU-117 [[1, 64, 112, 112]] [1, 64, 112, 112] 0
Conv2D-302 [[1, 64, 112, 112]] [1, 64, 112, 112] 36,928
BatchNorm2D-96 [[1, 64, 112, 112]] [1, 64, 112, 112] 256
GELU-118 [[1, 64, 112, 112]] [1, 64, 112, 112] 0
BatchNorm2D-97 [[1, 64, 112, 112]] [1, 64, 112, 112] 256
Conv2D-303 [[1, 64, 112, 112]] [1, 256, 112, 112] 16,640
BatchNorm2D-98 [[1, 256, 112, 112]] [1, 256, 112, 112] 1,024
GELU-119 [[1, 256, 112, 112]] [1, 256, 112, 112] 0
Conv2D-304 [[1, 256, 112, 112]] [1, 256, 56, 56] 2,560
BatchNorm2D-99 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
GELU-120 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-305 [[1, 256, 1, 1]] [1, 16, 1, 1] 4,112
GELU-121 [[1, 16, 1, 1]] [1, 16, 1, 1] 0
Conv2D-306 [[1, 16, 1, 1]] [1, 256, 1, 1] 4,352
Sigmoid-23 [[1, 256, 1, 1]] [1, 256, 1, 1] 0
SE-23 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-307 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,448
BatchNorm2D-100 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
MaxPool2D-9 [[1, 64, 112, 112]] [1, 64, 56, 56] 0
Conv2D-308 [[1, 64, 56, 56]] [1, 64, 56, 56] 4,160
MBConv-23 [[1, 64, 112, 112]] [1, 64, 56, 56] 0
LayerNorm-89 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-89 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-311 [[64, 64, 7, 7]] [64, 192, 7, 7] 12,480
Dropout-135 [[64, 2, 49, 49]] [64, 2, 49, 49] 0
Conv2D-312 [[64, 64, 7, 7]] [64, 64, 7, 7] 4,160
Dropout-134 [[64, 64, 7, 7]] [64, 64, 7, 7] 0
Rel_Attention-45 [[1, 8, 8, 64, 7, 7]] [1, 8, 8, 64, 7, 7] 338
Identity-59 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-90 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-90 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-309 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,640
GELU-122 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Dropout-133 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-310 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,448
Mlp-45 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Window_Block-23 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-91 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-91 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-315 [[64, 64, 7, 7]] [64, 192, 7, 7] 12,480
Dropout-138 [[64, 2, 49, 49]] [64, 2, 49, 49] 0
Conv2D-316 [[64, 64, 7, 7]] [64, 64, 7, 7] 4,160
Dropout-137 [[64, 64, 7, 7]] [64, 64, 7, 7] 0
Rel_Attention-46 [[1, 8, 8, 64, 7, 7]] [1, 8, 8, 64, 7, 7] 338
Identity-60 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-92 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-92 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-313 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,640
GELU-123 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Dropout-136 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-314 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,448
Mlp-46 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Grid_Block-23 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Max_Block-23 [[1, 64, 112, 112]] [1, 64, 56, 56] 0
BatchNorm2D-101 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-317 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,640
BatchNorm2D-102 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
GELU-124 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-318 [[1, 256, 56, 56]] [1, 256, 56, 56] 2,560
BatchNorm2D-103 [[1, 256, 56, 56]] [1, 256, 56, 56] 1,024
GELU-125 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-319 [[1, 256, 1, 1]] [1, 16, 1, 1] 4,112
GELU-126 [[1, 16, 1, 1]] [1, 16, 1, 1] 0
Conv2D-320 [[1, 16, 1, 1]] [1, 256, 1, 1] 4,352
Sigmoid-24 [[1, 256, 1, 1]] [1, 256, 1, 1] 0
SE-24 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Conv2D-321 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,448
BatchNorm2D-104 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Identity-61 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
MBConv-24 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-93 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-93 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-324 [[64, 64, 7, 7]] [64, 192, 7, 7] 12,480
Dropout-141 [[64, 2, 49, 49]] [64, 2, 49, 49] 0
Conv2D-325 [[64, 64, 7, 7]] [64, 64, 7, 7] 4,160
Dropout-140 [[64, 64, 7, 7]] [64, 64, 7, 7] 0
Rel_Attention-47 [[1, 8, 8, 64, 7, 7]] [1, 8, 8, 64, 7, 7] 338
Identity-62 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-94 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-94 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-322 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,640
GELU-127 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Dropout-139 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-323 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,448
Mlp-47 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Window_Block-24 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-95 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-95 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-328 [[64, 64, 7, 7]] [64, 192, 7, 7] 12,480
Dropout-144 [[64, 2, 49, 49]] [64, 2, 49, 49] 0
Conv2D-329 [[64, 64, 7, 7]] [64, 64, 7, 7] 4,160
Dropout-143 [[64, 64, 7, 7]] [64, 64, 7, 7] 0
Rel_Attention-48 [[1, 8, 8, 64, 7, 7]] [1, 8, 8, 64, 7, 7] 338
Identity-63 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
LayerNorm-96 [[1, 56, 56, 64]] [1, 56, 56, 64] 128
Channel_Layernorm-96 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-326 [[1, 64, 56, 56]] [1, 256, 56, 56] 16,640
GELU-128 [[1, 256, 56, 56]] [1, 256, 56, 56] 0
Dropout-142 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Conv2D-327 [[1, 256, 56, 56]] [1, 64, 56, 56] 16,448
Mlp-48 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Grid_Block-24 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
Max_Block-24 [[1, 64, 56, 56]] [1, 64, 56, 56] 0
BatchNorm2D-105 [[1, 64, 56, 56]] [1, 64, 56, 56] 256
Conv2D-330 [[1, 64, 56, 56]] [1, 512, 56, 56] 33,280
BatchNorm2D-106 [[1, 512, 56, 56]] [1, 512, 56, 56] 2,048
GELU-129 [[1, 512, 56, 56]] [1, 512, 56, 56] 0
Conv2D-331 [[1, 512, 56, 56]] [1, 512, 28, 28] 5,120
BatchNorm2D-107 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
GELU-130 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-332 [[1, 512, 1, 1]] [1, 32, 1, 1] 16,416
GELU-131 [[1, 32, 1, 1]] [1, 32, 1, 1] 0
Conv2D-333 [[1, 32, 1, 1]] [1, 512, 1, 1] 16,896
Sigmoid-25 [[1, 512, 1, 1]] [1, 512, 1, 1] 0
SE-25 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-334 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,664
BatchNorm2D-108 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
MaxPool2D-10 [[1, 64, 56, 56]] [1, 64, 28, 28] 0
Conv2D-335 [[1, 64, 28, 28]] [1, 128, 28, 28] 8,320
MBConv-25 [[1, 64, 56, 56]] [1, 128, 28, 28] 0
LayerNorm-97 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-97 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-338 [[16, 128, 7, 7]] [16, 384, 7, 7] 49,536
Dropout-147 [[16, 4, 49, 49]] [16, 4, 49, 49] 0
Conv2D-339 [[16, 128, 7, 7]] [16, 128, 7, 7] 16,512
Dropout-146 [[16, 128, 7, 7]] [16, 128, 7, 7] 0
Rel_Attention-49 [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7] 676
Identity-64 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-98 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-98 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-336 [[1, 128, 28, 28]] [1, 512, 28, 28] 66,048
GELU-132 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Dropout-145 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-337 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,664
Mlp-49 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Window_Block-25 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-99 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-99 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-342 [[16, 128, 7, 7]] [16, 384, 7, 7] 49,536
Dropout-150 [[16, 4, 49, 49]] [16, 4, 49, 49] 0
Conv2D-343 [[16, 128, 7, 7]] [16, 128, 7, 7] 16,512
Dropout-149 [[16, 128, 7, 7]] [16, 128, 7, 7] 0
Rel_Attention-50 [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7] 676
Identity-65 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-100 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-100 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-340 [[1, 128, 28, 28]] [1, 512, 28, 28] 66,048
GELU-133 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Dropout-148 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-341 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,664
Mlp-50 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Grid_Block-25 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Max_Block-25 [[1, 64, 56, 56]] [1, 128, 28, 28] 0
BatchNorm2D-109 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-344 [[1, 128, 28, 28]] [1, 512, 28, 28] 66,048
BatchNorm2D-110 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
GELU-134 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-345 [[1, 512, 28, 28]] [1, 512, 28, 28] 5,120
BatchNorm2D-111 [[1, 512, 28, 28]] [1, 512, 28, 28] 2,048
GELU-135 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-346 [[1, 512, 1, 1]] [1, 32, 1, 1] 16,416
GELU-136 [[1, 32, 1, 1]] [1, 32, 1, 1] 0
Conv2D-347 [[1, 32, 1, 1]] [1, 512, 1, 1] 16,896
Sigmoid-26 [[1, 512, 1, 1]] [1, 512, 1, 1] 0
SE-26 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Conv2D-348 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,664
BatchNorm2D-112 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Identity-66 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
MBConv-26 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-101 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-101 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-351 [[16, 128, 7, 7]] [16, 384, 7, 7] 49,536
Dropout-153 [[16, 4, 49, 49]] [16, 4, 49, 49] 0
Conv2D-352 [[16, 128, 7, 7]] [16, 128, 7, 7] 16,512
Dropout-152 [[16, 128, 7, 7]] [16, 128, 7, 7] 0
Rel_Attention-51 [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7] 676
Identity-67 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-102 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-102 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-349 [[1, 128, 28, 28]] [1, 512, 28, 28] 66,048
GELU-137 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Dropout-151 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-350 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,664
Mlp-51 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Window_Block-26 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-103 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-103 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-355 [[16, 128, 7, 7]] [16, 384, 7, 7] 49,536
Dropout-156 [[16, 4, 49, 49]] [16, 4, 49, 49] 0
Conv2D-356 [[16, 128, 7, 7]] [16, 128, 7, 7] 16,512
Dropout-155 [[16, 128, 7, 7]] [16, 128, 7, 7] 0
Rel_Attention-52 [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7] 676
Identity-68 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
LayerNorm-104 [[1, 28, 28, 128]] [1, 28, 28, 128] 256
Channel_Layernorm-104 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-353 [[1, 128, 28, 28]] [1, 512, 28, 28] 66,048
GELU-138 [[1, 512, 28, 28]] [1, 512, 28, 28] 0
Dropout-154 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Conv2D-354 [[1, 512, 28, 28]] [1, 128, 28, 28] 65,664
Mlp-52 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Grid_Block-26 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
Max_Block-26 [[1, 128, 28, 28]] [1, 128, 28, 28] 0
BatchNorm2D-113 [[1, 128, 28, 28]] [1, 128, 28, 28] 512
Conv2D-357 [[1, 128, 28, 28]] [1, 1024, 28, 28] 132,096
BatchNorm2D-114 [[1, 1024, 28, 28]] [1, 1024, 28, 28] 4,096
GELU-139 [[1, 1024, 28, 28]] [1, 1024, 28, 28] 0
Conv2D-358 [[1, 1024, 28, 28]] [1, 1024, 14, 14] 10,240
BatchNorm2D-115 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-140 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-359 [[1, 1024, 1, 1]] [1, 64, 1, 1] 65,600
GELU-141 [[1, 64, 1, 1]] [1, 64, 1, 1] 0
Conv2D-360 [[1, 64, 1, 1]] [1, 1024, 1, 1] 66,560
Sigmoid-27 [[1, 1024, 1, 1]] [1, 1024, 1, 1] 0
SE-27 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-361 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
BatchNorm2D-116 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
MaxPool2D-11 [[1, 128, 28, 28]] [1, 128, 14, 14] 0
Conv2D-362 [[1, 128, 14, 14]] [1, 256, 14, 14] 33,024
MBConv-27 [[1, 128, 28, 28]] [1, 256, 14, 14] 0
LayerNorm-105 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-105 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-365 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-159 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-366 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-158 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-53 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-69 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-106 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-106 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-363 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-142 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-157 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-364 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-53 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Window_Block-27 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-107 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-107 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-369 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-162 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-370 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-161 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-54 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-70 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-108 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-108 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-367 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-143 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-160 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-368 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-54 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Grid_Block-27 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Max_Block-27 [[1, 128, 28, 28]] [1, 256, 14, 14] 0
BatchNorm2D-117 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-371 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
BatchNorm2D-118 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-144 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-372 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 10,240
BatchNorm2D-119 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-145 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-373 [[1, 1024, 1, 1]] [1, 64, 1, 1] 65,600
GELU-146 [[1, 64, 1, 1]] [1, 64, 1, 1] 0
Conv2D-374 [[1, 64, 1, 1]] [1, 1024, 1, 1] 66,560
Sigmoid-28 [[1, 1024, 1, 1]] [1, 1024, 1, 1] 0
SE-28 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-375 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
BatchNorm2D-120 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Identity-71 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
MBConv-28 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-109 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-109 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-378 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-165 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-379 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-164 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-55 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-72 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-110 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-110 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-376 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-147 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-163 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-377 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-55 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Window_Block-28 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-111 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-111 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-382 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-168 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-383 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-167 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-56 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-73 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-112 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-112 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-380 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-148 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-166 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-381 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-56 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Grid_Block-28 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Max_Block-28 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
BatchNorm2D-121 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-384 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
BatchNorm2D-122 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-149 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-385 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 10,240
BatchNorm2D-123 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-150 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-386 [[1, 1024, 1, 1]] [1, 64, 1, 1] 65,600
GELU-151 [[1, 64, 1, 1]] [1, 64, 1, 1] 0
Conv2D-387 [[1, 64, 1, 1]] [1, 1024, 1, 1] 66,560
Sigmoid-29 [[1, 1024, 1, 1]] [1, 1024, 1, 1] 0
SE-29 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-388 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
BatchNorm2D-124 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Identity-74 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
MBConv-29 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-113 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-113 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-391 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-171 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-392 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-170 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-57 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-75 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-114 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-114 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-389 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-152 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-169 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-390 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-57 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Window_Block-29 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-115 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-115 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-395 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-174 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-396 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-173 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-58 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-76 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-116 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-116 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-393 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-153 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-172 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-394 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-58 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Grid_Block-29 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Max_Block-29 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
BatchNorm2D-125 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-397 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
BatchNorm2D-126 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-154 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-398 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 10,240
BatchNorm2D-127 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-155 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-399 [[1, 1024, 1, 1]] [1, 64, 1, 1] 65,600
GELU-156 [[1, 64, 1, 1]] [1, 64, 1, 1] 0
Conv2D-400 [[1, 64, 1, 1]] [1, 1024, 1, 1] 66,560
Sigmoid-30 [[1, 1024, 1, 1]] [1, 1024, 1, 1] 0
SE-30 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-401 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
BatchNorm2D-128 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Identity-77 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
MBConv-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-117 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-117 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-404 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-177 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-405 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-176 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-59 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-78 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-118 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-118 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-402 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-157 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-175 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-403 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-59 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Window_Block-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-119 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-119 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-408 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-180 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-409 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-179 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-60 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-79 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-120 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-120 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-406 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-158 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-178 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-407 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-60 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Grid_Block-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Max_Block-30 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
BatchNorm2D-129 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-410 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
BatchNorm2D-130 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-159 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-411 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 10,240
BatchNorm2D-131 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 4,096
GELU-160 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-412 [[1, 1024, 1, 1]] [1, 64, 1, 1] 65,600
GELU-161 [[1, 64, 1, 1]] [1, 64, 1, 1] 0
Conv2D-413 [[1, 64, 1, 1]] [1, 1024, 1, 1] 66,560
Sigmoid-31 [[1, 1024, 1, 1]] [1, 1024, 1, 1] 0
SE-31 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Conv2D-414 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
BatchNorm2D-132 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Identity-80 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
MBConv-31 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-121 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-121 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-417 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-183 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-418 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-182 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-61 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-81 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-122 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-122 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-415 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-162 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-181 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-416 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-61 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Window_Block-31 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-123 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-123 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-421 [[4, 256, 7, 7]] [4, 768, 7, 7] 197,376
Dropout-186 [[4, 8, 49, 49]] [4, 8, 49, 49] 0
Conv2D-422 [[4, 256, 7, 7]] [4, 256, 7, 7] 65,792
Dropout-185 [[4, 256, 7, 7]] [4, 256, 7, 7] 0
Rel_Attention-62 [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7] 1,352
Identity-82 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
LayerNorm-124 [[1, 14, 14, 256]] [1, 14, 14, 256] 512
Channel_Layernorm-124 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-419 [[1, 256, 14, 14]] [1, 1024, 14, 14] 263,168
GELU-163 [[1, 1024, 14, 14]] [1, 1024, 14, 14] 0
Dropout-184 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Conv2D-420 [[1, 1024, 14, 14]] [1, 256, 14, 14] 262,400
Mlp-62 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Grid_Block-31 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
Max_Block-31 [[1, 256, 14, 14]] [1, 256, 14, 14] 0
BatchNorm2D-133 [[1, 256, 14, 14]] [1, 256, 14, 14] 1,024
Conv2D-423 [[1, 256, 14, 14]] [1, 2048, 14, 14] 526,336
BatchNorm2D-134 [[1, 2048, 14, 14]] [1, 2048, 14, 14] 8,192
GELU-164 [[1, 2048, 14, 14]] [1, 2048, 14, 14] 0
Conv2D-424 [[1, 2048, 14, 14]] [1, 2048, 7, 7] 20,480
BatchNorm2D-135 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
GELU-165 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-425 [[1, 2048, 1, 1]] [1, 128, 1, 1] 262,272
GELU-166 [[1, 128, 1, 1]] [1, 128, 1, 1] 0
Conv2D-426 [[1, 128, 1, 1]] [1, 2048, 1, 1] 264,192
Sigmoid-32 [[1, 2048, 1, 1]] [1, 2048, 1, 1] 0
SE-32 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-427 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,049,088
BatchNorm2D-136 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
MaxPool2D-12 [[1, 256, 14, 14]] [1, 256, 7, 7] 0
Conv2D-428 [[1, 256, 7, 7]] [1, 512, 7, 7] 131,584
MBConv-32 [[1, 256, 14, 14]] [1, 512, 7, 7] 0
LayerNorm-125 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-125 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-431 [[1, 512, 7, 7]] [1, 1536, 7, 7] 787,968
Dropout-189 [[1, 16, 49, 49]] [1, 16, 49, 49] 0
Conv2D-432 [[1, 512, 7, 7]] [1, 512, 7, 7] 262,656
Dropout-188 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Rel_Attention-63 [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7] 2,704
Identity-83 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-126 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-126 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-429 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,050,624
GELU-167 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Dropout-187 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-430 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,049,088
Mlp-63 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Window_Block-32 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-127 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-127 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-435 [[1, 512, 7, 7]] [1, 1536, 7, 7] 787,968
Dropout-192 [[1, 16, 49, 49]] [1, 16, 49, 49] 0
Conv2D-436 [[1, 512, 7, 7]] [1, 512, 7, 7] 262,656
Dropout-191 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Rel_Attention-64 [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7] 2,704
Identity-84 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-128 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-128 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-433 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,050,624
GELU-168 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Dropout-190 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-434 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,049,088
Mlp-64 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Grid_Block-32 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Max_Block-32 [[1, 256, 14, 14]] [1, 512, 7, 7] 0
BatchNorm2D-137 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Conv2D-437 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,050,624
BatchNorm2D-138 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
GELU-169 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-438 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 20,480
BatchNorm2D-139 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 8,192
GELU-170 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-439 [[1, 2048, 1, 1]] [1, 128, 1, 1] 262,272
GELU-171 [[1, 128, 1, 1]] [1, 128, 1, 1] 0
Conv2D-440 [[1, 128, 1, 1]] [1, 2048, 1, 1] 264,192
Sigmoid-33 [[1, 2048, 1, 1]] [1, 2048, 1, 1] 0
SE-33 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Conv2D-441 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,049,088
BatchNorm2D-140 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
Identity-85 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
MBConv-33 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-129 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-129 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-444 [[1, 512, 7, 7]] [1, 1536, 7, 7] 787,968
Dropout-195 [[1, 16, 49, 49]] [1, 16, 49, 49] 0
Conv2D-445 [[1, 512, 7, 7]] [1, 512, 7, 7] 262,656
Dropout-194 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Rel_Attention-65 [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7] 2,704
Identity-86 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-130 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-130 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-442 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,050,624
GELU-172 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Dropout-193 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-443 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,049,088
Mlp-65 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Window_Block-33 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-131 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-131 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-448 [[1, 512, 7, 7]] [1, 1536, 7, 7] 787,968
Dropout-198 [[1, 16, 49, 49]] [1, 16, 49, 49] 0
Conv2D-449 [[1, 512, 7, 7]] [1, 512, 7, 7] 262,656
Dropout-197 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Rel_Attention-66 [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7] 2,704
Identity-87 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
LayerNorm-132 [[1, 7, 7, 512]] [1, 7, 7, 512] 1,024
Channel_Layernorm-132 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-446 [[1, 512, 7, 7]] [1, 2048, 7, 7] 1,050,624
GELU-173 [[1, 2048, 7, 7]] [1, 2048, 7, 7] 0
Dropout-196 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-447 [[1, 2048, 7, 7]] [1, 512, 7, 7] 1,049,088
Mlp-66 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Grid_Block-33 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Max_Block-33 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Conv2D-450 [[1, 512, 7, 7]] [1, 512, 7, 7] 262,656
BatchNorm2D-141 [[1, 512, 7, 7]] [1, 512, 7, 7] 2,048
GELU-174 [[1, 512, 7, 7]] [1, 512, 7, 7] 0
Linear-3 [[1, 512]] [1, 1000] 513,000
Softmax-3 [[1, 1000]] [1, 1000] 0
===================================================================================
Total params: 31,001,840
Trainable params: 30,893,552
Non-trainable params: 108,288
-----------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 721.36
Params size (MB): 118.26
Estimated Total Size (MB): 840.20
-----------------------------------------------------------------------------------
{'total_params': 31001840, 'trainable_params': 30893552}
4、总结
为解决传统自注意力机制在图像大小方面缺乏的可扩展性,论文提出了一种高效的、可扩展的多轴注意力模型,该模型由Block局部注意和Grid全局注意力两部分组成。本文还提出了一个新的架构,通过有效地混合提出的注意力模型与MBConv卷积,并相应地提出了一个简单的分层视觉骨干,称为MaxViT,通过简单地在多个阶段重复基本构建块。MaxViT允许任意分辨率的输入,实现全局-局部空间交互,且只具有线性复杂度。