YOLO26改进 - 注意力机制 _ GC Block(GlobalContext Block)全局上下文块:三重变换捕获全局依赖,提升复杂场景鲁棒性

2 阅读5分钟

前言

本文介绍了全局上下文块(GC Block)及其在YOLO26中的集成应用。GC Block是GCNet的核心组件,结合了NLNet和SENet的优势,通过上下文建模、特征变换和特征融合三个模块,高效捕获特征图中的全局依赖关系,在提高模型性能的同时降低计算成本。我们将GC Block引入YOLO26,在检测头部分的不同尺度特征图上应用该模块。实验表明,改进后的YOLO26在目标检测任务中表现良好,展现了GC Block在实际应用中的价值。

文章目录: YOLO26改进大全:卷积层、轻量化、注意力机制、损失函数、Backbone、SPPF、Neck、检测头全方位优化汇总

专栏链接: YOLO26改进专栏

@[TOC]

介绍

image-20240718164344690

摘要

非局部网络(NLNet)通过聚合特定查询位置的全局上下文,为捕获长程依赖性提供了开创性方法。然而,经过严格的实证分析,我们发现非局部网络在同一图像的不同查询位置所建模的全局上下文呈现出高度相似性。基于这一发现,本文构建了一个基于查询无关公式的简化网络架构,该架构在保持NLNet准确性的同时显著降低了计算复杂度。进一步研究表明,该简化设计在结构上与挤压-激励网络(SENet)具有相似性,因此我们将二者统一到一个包含三个步骤的通用框架中,用于全局上下文建模。在此通用框架基础上,我们设计了一个更为优越的实例化模块,称为全局上下文(GC)块,该模块具有轻量化特性且能够有效建模全局上下文信息。得益于其轻量化设计,我们可以将该模块应用于骨干网络的多个层级,从而构建全局上下文网络(GCNet)。实验结果表明,GCNet在多种识别任务的主要基准测试中均表现出优于简化NLNet和SENet的性能。相关代码与配置文件已发布于:github.com/xvjiarui/GC…

文章链接

论文地址:论文地址

代码地址:代码地址

参考代码代码地址

基本原理

GC Block 详细介绍

全局上下文块(Global Context Block, GC Block)是Global Context Network(GCNet)的核心组件,设计用来高效捕获特征图中的全局依赖关系。它结合了非局部网络(NLNet)和挤压-激励网络(SENet)的优势,具体结构如下:

1. 上下文建模模块(Context Modeling Module)

这个模块的主要目的是聚合所有位置的特征形成全局上下文特征,具体步骤如下:

  • 输入特征图:假设输入特征图为 XRC×H×WX \in \mathbb{R}^{C \times H \times W},其中 CC 表示通道数,HHWW 分别表示特征图的高度和宽度。
  • 空间维度压缩:通过全局平均池化操作将空间维度压缩为单个向量,得到全局上下文特征 zRCz \in \mathbb{R}^Cz=1H×Wi=1Hj=1WXijz = \frac{1}{H \times W} \sum_{i=1}^H \sum_{j=1}^W X_{ij}
  • 注意力权重计算:使用一个全连接层或1x1卷积层,将全局上下文特征变换为注意力权重 WzRCW_z \in \mathbb{R}^{C}Wz=σ(W1z)W_z = \sigma(W_1 z) 其中,W1W_1 是可学习的权重矩阵,σ\sigma 是激活函数(通常为softmax)。
2. 特征变换模块(Feature Transform Module)

这个模块用于捕获特征图中通道之间的依赖关系:

  • 瓶颈变换:使用两层1x1卷积和ReLU激活函数,进行瓶颈变换以减少计算复杂度: y=W2(δ(W1X))y = W_2 (\delta(W_1 X)) 其中,W1W_1W2W_2 是可学习的权重矩阵,δ\delta 是ReLU激活函数。
3. 特征融合模块(Feature Fusion Module)

这个模块的目的是将全局上下文特征融合到每个查询位置的特征中:

  • 特征融合:通过加法操作,将全局上下文特征 zz 融合到每个位置的特征 XijX_{ij} 中,得到增强的特征图 YRC×H×WY \in \mathbb{R}^{C \times H \times W}Yij=Xij+zY_{ij} = X_{ij} + z

GC Block 流程总结

  1. 输入特征图XRC×H×WX \in \mathbb{R}^{C \times H \times W}
  2. 全局上下文建模:通过全局平均池化和全连接层计算全局上下文特征 zz
  3. 特征变换:使用瓶颈变换模块对特征图进行变换。
  4. 特征融合:将全局上下文特征 zz 融合到每个位置的特征 XijX_{ij} 中。

总结

GC Block 通过上下文建模、特征变换和特征融合三个模块,高效地捕获并利用图像中的全局上下文信息。这种设计不仅显著提高了模型在各种视觉识别任务中的性能,还保持了较低的计算成本和内存消耗。因此,GC Block 在实际应用中具有很高的实用价值。

核心代码

import torch
from mmcv.cnn import constant_init, kaiming_init
from torch import nn


def last_zero_init(m):
    if isinstance(m, nn.Sequential):
        constant_init(m[-1], val=0)
    else:
        constant_init(m, val=0)


class ContextBlock(nn.Module):

    def __init__(self,
                 inplanes,
                 ratio,
                 pooling_type='att',
                 fusion_types=('channel_add', )):
        super(ContextBlock, self).__init__()
        assert pooling_type in ['avg', 'att']
        assert isinstance(fusion_types, (list, tuple))
        valid_fusion_types = ['channel_add', 'channel_mul']
        assert all([f in valid_fusion_types for f in fusion_types])
        assert len(fusion_types) > 0, 'at least one fusion should be used'
        self.inplanes = inplanes
        self.ratio = ratio
        self.planes = int(inplanes * ratio)
        self.pooling_type = pooling_type
        self.fusion_types = fusion_types
        if pooling_type == 'att':
            self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
            self.softmax = nn.Softmax(dim=2)
        else:
            self.avg_pool = nn.AdaptiveAvgPool2d(1)
        if 'channel_add' in fusion_types:
            self.channel_add_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_add_conv = None
        if 'channel_mul' in fusion_types:
            self.channel_mul_conv = nn.Sequential(
                nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
                nn.LayerNorm([self.planes, 1, 1]),
                nn.ReLU(inplace=True),  # yapf: disable
                nn.Conv2d(self.planes, self.inplanes, kernel_size=1))
        else:
            self.channel_mul_conv = None
        self.reset_parameters()

    def reset_parameters(self):
        if self.pooling_type == 'att':
            kaiming_init(self.conv_mask, mode='fan_in')
            self.conv_mask.inited = True

        if self.channel_add_conv is not None:
            last_zero_init(self.channel_add_conv)
        if self.channel_mul_conv is not None:
            last_zero_init(self.channel_mul_conv)

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        if self.pooling_type == 'att':
            input_x = x
            # [N, C, H * W]
            input_x = input_x.view(batch, channel, height * width)
            # [N, 1, C, H * W]
            input_x = input_x.unsqueeze(1)
            # [N, 1, H, W]
            context_mask = self.conv_mask(x)
            # [N, 1, H * W]
            context_mask = context_mask.view(batch, 1, height * width)
            # [N, 1, H * W]
            context_mask = self.softmax(context_mask)
            # [N, 1, H * W, 1]
            context_mask = context_mask.unsqueeze(-1)
            # [N, 1, C, 1]
            context = torch.matmul(input_x, context_mask)
            # [N, C, 1, 1]
            context = context.view(batch, channel, 1, 1)
        else:
            # [N, C, 1, 1]
            context = self.avg_pool(x)

        return context

    def forward(self, x):
        # [N, C, 1, 1]
        context = self.spatial_pool(x)

        out = x
        if self.channel_mul_conv is not None:
            # [N, C, 1, 1]
            channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
            out = out * channel_mul_term
        if self.channel_add_conv is not None:
            # [N, C, 1, 1]
            channel_add_term = self.channel_add_conv(context)
            out = out + channel_add_term

        return out

实验

脚本

import warnings
warnings.filterwarnings('ignore')
from ultralytics import YOLO
 
if __name__ == '__main__':
#     修改为自己的配置文件地址
    model = YOLO('./ultralytics/cfg/models/26/yolo26-GlobalContext.yaml')
#     修改为自己的数据集地址
    model.train(data='./ultralytics/cfg/datasets/coco8.yaml',
                cache=False,
                imgsz=640,
                epochs=10,
                single_cls=False,  # 是否是单类别检测
                batch=8,
                close_mosaic=10,
                workers=0,
                optimizer='MuSGD',  
                # optimizer='SGD',
                amp=False,
                project='runs/train',
                name='yolo26-GlobalContext',
                )
    
 

结果

image-20260127224956941