注意力机制——SENet原理详解及源码解析

1,785 阅读3分钟

本文正在参加「金石计划 . 瓜分6万现金大奖」

🍊作者简介:秃头小苏,致力于用最通俗的语言描述问题

🍊专栏推荐:深度学习网络原理与实战

🍊近期目标:写好专栏的每一篇文章

🍊支持小苏:点赞👍🏼、收藏⭐、留言📩

 

SENet原理详解

​  先来简单说说我们为什么需要使用注意力机制,这是因为我们希望网络可以专注于一些更加重要的东西,这对物体的识别定位都大有益处。enmmm,是不是够简单呢。🍦🍦🍦如果你是第一次学习注意力机制,我觉得你会充满疑惑,怎么让网络注意到一些更加重要的东西呢?那么带着疑问,和我一起来看看SENet的原理,等我介绍完后看你能否理解喔。🌿🌿🌿

​  话不多说,我们直接来看SENet的关键结构,如下图所示:

image-20221204203939637

​  我们来介绍一下上图的网络,首先是输入X,其维度为H×W×C{\rm{H' \times W' \times C'}},经过一系列卷积等维度变化操作后得到特征图U,其维度H×W×CH×W×C【注:其实从特征图U开始向后才是真正的SENet的结构,这一步转换只是一些特征图维度变化】 当我们得到U后,会先将U经过全局平均池化的操作,即将U的维度由H×W×CH×W×C变成1×1×C1×1×C,此步骤对应着上图中的Fsq(){F_{sq}}( \cdot )。接着会执行步骤Fex(,W){F_{ex}}( \cdot ,W),此步骤包含两个全连接层已经两个激活函数,为方便大家理解,做此过程的图如下:

image-20221204210751044

​  从上图我们可以看出,在第一次全连接层后我们使用Relu激活函数,此时得到的输出维度为1×1×C1 \times 1 \times {\rm{C''}},通常情况下C{\rm{C''}}设置为CC14\frac{1}{4}。第二个全连接层后使用Sigmoid函数,将每层数值归一化到0-1之间,以此表示每个通道的权重,第二个全连接的输出也为1×1×C1×1×C。得到了最后1×1×C1×1×C的输出后,我们将U和刚刚得到的1×1×C1×1×C输出相乘,得到最终的特征图X~{{\rm{\tilde X}}},最终特征图X~{{\rm{\tilde X}}}的维度和U一致,为H×W×CH×W×C

​  介绍到这里,大家是否明白了呢。如果你还没明白的话,再来看下图吧!!!首先下图左上角表示为两个通道的特征图,经平均池化后得到左下角的图;再次经过两次全连接层和激活函数后,转化成了右下角的图,最后用右下角的0.5、0.6分别乘原始的特质图,则得到最终的右上角的图。可以发现经过SENet特征图输入前后尺寸没有变化,其值发生变化。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-XhC6a19q-1645335379587)(C:\Users\WSJ\AppData\Roaming\Typora\typora-user-images\image-20220219210356512.png)]

SENet代码详解

​ 理解了上文所述的SENet原理,那么编写SENet的代码就非常简单了,如下:

def SENet(input):
    #全局平均池化
    x = nn.AdaptiveAvgPool2d((1,1))(input)
    x = x.view(1, -1)
    #第一个全连接层
    x = nn.Linear(2, 1)(x)
    x = nn.functional.relu(x)
    #第二个全连接层
    x = nn.Linear(1, 2)(x)
    x = nn.functional.sigmoid(x)

    return x


if __name__ == '__main__':
    input = torch.ones(1, 2 ,2 ,2)
    output = SENet(input)
    # 将SENet的输出维度进行变化,以便后面的乘机操作
    output = output.view(input.shape[0], input.shape[1],1, 1)
    SE_output = input*output
    
    print(input)
    print(input.shape)
    print(output)
    print(output.shape)
    print(SE_output)

我们可以来看一下上述代码的输出,如下:

input:

image-20221205113701020

output:

image-20221205113732249

SE_output:

image-20221205113755663

  你可以结合理论部分,再对照这些输出看看是否一致喔。【注意:大家需要注意在最后一步相乘操作前需要先View一下输出output的尺寸,不然乘的结果不一样哦,这涉及到一些pytorch乘法的操作,这部分我也调试了很久,大家可以动手试试看。】

   

如若文章对你有所帮助,那就🛴🛴🛴

         一键三连 (1).gif