SwinIR:使用Swin Transformer进行图像恢复

471 阅读5分钟

持续创作,加速成长!这是我参与「掘金日新计划 · 10 月更文挑战」的第20天,点击查看活动详情

引言

基于CNN的图像恢复方法

  • 专注于精细的架构设计,如残差学习和密集连接
  • 图像和卷积核之间的交互与内容无关,使用相同的卷积核来恢复不同的图像区域可能效果不太好
  • 卷积只能进行局部处理,而无法建模长距离依赖

Vision Transformer

方法:将输入图像分割成固定大小的patch(如48×48),并对每个patch进行独立处理

  • 修复后的图像可能会在每个patch周围引入边界伪影
  • 每个patch的边界像素会丢失信息
  • 可以通过patch重叠来缓解,但会带来额外计算负担

Swin Transformer:集成了CNN和Transformer的优势

  • CNN:局部注意力机制使其具有处理大尺寸图像的优势
  • Transformer:移位窗口(滑动窗口机制)建模长距离依赖

SwinIR

模块组成

  1. 浅层特征提取模块:卷积层
    浅层特征直接传输给重构模块,以保持低频信息(长距离跳跃连接)
  2. 深层特征提取模块
  • 组成:多个残差Swin Transformer块(RSTB)+ 一个卷积层
  • RSTB:利用多个Swin Transformer层进行局部注意和跨窗口交互
  • 卷积层:增强特性,并使用残差连接为特性聚合提供快捷方式
  1. 高质量图像重建模块:融合浅特征和深特征

优点

  • 图像内容和注意力权重之间基于内容交互,可以解释为空间变化的卷积(即卷积核依据图像内容不同而变化)
  • 移动窗口机制实现长距离依赖建模
  • 参数更少,且性能更好

相关工作

Vision Transformer

  • 应用:探索不同区域之间的全局交互来学习关注重要的图像区域
  • 应用于图像恢复
    IPT:可以解决很多恢复问题,但依赖于大量参数、数据集和多任务学习
    VSR-Transformer:采用自注意力机制,可以更好融合视频SR的特征,但特征提取仍依靠CNN完成
    两者均为patch级注意力机制,可能不利于图像恢复

方法

网络架构

image.png

对所有恢复任务使用相同的特征提取模块,不同的重建模块

浅层特征提取

3*3 卷积

  • 卷积擅长前期视觉处理,优化更稳定,结果更好
  • 将输入图像空间映射到高维特征空间的简单方法

深层特征提取

k个RSTB和1个3*3卷积

  • 卷积:将其归纳偏置(空间局部性)引入到基于transformer的网络中,为后期浅、深特征的聚合奠基
  • 归纳偏置:从现实生活中观察到的现象中归纳出一定的规则 ,然后对模型做一定的约束
  • 常见网络的归纳偏置:
  1. 深度神经网络:层次化处理信息更有效
  2. 卷积神经网络:信息具有空间局部性 ,可用滑动卷积共享权重的方式减少参数
  3. 循环神经网络:考虑时序信息,强调顺序重要性
  4. 图网络:中心节点与邻居节点的相似性会更好引导信息流动

图像重建(以SR为例)

上采样:亚像素卷积层

  • 亚像素:将多通道feature上的单个像素组合成一个feature上的单位

image.png

  • 长距离跳跃连接:将低频信息直接传输给重构模块,有助于深度特征提取模块专注于高频信息,从而稳定训练

损失函数

  • 经典和轻量级的图像:原始L1像素损失
  • 真实世界图像:L1像素损失+GAN损失+感知损失

RSTB(残差Swin Transformer块)

image.png

L个STL(Swin Transformer层)+ 一个卷积层 + 残差连接

  • 卷积:增强平移等变性(图像中的对象平移后仍能被准确识别);Transformer提供基于空间变化的卷积,增加空间不变的普通卷积与其互补
  • 残差连接:从不同块到重建模块的基于身份的短连接,提供不同级别的特征聚合

STL(Swin Transformer层)

image.png

基于原始Transformer层的标准多头自注意力(MSA),差异在于局部注意力和移位窗口机制

MSA(multi-head self-attention)

  1. 分割:将输入分割成不重叠的M*M局部窗口
  2. 计算:计算每个窗口的标准自注意力

image.png

X:局部窗口特征,PQ、PK、PV:跨不同窗口共享的投影矩阵

image.png

B:可学习的相对位置编码
实现:多次并行执行Attention函数,将计算结果连接,形成MSA

MLP( multi-layer perceptron)

多层感知器:两个带GELU非线性激活的全连接层,用于进一步特征转换

LN(LayerNorm)

  • LN层标准化:取同一个样本的不同通道做归一化
  • BN批标准化:取不同样本同一通道的特征做归一化

交替使用常规窗口和滑动窗口

  • 问题:不同自注意力层的分区固定导致局部窗口之间没有连接交互

image.png

  • 滑动窗口分区:相邻自注意力层间分割窗口移位。第一个自注意力层中平分特征图,第二个自注意力层从(⌊ M/2 ⌋⌊ M/2 ⌋)处有规律地取代前一层的窗口
  • 优点
  1. 计算复杂度与输入图片大小线性相关:解决了计算复杂度与输入大小呈二次增长的问题
    标准Transformer:一个patch作为一个token,对于每一个token都需要计算其与全部其他token的关系,复杂度与token数量呈二次相关
    Swin Transformer:输入图像划分为窗口,只计算窗口内部各patch间的自注意力
  2. 启用跨窗口交互连接
  • 缺点:部分窗口大小小于M×M,需要用padding补齐,增加了计算量