Spatially Adaptive Residual Networks for Efficient Image and Video Deblurring

453 阅读7分钟

paper

本文是印度马德拉斯理工学院的研究员提出的一种基于空间自适应残差网络的图像/视频去模糊方法。

严重模糊图像复原要求网络具有极大感受野,现有网络往往采用加深网络层数、加大卷积核尺寸或者多尺度方式提升感受野,然而这些方法会早知模型大小的提升以及推理耗时提升。作者提出一种组合形变卷积与自注意力机制的去模糊网络,进一步,集成时序递归模块可以将其扩展到视频去模糊。该网络可以模拟空间可变模糊移除而无需多尺度与大卷积核。最后作者通过实验定性与定量进行分析:在速度、精度以及模型大小方面均取得了SOTA性能。

Abstract

​ 针对已有去模糊方法存在的两个局限性:(1) 空间不变卷积核,对于动态场景去模糊而言并非最优方案,严重限制了去模糊精度;(2) 通过网络深度与卷积核尺寸提升扩大感受野,这会导致模型变大、推理耗时增加。

​ 为此,基于形变卷积与自注意力机制,作者提出一种高效果的端到端的去模糊框架。它与其他SOTA方法的性能对比见下图。

​ 该方法的优点包含以下几点:

  • 全卷积且参数高效,仅需一次性前向过程;
  • 可轻易集成其他架构与损失函数;
  • 网络估计的变换是动态的,因而可以自适应处理测试图像。

Method

​ 上图给出了作者所提出的SARN网络架构示意图,其中编码子网络将输入图像逐渐变换为分辨率更小、通道更多的特征图,在此基础上继续执行空间注意力模块与形变残差模块,最后送入到解码模块中,通过一系列的残差模块与反卷积对其进行重建。注:上图中n=32。

Deformable Residual Module

​ 传统的CNN在固定的网格上进行采样,这限制了其模拟未知几何变换的能力。STN将空间学习引入到CNN中,然而这种变换比较耗时且为全局图像变换,并不适合于局部图像几何变换。作者采用形变卷积,它以一种有效的方法学习局部几何变换。形变卷积首先学习稠密偏移图进行特征重采样,然后再进行卷积操作,该过程见上图中的从Input FeatureOutput Feature的过程。作者在形变卷积基础上引入参考模块,称之为形变残差模块。更多关于形变卷积的介绍与分析建议参考原文Deformable Convolutional Networks

Self-Attention Module

​ 近期的去模糊方法着重于多尺度处理,这种处理方式可以获取不同尺度的运动模糊,提升网络的感受野。尽管这种“自粗而精”的处理策略可以处理不同程度的模糊,但是它无法从全局角度利用模糊区域之间的相关性,而这对于复原任务也很重要。为此,作者提出采用:在不同空间分辨率利用注意力机制学习非局部关联性。

​ 用于模拟长范围依赖关系的注意力机制已在多个领域(跨语言与视觉应用)取得了成功。作者采用非局部注意力进行不同场景区域之间的关联性学习并用于提升图像复原质量。

​ 上图给出了作者所提出的SAM模块示意图。它有如下两点优势:

  • 它克服了感受野有限的局限性;
  • 它隐含的提供了一种可以传播相对信息的通路。

上述优势使得它适合于处理去模糊,这是因为:因模糊导致的场景-边缘之间往往是相关的。

​ 以上图为例,给定输入特征A\in R^{C \times H \times W},首先,将其送入两个1\times1卷积得到两个新的特征B和C,其中\{B, C\} \in R^{\hat{C} \times H \times W};然后,将其进行reshape为R^{\hat{C} \times N};其次,对B和C进行矩阵乘操作并执行softmax得到空间注意力特征S \in R^{N \times N}(s_{ji}可以度量i位置与j位置的影响关系),计算方式如下公式所示。最后,将特征A经由另一个1\times1卷积得到特征D\in R^{C\times H \times W},并reshape为R^{C \times N},并将其S进行矩阵乘操作得到增强版特征,将其与特征A相加得到最终的特征E\in R^{C\times H \times W}

s_{ji} = \frac{\mathcal{exp}(B_i \cdot C_j)}{\sum_{i=1}^N \mathcal{exp}(B_i \cdot C_j)} \notag

​ 经由上述操作得到的特征E包含所有位置特征的加权组合以及原始特征。因此它具有全局上下文信息,并按照空间注意力进行上下文信息选择性集成,促使相似特征增强,不相关特征削弱。

​ 作者还发现:将SAM至于DRM之前可以取得更好的性能。猜测原因为:早期的特征增强有助于提升网络的非局部性。

Video Deblurring

​ 图像去模糊一种很自然的扩展是视频去模糊,作者采用LSTM进行前后帧特征集成,该过程可以描述为:

\begin{split}
f^i &= Net_E(B^i, I^{i-1})  \\
h^i, g^i &= ConvLSTM(h^{i-1}, f^i; \theta_{LSTM})  \\
I^i &= Net_D(g^i; \theta_D)
\end{split}

在视频去模糊中,它以5帧作为输入,输出中间帧的去模糊效果图。

Experiments

​ 在训练过程中,相关参数配置如下:

  • 对于图像去模糊任务,训练数据为GoPro,优化器为Adam,学习率为0.0001,BatchSize=4,训练迭代次数为1百万.
  • 对于视频去模糊任务,优化器Adam,学习率0.0001,BatchSize=4,迭代次数3百万。

​ 下图给出在GoPro数据集上相关去模糊方法的性能与视觉效果对比。更多实验结果与分析建议参考原文,这里不再赘述。

​ 下面给出了在视频去模糊任务上的性能与视觉效果对比。更多实验结果与分析建议参考原文,这里不再赘述。

Concolusion

​ 作者结合形变卷积、自注意力机制提出一种有效的图像/视频去模糊方法。其中形变卷积残差模块可以解决局部模糊的局部信息偏移问题;而自注意力机制则可以对不同模糊区域建立关联性,从而提升特征性能。自注意力机制与形变卷积均可提升网络的感受野,同时具有高效性。最后作者通过实验验证了所提方法的SOTA性能。

参考代码

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmdet.ops import DeformConvPack

# GPU 
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# DeformConv copy from mmdetection.
class DeformResModule(nn.Module):
    def __init__(self, inc, ksize):
        super(DeformResModule, self).__init__()
        pad = (ksize-1)//2
        self.dconv = DeformConvPack(inc,inc,ksize,1,padding=pad)
    def forward(self, x):
        res = self.dconv(x)
        return res + x

class ResBlock(nn.Module):
    def __init__(self, inc, ksize):
        super(ResBlock, self).__init__()
        padding = (ksize-1)//2
        self.conv1 = nn.Conv2d(inc, inc, ksize, 1, padding)
        self.conv2 = nn.Conv2d(inc, inc, ksize, 1, padding)
    def forward(self, x):
        res = self.conv2(F.relu(self.conv1(x)))
        return res + x
        
class SAM(nn.Module):
    def __init__(self, inc):
        super(SAM, self).__init__()
        self.convb = nn.Conv2d(inc, inc, 1)
        self.convc = nn.Conv2d(inc, inc, 1)
        self.convd = nn.Conv2d(inc, inc, 1)
        
    def forward(self, x):
        N, C, H, W = x.size()
        featB = self.convb(x)                         #N,C,H,W
        featC = self.convc(x)                         #N,C,H,W
        featD = self.convd(x)                         #N,C,H,W
        
        featB = featB.reshape(N, C, -1)               #N,C, HW
        featC = featC.reshape(N, C, -1)               #N,C, HW
        featC = featC.permute(0, 2, 1)                #N,HW,C
        
        featD = featD.reshape(N, C, -1)               #N,C, HW
        featD = featD.permute(0, 2, 1)                #N,HW,C
        
        featBC = torch.matmul(featC, featB)           #N,HW,HW
        featBC = featBC.softmax(-1)                   #N,HW,HW
        
        fusion = torch.matmul(featBC, featD)          #N,HW,C
        fusion = fusion.permute(0, 2, 1).contiguous() #N,C, HW
        fusion = fusion.reshape(N, C, H, W)           #N,C,H,W
        
        return x + fusion

class Net(nn.Module):
    def __init__(self, inc, outc, midc):
        super(Net, self).__init__()
        mid2 = midc*2
        mid4 = midc*4
        self.ecode1 = nn.Sequential(nn.Conv2d(inc,midc,3,1,1),
                                    nn.ReLU(),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3))
        self.ecode2 = nn.Sequential(nn.Conv2d(midc,mid2,3,2,1),
                                    nn.ReLU(),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3))
        self.ecode3 = nn.Sequential(nn.Conv2d(mid2,mid4,3,2,1),
                                    nn.ReLU(),
                                    SAM(mid4),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3),
                                    DeformResModule(mid4, 3))
        
        self.dcode2 = nn.Sequential(ResBlock(mid2, 3),
                                    ResBlock(mid2, 3),
                                    ResBlock(mid2, 3))
        self.dcode1 = nn.Sequential(ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    ResBlock(midc, 3),
                                    nn.Conv2d(midc, 3, 3, 1, 1))
                
        self.upsample1 = nn.ConvTranspose2d(mid4, mid2, 4, 2, 1)
        self.upsample2 = nn.ConvTranspose2d(mid2, midc, 4, 2, 1)
        
        self.feat1 = nn.Conv2d(midc, midc, 3, 1, 1)
        self.feat2 = nn.Conv2d(midc*2, midc*2, 3, 1, 1)
        
    def forward(self, x):
        encoder1 = self.ecode1(x)
        encoder2 = self.ecode2(encoder1)
        encoder3 = self.ecode3(encoder2)
        decoder3 = self.upsample1(encoder3)
        decoder2 = self.dcode2(decoder3 + self.feat2(encoder2))
        decoder1 = self.upsample2(decoder2)
        output   = self.dcode1(decoder1 + self.feat1(encoder1))
        
        return output
             
        
def main():
    model = Net(3, 3, 32).cuda().eval()
    
    inputs = torch.randn(4, 3, 128, 128).cuda()
    with torch.no_grad():
        output = model(inputs)
    print(output.size())
    
    
if __name__ == "__main__":
    main()