MaxViT: Multi-Axis Vision Transformer论文浅析与代码复现

1,373 阅读46分钟

MaxViT: Multi-Axis Vision Transformer论文浅析

开启掘金成长之旅!这是我参与「掘金日新计划 · 2 月更文挑战」的第 13 天,点击查看活动详情

1、MaxViT主体结构与创新点

1.1 研究动机

  卷积神经网络经历了从AlexNet到ResNet再到Vision Transformer,其在计算机视觉任务中的表现越来越好,通过注意力机制,Vision Transformer取得了非常好的效果。然而,在没有充分的预训练情况下,Vision Transformer通常不会取得很好的效果,并且由于注意力算子需要二次复杂度,因此在层次网络的早期或高分辨率阶段通过完全注意力进行全局交互的计算量很大。如何有效地结合全局和局部交互,在计算预算下平衡模型大小和可推广性仍然是一个具有挑战性的问题。

  在该论文中,作者提出了一种新型的Transformer模块,称为多轴自注意力(multi-axis self-attention, Max-SA),它可以作为基本的架构组件,在单个块中执行局部和全局空间交互。与完全自注意力相比,Max-SA具有更大的灵活性和效率,即自然适应不同的输入长度,具有线性复杂度。此外,Max-SA仅具有线性复杂度,可以用作网络任何层的通用独立注意力模块,增加少量的计算量。
  其主要创新点包含如下三点:

  • MaxViT是一个通用的Transformer结构,在每一个块内都可以实现局部与全局之间的空间交互,同时可适应不同分辨率的输入大小。
  • Max-SA通过分解空间轴得到窗口注意力(Block attention)与网格注意力(Grid attention),将传统计算方法的二次复杂度降到线性复杂度。
  • MBConv作为自注意力计算的补充,利用其固有的归纳偏差来提升模型的泛化能力,避免陷入过拟合。

1.2 Max-SA主要结构


  作者通过引入Max-SA模块,将传统的自注意机制分解为窗格注意力(Block attention)与网格注意力(Grid attention)两种稀疏形式,在不损失非局部性的情况下,将传统注意力机制的计算复杂度从二次复杂度降低到线性。并且Max-SA具有灵活性和可伸缩性,我们可以通过简单地将Max-SA与MBConv在分层体系结构中叠加,从而构建一个称为MaxViT的视觉Backbone,MaxViT主要结构如上图2所示。

class MaxViT(nn.Layer):
    def __init__(self, args):
        super().__init__()
        self.conv_stem = nn.Sequential(nn.Conv2D(args['input_dim'], args['stem_dim'], 3,2,3//2),
                                       nn.BatchNorm2D(args['stem_dim']),
                                       nn.GELU(),
                                       nn.Conv2D(args['stem_dim'], args['stem_dim'], 3,1,3//2),
                                       nn.BatchNorm2D(args['stem_dim']),
                                       nn.GELU())
        in_dim = args['stem_dim']
        self.max_blocks = nn.LayerList([])
        for i,num_block in enumerate(args['stage_num_block']):
            layers = nn.LayerList([])
            out_dim = args['stage_dim']*(2**i)
            num_head = args['num_heads']*(2**i)
            for i in range(num_block):
                pooling_size = args['pooling_size']if i == 0 else 1
                layers.append(Max_Block(in_dim,out_dim,num_head,args['block_size'], 
                                        args['grid_size'],args['mbconv_ksize'],pooling_size,
                                        args['mbconv_expand_rate'],args['se_rate'],args['mlp_ratio'],
                                        args['qkv_bias'],args['qk_scale'], args['drop'], args['attn_drop'],
                                        args['drop_path'],args['act_layer'] ,args['norm_layer']))
                in_dim = out_dim
            self.max_blocks.append(layers)
        self.last_conv = nn.Sequential(nn.Conv2D(in_dim,in_dim,1,),
                                       nn.BatchNorm2D(in_dim),
                                       nn.GELU())
        self.proj = nn.Linear(in_dim,args['num_classes'])
        self.softmax = nn.Softmax(1)
        
    def forward(self, x):
        x = self.conv_stem(x)
        for blocks in self.max_blocks:
            for block in blocks:
                x = block(x)
        x = self.last_conv(x)
        x = self.softmax(self.proj(x.mean([2, 3])))
        return x

1.3 Multi-axis Attention 详解


  与局部卷积相比,全局相互作用是自注意力机制的优势之一。然而,直接将注意力应用于整个空间在计算上是不可行的,因为注意力算子需要二次复杂度,为了解决全局自注意力机制导致的二次计算复杂度,作者通过分解空间轴得到局部(block attention)与全局(grid attention)两种稀疏形式,巧妙的解决了计算复杂度的问题。如上所示,Max-SA模块主要包含Block Attention与Grid Attention两个部分。

class Max_Block(nn.Layer):
    def __init__(self, in_dim, out_dim , num_heads=8.,block_size=(7,7), grid_size=(7,7),
                 mbconv_ksize = 3,pooling_size = 1,mbconv_expand_rate=4,se_reduce_rate=0.25,
                 mlp_ratio=4,qkv_bias=False,qk_scale=None, drop=0., attn_drop=0.,drop_path=0., 
                 act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
        super().__init__()
        self.mbconv = MBConv(in_dim,out_dim,mbconv_ksize,pooling_size,mbconv_expand_rate,se_reduce_rate,drop)
        self.block_attn = Window_Block(out_dim, block_size, num_heads, mlp_ratio, qkv_bias,qk_scale, drop, 
                                       attn_drop,drop_path, act_layer ,norm_layer)
        self.grid_attn = Grid_Block(out_dim, grid_size, num_heads, mlp_ratio, qkv_bias,qk_scale, drop, 
                                    attn_drop,drop_path, act_layer ,norm_layer)
        
    def forward(self, x):
        x = self.mbconv(x)
        x = self.block_attn(x)
        x = self.grid_attn(x)
        return x
  • Block Attention

  将输入特征图划分为不重叠的窗口, 最后在每一个窗口中执行自注意力计算。虽然避免了全局自注意力机制的复杂计算,但是局部注意模型已经被证明不适用于大规模的数据集。所以作者提出一种稀疏的全局自注意力机制,被称作grid attention(网格注意力机制)。

class Window_Block(nn.Layer):
    def __init__(self, dim, block_size=(7,7), num_heads=8, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., 
                 attn_drop=0.,drop_path=0., act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
        super().__init__()
        self.block_size = block_size
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.attn = Rel_Attention(dim, block_size, num_heads, qkv_bias, qk_scale, attn_drop, drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self,x):
        assert x.shape[2]%self.block_size[0] == 0 & x.shape[3]%self.block_size[1] == 0, 'image size should be divisible by block_size'
        
        out = block(self.norm1(x),self.block_size)
        out = self.attn(out)
        x = x + self.drop_path(unblock(self.attn(out)))
        out = self.mlp(self.norm2(x))
        x = x + self.drop_path(out)
        return x
  • Grid Attention

  不同于传统使用固定窗口大小来划分特征图的操作,grid attention 使用固定的大小的均匀网格将输人张量网格化, 可以有效平衡局部和全局之间的计算 (且仅具有线性复杂度)。

class Grid_Block(nn.Layer):
    def __init__(self, dim, grid_size=(7,7), num_heads=8, mlp_ratio=4., qkv_bias=False,qk_scale=None, drop=0., 
                 attn_drop=0.,drop_path=0., act_layer=nn.GELU ,norm_layer=Channel_Layernorm):
        super().__init__()
        self.grid_size = grid_size
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.norm1 = norm_layer(dim)
        self.norm2 = norm_layer(dim)
        self.mlp = Mlp(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.attn = Rel_Attention(dim, grid_size, num_heads, qkv_bias, qk_scale, attn_drop, drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        
    def forward(self,x):
        assert x.shape[2]%self.grid_size[0] == 0 & x.shape[3]%self.grid_size[1] == 0, 'image size should be divisible by grid_size'
        grid_size = (x.shape[2]//self.grid_size[0], x.shape[3]//self.grid_size[1])
        
        out = block(self.norm1(x),grid_size)
        out = out.transpose([0,4,5,3,1,2])
        out = self.attn(out).transpose([0,4,5,3,1,2])
        x = x + self.drop_path(unblock(out))
        out = self.mlp(self.norm2(x))
        x = x + self.drop_path(out)
        return x

1.4 MBConv

  为了获得更丰富的特征表示,首先使用逐点卷积进行通道升维,在升维后的投影空间中进行Depth-wise卷积,紧随其后的SE用于增强重要通道的表征,最后再次使用逐点卷积恢复维度。可用如下公式表示:

  对于每个阶段的第一个MBConv块,下采样是通过应用stride=2的深度可分离卷积( Depthwise Conv3x3)来完成的,而残差连接分支也 应用pooling 和 channel 映射:

  MBConv包含如下特点:

  • 采用了Depthwise Convlution,因此相比于传统卷积,Depthwise Conv的参数能够大大减少;
  • 采用了“倒瓶颈”的结构,也就是说在卷积过程中,特征经历了升维和降维两个步骤,并利用卷积固有的归纳偏置,在一定程度上提升模型的泛化能力与可训练性。
  • 相比于ViT中的显式位置编码,在Multi-axis Attention则使用MBConv来代替,这是因为深度可分离卷积可被视为条件位置编码(CPE)。
class MBConv(nn.Layer):
    def __init__(self,in_dim,out_dim,kernel_size=3,stride_size=1,expand_rate = 4,se_rate = 0.25,dropout = 0.):
        super().__init__()
        hidden_dim = int(expand_rate * out_dim)
        self.bn = nn.BatchNorm2D(in_dim)
        self.expand_conv = nn.Sequential(nn.Conv2D(in_dim, hidden_dim, 1),
                                         nn.BatchNorm2D(hidden_dim),
                                         nn.GELU())
        self.dw_conv = nn.Sequential(nn.Conv2D(hidden_dim, hidden_dim, kernel_size, stride_size, kernel_size//2, groups=hidden_dim),
                                     nn.BatchNorm2D(hidden_dim),
                                     nn.GELU())
        self.se = SE(hidden_dim,max(1,int(out_dim*se_rate)))
        self.out_conv = nn.Sequential(nn.Conv2D(hidden_dim, out_dim, 1),
                                      nn.BatchNorm2D(out_dim))
        if stride_size > 1:
            self.proj = nn.Sequential(nn.MaxPool2D(kernel_size, stride_size, kernel_size//2),
                                      nn.Conv2D(in_dim, out_dim, 1)) 
        else: 
            self.proj = nn.Identity()
    
    def forward(self, x):
        out = self.bn(x)
        out = self.expand_conv(out)
        out = self.dw_conv(out)
        out = self.se(out)
        out = self.out_conv(out)
        return out + self.proj(x)

1.5 Multi-Axis attention与Axial attention区别

  论文所提出的方法不同于 Axial attention。如图 3 所示, 在 Axial attention 中 首先使用列注意力(column-wise attention),然后使用行注意力( row-wise attention) 来计算全局 注意力, 。然而 Multi-Axis attention 则先采用局部注意力 (block attention), 再使用稀疏的全局注意力 (grid attention), 这样的 设计充分考虑了图像的 2D 结构。

2、整体网络结构复现

在论文中,作者基于Max-SA模块搭建了四种网络结构(MaxViT model family (T/S/B/L)),本项目对这四种结构均进行了复现,其网络结构列表如下:

3、网络模型结构输出

import paddle
from maxvit import MaxViT,tiny_args

print(MaxViT(tiny_args)(paddle.zeros([2,3,224,224])).shape)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/norm.py:654: UserWarning: When training, we now always track global mean and variance.
  "When training, we now always track global mean and variance.")


[2, 1000]
model = MaxViT(tiny_args)

paddle.summary(model,(1,3,224,224))
-----------------------------------------------------------------------------------
    Layer (type)           Input Shape           Output Shape         Param #    
===================================================================================
     Conv2D-301         [[1, 3, 224, 224]]    [1, 64, 112, 112]        1,792     
   BatchNorm2D-95      [[1, 64, 112, 112]]    [1, 64, 112, 112]         256      
      GELU-117         [[1, 64, 112, 112]]    [1, 64, 112, 112]          0       
     Conv2D-302        [[1, 64, 112, 112]]    [1, 64, 112, 112]       36,928     
   BatchNorm2D-96      [[1, 64, 112, 112]]    [1, 64, 112, 112]         256      
      GELU-118         [[1, 64, 112, 112]]    [1, 64, 112, 112]          0       
   BatchNorm2D-97      [[1, 64, 112, 112]]    [1, 64, 112, 112]         256      
     Conv2D-303        [[1, 64, 112, 112]]    [1, 256, 112, 112]      16,640     
   BatchNorm2D-98      [[1, 256, 112, 112]]   [1, 256, 112, 112]       1,024     
      GELU-119         [[1, 256, 112, 112]]   [1, 256, 112, 112]         0       
     Conv2D-304        [[1, 256, 112, 112]]    [1, 256, 56, 56]        2,560     
   BatchNorm2D-99       [[1, 256, 56, 56]]     [1, 256, 56, 56]        1,024     
      GELU-120          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-305          [[1, 256, 1, 1]]       [1, 16, 1, 1]          4,112     
      GELU-121           [[1, 16, 1, 1]]        [1, 16, 1, 1]            0       
     Conv2D-306          [[1, 16, 1, 1]]        [1, 256, 1, 1]         4,352     
     Sigmoid-23          [[1, 256, 1, 1]]       [1, 256, 1, 1]           0       
        SE-23           [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-307         [[1, 256, 56, 56]]     [1, 64, 56, 56]        16,448     
   BatchNorm2D-100      [[1, 64, 56, 56]]      [1, 64, 56, 56]          256      
     MaxPool2D-9       [[1, 64, 112, 112]]     [1, 64, 56, 56]           0       
     Conv2D-308         [[1, 64, 56, 56]]      [1, 64, 56, 56]         4,160     
      MBConv-23        [[1, 64, 112, 112]]     [1, 64, 56, 56]           0       
    LayerNorm-89        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-89    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-311          [[64, 64, 7, 7]]      [64, 192, 7, 7]        12,480     
     Dropout-135        [[64, 2, 49, 49]]      [64, 2, 49, 49]           0       
     Conv2D-312          [[64, 64, 7, 7]]       [64, 64, 7, 7]         4,160     
     Dropout-134         [[64, 64, 7, 7]]       [64, 64, 7, 7]           0       
  Rel_Attention-45    [[1, 8, 8, 64, 7, 7]]  [1, 8, 8, 64, 7, 7]        338      
     Identity-59        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-90        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-90    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-309         [[1, 64, 56, 56]]      [1, 256, 56, 56]       16,640     
      GELU-122          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Dropout-133        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-310         [[1, 256, 56, 56]]     [1, 64, 56, 56]        16,448     
       Mlp-45           [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
   Window_Block-23      [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-91        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-91    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-315          [[64, 64, 7, 7]]      [64, 192, 7, 7]        12,480     
     Dropout-138        [[64, 2, 49, 49]]      [64, 2, 49, 49]           0       
     Conv2D-316          [[64, 64, 7, 7]]       [64, 64, 7, 7]         4,160     
     Dropout-137         [[64, 64, 7, 7]]       [64, 64, 7, 7]           0       
  Rel_Attention-46    [[1, 8, 8, 64, 7, 7]]  [1, 8, 8, 64, 7, 7]        338      
     Identity-60        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-92        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-92    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-313         [[1, 64, 56, 56]]      [1, 256, 56, 56]       16,640     
      GELU-123          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Dropout-136        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-314         [[1, 256, 56, 56]]     [1, 64, 56, 56]        16,448     
       Mlp-46           [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    Grid_Block-23       [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    Max_Block-23       [[1, 64, 112, 112]]     [1, 64, 56, 56]           0       
   BatchNorm2D-101      [[1, 64, 56, 56]]      [1, 64, 56, 56]          256      
     Conv2D-317         [[1, 64, 56, 56]]      [1, 256, 56, 56]       16,640     
   BatchNorm2D-102      [[1, 256, 56, 56]]     [1, 256, 56, 56]        1,024     
      GELU-124          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-318         [[1, 256, 56, 56]]     [1, 256, 56, 56]        2,560     
   BatchNorm2D-103      [[1, 256, 56, 56]]     [1, 256, 56, 56]        1,024     
      GELU-125          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-319          [[1, 256, 1, 1]]       [1, 16, 1, 1]          4,112     
      GELU-126           [[1, 16, 1, 1]]        [1, 16, 1, 1]            0       
     Conv2D-320          [[1, 16, 1, 1]]        [1, 256, 1, 1]         4,352     
     Sigmoid-24          [[1, 256, 1, 1]]       [1, 256, 1, 1]           0       
        SE-24           [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Conv2D-321         [[1, 256, 56, 56]]     [1, 64, 56, 56]        16,448     
   BatchNorm2D-104      [[1, 64, 56, 56]]      [1, 64, 56, 56]          256      
     Identity-61        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
      MBConv-24         [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-93        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-93    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-324          [[64, 64, 7, 7]]      [64, 192, 7, 7]        12,480     
     Dropout-141        [[64, 2, 49, 49]]      [64, 2, 49, 49]           0       
     Conv2D-325          [[64, 64, 7, 7]]       [64, 64, 7, 7]         4,160     
     Dropout-140         [[64, 64, 7, 7]]       [64, 64, 7, 7]           0       
  Rel_Attention-47    [[1, 8, 8, 64, 7, 7]]  [1, 8, 8, 64, 7, 7]        338      
     Identity-62        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-94        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-94    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-322         [[1, 64, 56, 56]]      [1, 256, 56, 56]       16,640     
      GELU-127          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Dropout-139        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-323         [[1, 256, 56, 56]]     [1, 64, 56, 56]        16,448     
       Mlp-47           [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
   Window_Block-24      [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-95        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-95    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-328          [[64, 64, 7, 7]]      [64, 192, 7, 7]        12,480     
     Dropout-144        [[64, 2, 49, 49]]      [64, 2, 49, 49]           0       
     Conv2D-329          [[64, 64, 7, 7]]       [64, 64, 7, 7]         4,160     
     Dropout-143         [[64, 64, 7, 7]]       [64, 64, 7, 7]           0       
  Rel_Attention-48    [[1, 8, 8, 64, 7, 7]]  [1, 8, 8, 64, 7, 7]        338      
     Identity-63        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    LayerNorm-96        [[1, 56, 56, 64]]      [1, 56, 56, 64]          128      
Channel_Layernorm-96    [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-326         [[1, 64, 56, 56]]      [1, 256, 56, 56]       16,640     
      GELU-128          [[1, 256, 56, 56]]     [1, 256, 56, 56]          0       
     Dropout-142        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
     Conv2D-327         [[1, 256, 56, 56]]     [1, 64, 56, 56]        16,448     
       Mlp-48           [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    Grid_Block-24       [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
    Max_Block-24        [[1, 64, 56, 56]]      [1, 64, 56, 56]           0       
   BatchNorm2D-105      [[1, 64, 56, 56]]      [1, 64, 56, 56]          256      
     Conv2D-330         [[1, 64, 56, 56]]      [1, 512, 56, 56]       33,280     
   BatchNorm2D-106      [[1, 512, 56, 56]]     [1, 512, 56, 56]        2,048     
      GELU-129          [[1, 512, 56, 56]]     [1, 512, 56, 56]          0       
     Conv2D-331         [[1, 512, 56, 56]]     [1, 512, 28, 28]        5,120     
   BatchNorm2D-107      [[1, 512, 28, 28]]     [1, 512, 28, 28]        2,048     
      GELU-130          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Conv2D-332          [[1, 512, 1, 1]]       [1, 32, 1, 1]         16,416     
      GELU-131           [[1, 32, 1, 1]]        [1, 32, 1, 1]            0       
     Conv2D-333          [[1, 32, 1, 1]]        [1, 512, 1, 1]        16,896     
     Sigmoid-25          [[1, 512, 1, 1]]       [1, 512, 1, 1]           0       
        SE-25           [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Conv2D-334         [[1, 512, 28, 28]]     [1, 128, 28, 28]       65,664     
   BatchNorm2D-108      [[1, 128, 28, 28]]     [1, 128, 28, 28]         512      
    MaxPool2D-10        [[1, 64, 56, 56]]      [1, 64, 28, 28]           0       
     Conv2D-335         [[1, 64, 28, 28]]      [1, 128, 28, 28]        8,320     
      MBConv-25         [[1, 64, 56, 56]]      [1, 128, 28, 28]          0       
    LayerNorm-97        [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-97    [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-338         [[16, 128, 7, 7]]      [16, 384, 7, 7]        49,536     
     Dropout-147        [[16, 4, 49, 49]]      [16, 4, 49, 49]           0       
     Conv2D-339         [[16, 128, 7, 7]]      [16, 128, 7, 7]        16,512     
     Dropout-146        [[16, 128, 7, 7]]      [16, 128, 7, 7]           0       
  Rel_Attention-49    [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7]       676      
     Identity-64        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-98        [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-98    [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-336         [[1, 128, 28, 28]]     [1, 512, 28, 28]       66,048     
      GELU-132          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Dropout-145        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-337         [[1, 512, 28, 28]]     [1, 128, 28, 28]       65,664     
       Mlp-49           [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
   Window_Block-25      [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-99        [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-99    [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-342         [[16, 128, 7, 7]]      [16, 384, 7, 7]        49,536     
     Dropout-150        [[16, 4, 49, 49]]      [16, 4, 49, 49]           0       
     Conv2D-343         [[16, 128, 7, 7]]      [16, 128, 7, 7]        16,512     
     Dropout-149        [[16, 128, 7, 7]]      [16, 128, 7, 7]           0       
  Rel_Attention-50    [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7]       676      
     Identity-65        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-100       [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-100   [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-340         [[1, 128, 28, 28]]     [1, 512, 28, 28]       66,048     
      GELU-133          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Dropout-148        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-341         [[1, 512, 28, 28]]     [1, 128, 28, 28]       65,664     
       Mlp-50           [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    Grid_Block-25       [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    Max_Block-25        [[1, 64, 56, 56]]      [1, 128, 28, 28]          0       
   BatchNorm2D-109      [[1, 128, 28, 28]]     [1, 128, 28, 28]         512      
     Conv2D-344         [[1, 128, 28, 28]]     [1, 512, 28, 28]       66,048     
   BatchNorm2D-110      [[1, 512, 28, 28]]     [1, 512, 28, 28]        2,048     
      GELU-134          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Conv2D-345         [[1, 512, 28, 28]]     [1, 512, 28, 28]        5,120     
   BatchNorm2D-111      [[1, 512, 28, 28]]     [1, 512, 28, 28]        2,048     
      GELU-135          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Conv2D-346          [[1, 512, 1, 1]]       [1, 32, 1, 1]         16,416     
      GELU-136           [[1, 32, 1, 1]]        [1, 32, 1, 1]            0       
     Conv2D-347          [[1, 32, 1, 1]]        [1, 512, 1, 1]        16,896     
     Sigmoid-26          [[1, 512, 1, 1]]       [1, 512, 1, 1]           0       
        SE-26           [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Conv2D-348         [[1, 512, 28, 28]]     [1, 128, 28, 28]       65,664     
   BatchNorm2D-112      [[1, 128, 28, 28]]     [1, 128, 28, 28]         512      
     Identity-66        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
      MBConv-26         [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-101       [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-101   [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-351         [[16, 128, 7, 7]]      [16, 384, 7, 7]        49,536     
     Dropout-153        [[16, 4, 49, 49]]      [16, 4, 49, 49]           0       
     Conv2D-352         [[16, 128, 7, 7]]      [16, 128, 7, 7]        16,512     
     Dropout-152        [[16, 128, 7, 7]]      [16, 128, 7, 7]           0       
  Rel_Attention-51    [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7]       676      
     Identity-67        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-102       [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-102   [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-349         [[1, 128, 28, 28]]     [1, 512, 28, 28]       66,048     
      GELU-137          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Dropout-151        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-350         [[1, 512, 28, 28]]     [1, 128, 28, 28]       65,664     
       Mlp-51           [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
   Window_Block-26      [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-103       [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-103   [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-355         [[16, 128, 7, 7]]      [16, 384, 7, 7]        49,536     
     Dropout-156        [[16, 4, 49, 49]]      [16, 4, 49, 49]           0       
     Conv2D-356         [[16, 128, 7, 7]]      [16, 128, 7, 7]        16,512     
     Dropout-155        [[16, 128, 7, 7]]      [16, 128, 7, 7]           0       
  Rel_Attention-52    [[1, 4, 4, 128, 7, 7]] [1, 4, 4, 128, 7, 7]       676      
     Identity-68        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    LayerNorm-104       [[1, 28, 28, 128]]     [1, 28, 28, 128]         256      
Channel_Layernorm-104   [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-353         [[1, 128, 28, 28]]     [1, 512, 28, 28]       66,048     
      GELU-138          [[1, 512, 28, 28]]     [1, 512, 28, 28]          0       
     Dropout-154        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
     Conv2D-354         [[1, 512, 28, 28]]     [1, 128, 28, 28]       65,664     
       Mlp-52           [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    Grid_Block-26       [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
    Max_Block-26        [[1, 128, 28, 28]]     [1, 128, 28, 28]          0       
   BatchNorm2D-113      [[1, 128, 28, 28]]     [1, 128, 28, 28]         512      
     Conv2D-357         [[1, 128, 28, 28]]    [1, 1024, 28, 28]       132,096    
   BatchNorm2D-114     [[1, 1024, 28, 28]]    [1, 1024, 28, 28]        4,096     
      GELU-139         [[1, 1024, 28, 28]]    [1, 1024, 28, 28]          0       
     Conv2D-358        [[1, 1024, 28, 28]]    [1, 1024, 14, 14]       10,240     
   BatchNorm2D-115     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-140         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-359         [[1, 1024, 1, 1]]       [1, 64, 1, 1]         65,600     
      GELU-141           [[1, 64, 1, 1]]        [1, 64, 1, 1]            0       
     Conv2D-360          [[1, 64, 1, 1]]       [1, 1024, 1, 1]        66,560     
     Sigmoid-27         [[1, 1024, 1, 1]]      [1, 1024, 1, 1]           0       
        SE-27          [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-361        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
   BatchNorm2D-116      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
    MaxPool2D-11        [[1, 128, 28, 28]]     [1, 128, 14, 14]          0       
     Conv2D-362         [[1, 128, 14, 14]]     [1, 256, 14, 14]       33,024     
      MBConv-27         [[1, 128, 28, 28]]     [1, 256, 14, 14]          0       
    LayerNorm-105       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-105   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-365          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-159         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-366          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-158         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-53    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-69        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-106       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-106   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-363         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-142         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-157        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-364        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-53           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   Window_Block-27      [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-107       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-107   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-369          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-162         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-370          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-161         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-54    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-70        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-108       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-108   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-367         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-143         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-160        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-368        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-54           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Grid_Block-27       [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Max_Block-27        [[1, 128, 28, 28]]     [1, 256, 14, 14]          0       
   BatchNorm2D-117      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Conv2D-371         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
   BatchNorm2D-118     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-144         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-372        [[1, 1024, 14, 14]]    [1, 1024, 14, 14]       10,240     
   BatchNorm2D-119     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-145         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-373         [[1, 1024, 1, 1]]       [1, 64, 1, 1]         65,600     
      GELU-146           [[1, 64, 1, 1]]        [1, 64, 1, 1]            0       
     Conv2D-374          [[1, 64, 1, 1]]       [1, 1024, 1, 1]        66,560     
     Sigmoid-28         [[1, 1024, 1, 1]]      [1, 1024, 1, 1]           0       
        SE-28          [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-375        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
   BatchNorm2D-120      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Identity-71        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
      MBConv-28         [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-109       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-109   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-378          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-165         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-379          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-164         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-55    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-72        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-110       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-110   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-376         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-147         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-163        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-377        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-55           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   Window_Block-28      [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-111       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-111   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-382          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-168         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-383          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-167         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-56    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-73        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-112       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-112   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-380         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-148         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-166        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-381        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-56           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Grid_Block-28       [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Max_Block-28        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   BatchNorm2D-121      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Conv2D-384         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
   BatchNorm2D-122     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-149         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-385        [[1, 1024, 14, 14]]    [1, 1024, 14, 14]       10,240     
   BatchNorm2D-123     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-150         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-386         [[1, 1024, 1, 1]]       [1, 64, 1, 1]         65,600     
      GELU-151           [[1, 64, 1, 1]]        [1, 64, 1, 1]            0       
     Conv2D-387          [[1, 64, 1, 1]]       [1, 1024, 1, 1]        66,560     
     Sigmoid-29         [[1, 1024, 1, 1]]      [1, 1024, 1, 1]           0       
        SE-29          [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-388        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
   BatchNorm2D-124      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Identity-74        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
      MBConv-29         [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-113       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-113   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-391          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-171         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-392          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-170         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-57    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-75        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-114       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-114   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-389         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-152         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-169        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-390        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-57           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   Window_Block-29      [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-115       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-115   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-395          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-174         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-396          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-173         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-58    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-76        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-116       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-116   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-393         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-153         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-172        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-394        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-58           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Grid_Block-29       [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Max_Block-29        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   BatchNorm2D-125      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Conv2D-397         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
   BatchNorm2D-126     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-154         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-398        [[1, 1024, 14, 14]]    [1, 1024, 14, 14]       10,240     
   BatchNorm2D-127     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-155         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-399         [[1, 1024, 1, 1]]       [1, 64, 1, 1]         65,600     
      GELU-156           [[1, 64, 1, 1]]        [1, 64, 1, 1]            0       
     Conv2D-400          [[1, 64, 1, 1]]       [1, 1024, 1, 1]        66,560     
     Sigmoid-30         [[1, 1024, 1, 1]]      [1, 1024, 1, 1]           0       
        SE-30          [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-401        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
   BatchNorm2D-128      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Identity-77        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
      MBConv-30         [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-117       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-117   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-404          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-177         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-405          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-176         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-59    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-78        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-118       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-118   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-402         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-157         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-175        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-403        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-59           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   Window_Block-30      [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-119       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-119   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-408          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-180         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-409          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-179         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-60    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-79        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-120       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-120   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-406         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-158         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-178        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-407        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-60           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Grid_Block-30       [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Max_Block-30        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   BatchNorm2D-129      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Conv2D-410         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
   BatchNorm2D-130     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-159         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-411        [[1, 1024, 14, 14]]    [1, 1024, 14, 14]       10,240     
   BatchNorm2D-131     [[1, 1024, 14, 14]]    [1, 1024, 14, 14]        4,096     
      GELU-160         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-412         [[1, 1024, 1, 1]]       [1, 64, 1, 1]         65,600     
      GELU-161           [[1, 64, 1, 1]]        [1, 64, 1, 1]            0       
     Conv2D-413          [[1, 64, 1, 1]]       [1, 1024, 1, 1]        66,560     
     Sigmoid-31         [[1, 1024, 1, 1]]      [1, 1024, 1, 1]           0       
        SE-31          [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Conv2D-414        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
   BatchNorm2D-132      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Identity-80        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
      MBConv-31         [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-121       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-121   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-417          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-183         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-418          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-182         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-61    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-81        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-122       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-122   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-415         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-162         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-181        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-416        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-61           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   Window_Block-31      [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-123       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-123   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-421          [[4, 256, 7, 7]]       [4, 768, 7, 7]        197,376    
     Dropout-186         [[4, 8, 49, 49]]       [4, 8, 49, 49]           0       
     Conv2D-422          [[4, 256, 7, 7]]       [4, 256, 7, 7]        65,792     
     Dropout-185         [[4, 256, 7, 7]]       [4, 256, 7, 7]           0       
  Rel_Attention-62    [[1, 2, 2, 256, 7, 7]] [1, 2, 2, 256, 7, 7]      1,352     
     Identity-82        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    LayerNorm-124       [[1, 14, 14, 256]]     [1, 14, 14, 256]         512      
Channel_Layernorm-124   [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-419         [[1, 256, 14, 14]]    [1, 1024, 14, 14]       263,168    
      GELU-163         [[1, 1024, 14, 14]]    [1, 1024, 14, 14]          0       
     Dropout-184        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
     Conv2D-420        [[1, 1024, 14, 14]]     [1, 256, 14, 14]       262,400    
       Mlp-62           [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Grid_Block-31       [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
    Max_Block-31        [[1, 256, 14, 14]]     [1, 256, 14, 14]          0       
   BatchNorm2D-133      [[1, 256, 14, 14]]     [1, 256, 14, 14]        1,024     
     Conv2D-423         [[1, 256, 14, 14]]    [1, 2048, 14, 14]       526,336    
   BatchNorm2D-134     [[1, 2048, 14, 14]]    [1, 2048, 14, 14]        8,192     
      GELU-164         [[1, 2048, 14, 14]]    [1, 2048, 14, 14]          0       
     Conv2D-424        [[1, 2048, 14, 14]]     [1, 2048, 7, 7]        20,480     
   BatchNorm2D-135      [[1, 2048, 7, 7]]      [1, 2048, 7, 7]         8,192     
      GELU-165          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Conv2D-425         [[1, 2048, 1, 1]]       [1, 128, 1, 1]        262,272    
      GELU-166           [[1, 128, 1, 1]]       [1, 128, 1, 1]           0       
     Conv2D-426          [[1, 128, 1, 1]]      [1, 2048, 1, 1]        264,192    
     Sigmoid-32         [[1, 2048, 1, 1]]      [1, 2048, 1, 1]           0       
        SE-32           [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Conv2D-427         [[1, 2048, 7, 7]]       [1, 512, 7, 7]       1,049,088   
   BatchNorm2D-136       [[1, 512, 7, 7]]       [1, 512, 7, 7]         2,048     
    MaxPool2D-12        [[1, 256, 14, 14]]      [1, 256, 7, 7]           0       
     Conv2D-428          [[1, 256, 7, 7]]       [1, 512, 7, 7]        131,584    
      MBConv-32         [[1, 256, 14, 14]]      [1, 512, 7, 7]           0       
    LayerNorm-125        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-125    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-431          [[1, 512, 7, 7]]      [1, 1536, 7, 7]        787,968    
     Dropout-189        [[1, 16, 49, 49]]      [1, 16, 49, 49]           0       
     Conv2D-432          [[1, 512, 7, 7]]       [1, 512, 7, 7]        262,656    
     Dropout-188         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
  Rel_Attention-63    [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7]      2,704     
     Identity-83         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-126        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-126    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-429          [[1, 512, 7, 7]]      [1, 2048, 7, 7]       1,050,624   
      GELU-167          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Dropout-187         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-430         [[1, 2048, 7, 7]]       [1, 512, 7, 7]       1,049,088   
       Mlp-63            [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
   Window_Block-32       [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-127        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-127    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-435          [[1, 512, 7, 7]]      [1, 1536, 7, 7]        787,968    
     Dropout-192        [[1, 16, 49, 49]]      [1, 16, 49, 49]           0       
     Conv2D-436          [[1, 512, 7, 7]]       [1, 512, 7, 7]        262,656    
     Dropout-191         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
  Rel_Attention-64    [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7]      2,704     
     Identity-84         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-128        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-128    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-433          [[1, 512, 7, 7]]      [1, 2048, 7, 7]       1,050,624   
      GELU-168          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Dropout-190         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-434         [[1, 2048, 7, 7]]       [1, 512, 7, 7]       1,049,088   
       Mlp-64            [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    Grid_Block-32        [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    Max_Block-32        [[1, 256, 14, 14]]      [1, 512, 7, 7]           0       
   BatchNorm2D-137       [[1, 512, 7, 7]]       [1, 512, 7, 7]         2,048     
     Conv2D-437          [[1, 512, 7, 7]]      [1, 2048, 7, 7]       1,050,624   
   BatchNorm2D-138      [[1, 2048, 7, 7]]      [1, 2048, 7, 7]         8,192     
      GELU-169          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Conv2D-438         [[1, 2048, 7, 7]]      [1, 2048, 7, 7]        20,480     
   BatchNorm2D-139      [[1, 2048, 7, 7]]      [1, 2048, 7, 7]         8,192     
      GELU-170          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Conv2D-439         [[1, 2048, 1, 1]]       [1, 128, 1, 1]        262,272    
      GELU-171           [[1, 128, 1, 1]]       [1, 128, 1, 1]           0       
     Conv2D-440          [[1, 128, 1, 1]]      [1, 2048, 1, 1]        264,192    
     Sigmoid-33         [[1, 2048, 1, 1]]      [1, 2048, 1, 1]           0       
        SE-33           [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Conv2D-441         [[1, 2048, 7, 7]]       [1, 512, 7, 7]       1,049,088   
   BatchNorm2D-140       [[1, 512, 7, 7]]       [1, 512, 7, 7]         2,048     
     Identity-85         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
      MBConv-33          [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-129        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-129    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-444          [[1, 512, 7, 7]]      [1, 1536, 7, 7]        787,968    
     Dropout-195        [[1, 16, 49, 49]]      [1, 16, 49, 49]           0       
     Conv2D-445          [[1, 512, 7, 7]]       [1, 512, 7, 7]        262,656    
     Dropout-194         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
  Rel_Attention-65    [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7]      2,704     
     Identity-86         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-130        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-130    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-442          [[1, 512, 7, 7]]      [1, 2048, 7, 7]       1,050,624   
      GELU-172          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Dropout-193         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-443         [[1, 2048, 7, 7]]       [1, 512, 7, 7]       1,049,088   
       Mlp-65            [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
   Window_Block-33       [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-131        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-131    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-448          [[1, 512, 7, 7]]      [1, 1536, 7, 7]        787,968    
     Dropout-198        [[1, 16, 49, 49]]      [1, 16, 49, 49]           0       
     Conv2D-449          [[1, 512, 7, 7]]       [1, 512, 7, 7]        262,656    
     Dropout-197         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
  Rel_Attention-66    [[1, 1, 1, 512, 7, 7]] [1, 1, 1, 512, 7, 7]      2,704     
     Identity-87         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    LayerNorm-132        [[1, 7, 7, 512]]       [1, 7, 7, 512]         1,024     
Channel_Layernorm-132    [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-446          [[1, 512, 7, 7]]      [1, 2048, 7, 7]       1,050,624   
      GELU-173          [[1, 2048, 7, 7]]      [1, 2048, 7, 7]           0       
     Dropout-196         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-447         [[1, 2048, 7, 7]]       [1, 512, 7, 7]       1,049,088   
       Mlp-66            [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    Grid_Block-33        [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
    Max_Block-33         [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
     Conv2D-450          [[1, 512, 7, 7]]       [1, 512, 7, 7]        262,656    
   BatchNorm2D-141       [[1, 512, 7, 7]]       [1, 512, 7, 7]         2,048     
      GELU-174           [[1, 512, 7, 7]]       [1, 512, 7, 7]           0       
      Linear-3              [[1, 512]]            [1, 1000]           513,000    
      Softmax-3            [[1, 1000]]            [1, 1000]              0       
===================================================================================
Total params: 31,001,840
Trainable params: 30,893,552
Non-trainable params: 108,288
-----------------------------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 721.36
Params size (MB): 118.26
Estimated Total Size (MB): 840.20
-----------------------------------------------------------------------------------






{'total_params': 31001840, 'trainable_params': 30893552}

4、总结

  为解决传统自注意力机制在图像大小方面缺乏的可扩展性,论文提出了一种高效的、可扩展的多轴注意力模型,该模型由Block局部注意和Grid全局注意力两部分组成。本文还提出了一个新的架构,通过有效地混合提出的注意力模型与MBConv卷积,并相应地提出了一个简单的分层视觉骨干,称为MaxViT,通过简单地在多个阶段重复基本构建块。MaxViT允许任意分辨率的输入,实现全局-局部空间交互,且只具有线性复杂度。

5、参考资料

讲解 MaxViT: Multi-Axis Vision Transformer

论文 MaxViT: Multi-Axis Vision Transformer