一种通用的图像分割模型(论文复现)

236 阅读6分钟

一种通用的图像分割模型(论文复现)

本文所涉及所有资源均在传知代码平台可获取

概述

图像分割研究像素分组问题,对像素进行分组的不同语义产生了不同类型的分割任务,例如全景分割、实例分割或语义分割。虽然这些任务中只有语义不同,但目前的研究侧重于为每个任务设计专门的架构。Mask2Former是一个能够处理图像多种分割任务(全景分割、实例分割、语义分割)的新框架。它的关键组件是掩码注意力机制,通过约束预测掩码区域内的交叉注意来提取局部特征。Mask2Former将研究工作减少了至少三倍,且在四个流行的数据集上大大优于最好的专业架构

模型结构

在这里插入图片描述

Mask2Former的结构和MaskFormer类似,由一个主干网络,一个像素解码器,一个Transformer解码器组成。Mask2Former提出了一个新的Transformer解码器,该解码器使用掩码注意力机制代替传统的交叉注意力机制。为了处理尺寸较小的物体,Mask2Former每次将来自于像素解码器的多尺度特征的一个尺度馈送到Transformer解码器层。除此之外Mask2Former交换了自注意力和交叉注意力(掩码注意力)的顺序,使查询特征可学习,并去除dropout层结构式计算更有效

掩码分类准备

掩码分类架构通过预测N个二进制掩码,以及N个相应的类别标签,将像素分成N个块。掩码分类通过将不同的语义(类别或实例)分配给不同的片段来解决任何分割任务。然而,为每个片段找到好的语义表示具有挑战性,例如Mask RCNN使用边界框作为表示,这限制了它在语义分割中的应用。受DETR的启发,图像中的每个片段可以表示为C维特征向量(对象查询),由Transformer解码器处理,该解码器使用集合预测目标进行训练

一个简单的元架构由三个组件组成:

  • 一个主干网络:从图像中提取低分辨率特征。
  • 一个像素解码器:从主干的输出逐步对低分辨率特征进行上采样,以生成高分辨率逐像素嵌入。
  • 一个Transformer解码器,利用对象查询和图像特征进行交互,以丰富对象查询中包含的语义信息。
  • 二值掩码预测:从逐像素嵌入的对象查询解码出最终的二进制掩码预测
  • 带有掩码机制的Transformer解码器

Transformer解码器的关键组件包括一个掩码注意算子,它通过将每个查询的交叉注意力限制在其预测掩码的前景区域,而不是关注完整的特征图来提取局部特征。为了处理小物体,Mask2Former提出了一种有效地多尺度策略来利用高分辨率特征。它以循环的方式将像素解码器特征金字塔的连续特征映射馈送到连续的Transformer解码器层。Mask2Former的改进如下:

  • 掩码注意力机制

最近的研究表明,基于Transformer的模型收敛缓慢是由于交叉注意力层中关注全局上下文信息,因此交叉注意力需要许多训练轮才能学会关注局部对象区域。Mask2Former假设局部特征足以更新查询特征,且全局上下文信息可以通过自我注意力来收集。为此,Mask2Former提出了掩码注意,这是一种交叉注意的变体,它只关注每个查询预测掩码的前景区域。Mask2Former的掩码注意力机制如下计算

在这里插入图片描述

高分辨率特征   高分辨率提高了模型的性能,特别是小物体的准确率,但是这对计算要求很高。因此,Mask2Former提出了一种有效的多尺度策略,在控制计算量增加的同时引入了高分辨率特征。它不总是使用高分辨率特征图,而是利用由低分辨率到高分辨率特征组成的特征金字塔,一次将多尺度特征的一个分辨率馈送到一个Transformer解码器层。   具体来说,Mask2Former使用像素解码器产生的特征金字塔,分辨率为原始图像的1/32,1/16,1/8。对于每个分辨率,添加一个正弦位置嵌入epos∈RHlWl×Cepo**s​∈RHlW**l​×C,以及一个可学习的尺度级嵌入elvl∈R1×Celv**l​∈RC。在Transformer解码器层依次使用从最低分辨率到最高分辨率的那些图像,并且重复这个 3 层 Transformer 解码器 L 次。最终的Transformer解码器为3L层。

优化改进   一个标准的Transformer解码器层由三个模块(自我注意模块,交叉注意和前馈网络)组成,按照顺序处理查询特征。查询特征x0x0​在送入Transformer解码器之前被初始化为零,并与可学习的位置嵌入相关联。dropout应用于残差连接和注意力图。为了优化Transformer的解码器设计,Mask2Former进行了以下三个改进。

  • Mask2Former切换了自注意力和交叉注意力的顺序。第一层自注意层的查询特征与图像无关,不具有来自图像的信息,因此应用自注意不太可能丰富信息。
  • Mask2Former使查询特征x0x0也是可学习的(仍然保留可学习的查询位置嵌入),并且可学习的查询特征在被用于Transformer解码器中预测掩码M0M0之前直接被监督。
  • Mask2Former发现dropout是不必要的,而且通常会降低性能,完全消除了解码器的dropout。

实验

在这里插入图片描述

演示效果

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

核心逻辑

像素解码器

    def forward_features(self, features):
        srcs = []
        pos = []
        # Reverse feature maps into top-down order (from low to high resolution)
        # 将其通道维数全部转变为256
        for idx, f in enumerate(self.transformer_in_features[::-1]):
            x = features[f].float()  # deformable detr does not support half precision
            srcs.append(self.input_proj[idx](x))
            pos.append(self.pe_layer(x)) # 存放有关像素分辨率的sine位置编码不可学习
        # y: [1,43008,256] 将不同大小的特征图进行拼接
        y, spatial_shapes, level_start_index = self.transformer(srcs, pos)
        bs = y.shape[0]

        split_size_or_sections = [None] * self.transformer_num_feature_levels
        
        for i in range(self.transformer_num_feature_levels):
            if i < self.transformer_num_feature_levels - 1:
                split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i]
            else:
                split_size_or_sections[i] = y.shape[1] - level_start_index[i]
        y = torch.split(y, split_size_or_sections, dim=1)

        out = []
        multi_scale_features = []
        num_cur_levels = 0
        for i, z in enumerate(y):
            # z:[1,2048,256]->[1,256,2048]->[1,256,32,64]
            out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1]))
        # append `out` with extra FPN levels
        # 

文章代码资源点击附件获取