Pytorch 搭建 SearchTransfer
SearchTransfer源自论文Learning Texture Transformer Network for Image Super-Resolution
其主要思想类似于self-attention,但self-attention只是计算(B, HW, C)跟(B, C, HW)的batch-wise matrix multiplication,这篇文章不是将输入直接view成(B, HW, C),而是用卷积滑动窗口的方式展开成(B, 每个block的像素个数, num_blocks),然后做hard-attention。
本文记录了复现transformer module中遇到的一些用法
关键函数
torch.nn.functional.unfold
torch.nn.functional.fold
torch.expand
torch.gather
unfold展开方便做blocks间的attention,然后利用得到的相似图计算索引来提取ref_unfold中的信息,最后用fold还原
1. unfold
unfold
用 与nn.Conv2d
相同的滑动窗口 将输入划分为一个个blocks
import torch
import torch.nn.functional as F
x = torch.rand((1, 3, 5, 5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape) # torch.Size([1, 3, 5, 5])
print(x_unfold.shape) # torch.Size([1, 27, 25])
x的形状为(batch,channel,H,W),可以看到x_unfold的shape为(batch,k x k x channel, number_blocks)
k是kernel_size,k x k x channel表示一个blocks中的像素个数
number_blocks是在给定kernel_size, padding,stride的情况下,可以滑出几个block
2. fold
fold的用法与unfold相反,是将一个个blocks还原回(batch,channel,H,W)的样子
k = 6
s = 2
p = (k - s) // 2
H, W = 100, 100
x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape) # torch.Size([1, 108, 2500])
print(x_fold.shape) # torch.Size([1, 3, 10, 10])
print(x.mean()) # tensor(0.5012)
print(x_fold.mean()) # tensor(4.3924)
可以看到虽然形状是还原了,但x和x_fold的值域发生了变化,这是因为unfold的时候一个位置(1x1xchannel)可以出现在多个blocks中,因此fold的时候会求和这些重叠的位置,导致了数据不一致。因此得出x_fold后还需要除以重叠数才能得出原始数据范围。k=6,s=2时,一个位置会出现在3*3=9个blocks中(窗口上下左右滑动)。
x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean()) # tensor(0.4998)
print(x_fold.mean()) # tensor(0.4866)
print((x[:, :, 30:40, 30:40] == x_fold[:, :, 30:40, 30:40]).sum()) # tensor(189)
由sum()可以看出只有部分数据被还原了。还有一种准确计算divisor(如3. x 3.)的方法是用torch.ones作输入。
k = 5
s = 3
p = (k - s) // 2
H, W = 100, 100
x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
ones = torch.ones((1, 3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
x_fold = x_fold / ones_fold
print(x.mean()) # tensor(0.5001)
print(x_fold.mean()) # tensor(0.5001)
print((x == x_fold).sum()) # tensor(30000) 每个点都被还原了
3. expand
用法Tensor.expand(*size)
,在size中可以用-1代表保持不变的维度
x = torch.rand((1, 4)) # x = torch.rand(4) 也可以得到同样的结果
x_expand1 = x.expand((3, 4))
x_expand2 = x.expand((3, -1))
print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])
print(x_expand1)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
print(x_expand2)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914],
# [0.1745, 0.2331, 0.5449, 0.1914]])
4. gather
用法torch.gather(input, dim, index, *, sparse_grad=False, out=None)
,效果如下
for i in range(dim0):
for j in range(dim1):
for k in range(dim2):
out[i, j, k] = input[index[i][j][k], j, k] # if dim == 0
out[i, j, k] = input[i, index[i][j][k], k] # if dim == 1
out[i, j, k] = input[i, j, index[i][j][k]] # if dim == 2
使用gather时首先用expand使index的size与input相等。
如index.shape == [B, blocks]
,用expand将index.shape变为[B,c x c x k,blocks],这样index[i, :, k]
是一个1D tensor,且每个元素值都等于expand之前的index[i, j]
如此,当 j 变化时index[i][j][k]
就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]
就将 out中的第k个block 和 input中的第index[i][j][k]
个block 的每个点一一对应(遍历j)起来。
5. 搭建 Features Transfer
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transfer(nn.Module):
def __init__(self):
super(Transfer, self).__init__()
def bis(self, unfold, dim, index):
"""
block index select
args:
unfold: [B, k*k*C, Hr*Wr]
dim: 哪个维度是blocks
index: [B, H*W], value range is [0, Hr*Wr-1]
return: [B, k*k*C, H*W]
"""
views = [unfold.size(0)] + [-1 if i == dim else 1 for i in range(1, len(unfold.size()))] # [B, 1, -1(H*W)]
expanse = list(unfold.size())
expanse[0] = -1
expanse[dim] = -1 # [-1, k*k*C, -1]
index = index.view(views).expand(expanse) # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
return torch.gather(unfold, dim, index) # return[i][j][k] = unfold[i][j][index[i][j][k]]
def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3):
"""
args:
lrsr_lv3: [B, C, H, W]
refsr_lv3: [B, C, Hr, Wr]
ref_lv1: [B, C, Hr*4, Wr*4]
ref_lv2: [B, C, Hr*2, Wr*2]
ref_lv3: [B, C, Hr, Wr]
"""
H, W = lrsr_lv3.size()[-2:]
lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1) # [B, k*k*C, H*W]
refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1, 2) # [B, Hr*Wr, k*k*C]
lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)
R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold) # [B, Hr*Wr, H*W]
score, index = torch.max(R, dim=1) # [B, H*W]
ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1) # vgg19
ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2) # lv1->lv2, lv2->lv3有一次max pooling
ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4) # kernel_size没有按照真实的感受野计算
# 被除数,记录fold(unfold)时的overlap
divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)
T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index) # [B, k*k*C, H*W]
T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)
divisor_lv3 = self.bis(divisor_lv3, 2, index) # [B, k*k*C, H*W]
divisor_lv2 = self.bis(divisor_lv2, 2, index)
divisor_lv1 = self.bis(divisor_lv1, 2, index)
divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)
T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1
score = score.view(lrsr_lv3.size(0), 1, H, W) # [B, 1, H, W]
return score, T_lv1, T_lv2, T_lv3
**bis中gather的解释:**使用gather时首先用expand使index的size与input相等。
如index.shape == [B, blocks]
,用expand将index.shape
变为[B,c x c x k,blocks],这样index[i, :, k]
是一个1D tensor,且每个元素值都等于expand之前的index[i, j]
如此,当 j 变化时index[i][j][k]
就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]
就将 out中的第k个block 和 input中的第index[i][j][k]
个block 的每个点一一对应(遍历j)起来。