Resolution-robust Large Mask Inpainting with Fourier Convolutions这篇论文的代码:lama\
其模型的的生成器部分是否每一层都有用到快速傅里叶卷积(FFC)?
答:生成器部分每一层都有用到FFC。
- 这篇论文采用FFC实现感受野覆盖整张图片,而不是3x3或7x7的小块滤波器滑动提取特征。
- 这篇文章还提出使用基于具有高感受野的感知损失。损失函数促进了全局框架和形状的一致性。
Struceflow和PEN-NET的基础架构都是U-Net。
PEN-Net的注意力模块总结
文中介绍:首先,一旦紧凑的潜在特征被编码,金字塔上下文编码器通过提出的注意力转移网络(ATN)将填充区域从高层特征图填充到低层特征图(具有更丰富的细节),进一步提高了编码效率。
ATN的实现过程
首先定义斑块提取函数extract_patches(x, kernel=3, stride=1):
# extract patches
def extract_patches(x, kernel=3, stride=1):
if kernel != 1:
x = nn.ZeroPad2d(1)(x)
x = x.permute(0, 2, 3, 1)
# 手动切割图片块,只卷不积
all_patches = x.unfold(1, kernel, stride).unfold(2, kernel, stride)
return all_patches
该函数用于将高层次特征图和低层次特征图切割成4X4大小的块。
问题:既然切割成4X4的大小,为什么不直接raw_w = extract_patches(x1, kernel=4)呢? 答:切割的通道数要和yi(由高层特征卷积获得)一致,因此x1要先按3x3展开,而x2一直是3x3,图上标错了。
# extract patches from low-level features maps x1 with stride and rate
kernel = 2 * self.rate
raw_w = extract_patches(x1, kernel=self.ksize, stride=self.rate*self.stride)
raw_w = raw_w.contiguous().view(x1s[0], -1, x1s[1], kernel, kernel) # B*HW*C*K*K
# split tensors by batch dimension; tuple is returned
raw_w_groups = torch.split(raw_w, 1, dim=0)
接下来就是生成attention score:
for xi, wi, raw_wi, mi in zip(f_groups, w_groups, raw_w_groups, mm_groups):
# matching based on cosine-similarity
wi = wi[0]
escape_NaN = torch.FloatTensor([1e-4])
if torch.cuda.is_available():
escape_NaN = escape_NaN.cuda()
# normalize
wi_normed = wi / torch.max(torch.sqrt((wi*wi).sum([1,2,3],keepdim=True)), escape_NaN)
print("wi_normed:",wi_normed.size())
print("xi:", xi.size())
yi = F.conv2d(xi, wi_normed, stride=1, padding=padding)
print("yi_1:", yi.size())
yi = yi.contiguous().view(1, x2s[2]//self.stride*x2s[3]//self.stride, x2s[2], x2s[3])
# apply softmax to obtain
print("mi:", mi.size())
yi = yi * mi
# 为什么要乘以scale=10?
yi = F.softmax(yi*scale, dim=1) # 归一化,所有数值之和为1
# mi为什么要乘两次?可能使担心归一化后数值变化。
yi = yi * mi
# 压缩yi的元素
yi = yi.clamp(min=1e-8)
这边xi存储高层特征,wi存储高层特征的patch(3x3),raw_wi存储低层特征的patch(4x4),mi存储块状掩膜。yi就是attention score。最后就是反卷积yi,卷积核为底层特征patch。
# attending
wi_center = raw_wi[0]
print("yi:", yi.size())
print("wi_center:", wi_center.size())
# 为甚么反卷积后要每个值除以4,x1切成块作为反卷积的滤波器
# 反卷积就是将低分辨率的图像通过0填充成很大的图像再卷积成高分辨率的图像
yi = F.conv_transpose2d(yi, wi_center, stride=self.rate, padding=1) / 4.
y.append(yi)
最后用齿状的扩张卷积进行细化处理(其原理参考:https://blog.csdn.net/chaipp0607/article/details/99671483:
y = torch.cat(y, dim=0)
print('y:', y.size())
# y.contiguous().view(x1s)
y.contiguous().view(y.size())
# adjust after filling
if self.fuse:
tmp = []
for i in range(self.groups):
tmp.append(self.__getattr__('conv{}'.format(str(i).zfill(2)))(y))
y = torch.cat(tmp, dim=1)
return y