Pytorch 搭建 SearchTransfer

417 阅读5分钟

Pytorch 搭建 SearchTransfer

SearchTransfer源自论文Learning Texture Transformer Network for Image Super-Resolution

[paper] [code]

其主要思想类似于self-attention,但self-attention只是计算(B, HW, C)跟(B, C, HW)的batch-wise matrix multiplication,这篇文章不是将输入直接view成(B, HW, C),而是用卷积滑动窗口的方式展开成(B, 每个block的像素个数, num_blocks),然后做hard-attention。

本文记录了复现transformer module中遇到的一些用法

关键函数

  • torch.nn.functional.unfold
  • torch.nn.functional.fold
  • torch.expand
  • torch.gather

unfold展开方便做blocks间的attention,然后利用得到的相似图计算索引来提取ref_unfold中的信息,最后用fold还原

1. unfold

unfold用 与nn.Conv2d相同的滑动窗口 将输入划分为一个个blocks

import torch
import torch.nn.functional as F

x = torch.rand((1, 3, 5, 5))
x_unfold = F.unfold(x, kernel_size=3, padding=1, stride=1)
print(x.shape)	# torch.Size([1, 3, 5, 5])
print(x_unfold.shape)	# torch.Size([1, 27, 25])

x的形状为(batch,channel,H,W),可以看到x_unfold的shape为(batch,k x k x channel, number_blocks)

k是kernel_size,k x k x channel表示一个blocks中的像素个数

number_blocks是在给定kernel_size, padding,stride的情况下,可以滑出几个block

2. fold

fold的用法与unfold相反,是将一个个blocks还原回(batch,channel,H,W)的样子

k = 6
s = 2
p = (k - s) // 2
H, W = 100, 100

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)
print(x_unfold.shape)	# torch.Size([1, 108, 2500])
print(x_fold.shape)		# torch.Size([1, 3, 10, 10])
print(x.mean())			# tensor(0.5012)
print(x_fold.mean())	# tensor(4.3924)

可以看到虽然形状是还原了,但x和x_fold的值域发生了变化,这是因为unfold的时候一个位置(1x1xchannel)可以出现在多个blocks中,因此fold的时候会求和这些重叠的位置,导致了数据不一致。因此得出x_fold后还需要除以重叠数才能得出原始数据范围。k=6,s=2时,一个位置会出现在3*3=9个blocks中(窗口上下左右滑动)。

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p) / (3.*3.)
print(x_unfold.shape)
print(x_fold.shape)
print(x.mean())			# tensor(0.4998)
print(x_fold.mean())	# tensor(0.4866)
print((x[:, :, 30:40, 30:40] == x_fold[:, :, 30:40, 30:40]).sum()) # tensor(189)

由sum()可以看出只有部分数据被还原了。还有一种准确计算divisor(如3. x 3.)的方法是用torch.ones作输入。

k = 5
s = 3
p = (k - s) // 2
H, W = 100, 100

x = torch.rand((1, 3, H, W))
x_unfold = F.unfold(x, kernel_size=k, stride=s, padding=p)
x_fold = F.fold(x_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

ones = torch.ones((1, 3, H, W))
ones_unfold = F.unfold(ones, kernel_size=k, stride=s, padding=p)
ones_fold = F.fold(ones_unfold, output_size=(H, W), kernel_size=k, stride=s, padding=p)

x_fold = x_fold / ones_fold
print(x.mean())			# tensor(0.5001)
print(x_fold.mean())	# tensor(0.5001)
print((x == x_fold).sum())	# tensor(30000) 每个点都被还原了

3. expand

用法Tensor.expand(*size),在size中可以用-1代表保持不变的维度

x = torch.rand((1, 4))	# x = torch.rand(4) 也可以得到同样的结果
x_expand1 = x.expand((3, 4))
x_expand2 = x.expand((3, -1))

print(x)
# tensor([[0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand1)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914]])

print(x_expand2)
#tensor([[0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914],
#        [0.1745, 0.2331, 0.5449, 0.1914]])

4. gather

用法torch.gather(input, dim, index, *, sparse_grad=False, out=None),效果如下

for i in range(dim0):
    for j in range(dim1):
        for k in range(dim2):
            out[i, j, k] = input[index[i][j][k], j, k]  # if dim == 0
			out[i, j, k] = input[i, index[i][j][k], k]  # if dim == 1
			out[i, j, k] = input[i, j, index[i][j][k]]  # if dim == 2

使用gather时首先用expand使index的size与input相等。

index.shape == [B, blocks],用expand将index.shape变为[B,c x c x k,blocks],这样index[i, :, k]是一个1D tensor,且每个元素值都等于expand之前的index[i, j]

如此,当 j 变化时index[i][j][k]就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]就将 out中的第k个block 和 input中的第index[i][j][k]个block 的每个点一一对应(遍历j)起来。

5. 搭建 Features Transfer

import torch
import torch.nn as nn
import torch.nn.functional as F


class Transfer(nn.Module):
    def __init__(self):
        super(Transfer, self).__init__()

    def bis(self, unfold, dim, index):
        """
        block index select
        args:
            unfold: [B, k*k*C, Hr*Wr]
            dim: 哪个维度是blocks
            index: [B, H*W],  value range is [0, Hr*Wr-1]
            return: [B, k*k*C, H*W]
        """
        views = [unfold.size(0)] + [-1 if i == dim else 1 for i in range(1, len(unfold.size()))]  # [B, 1, -1(H*W)]
        expanse = list(unfold.size())
        expanse[0] = -1
        expanse[dim] = -1   # [-1, k*k*C, -1]
        index = index.view(views).expand(expanse)   # [B, H*W] -> [B, 1, H*W] -> [B, k*k*C, H*W]
        return torch.gather(unfold, dim, index)    # return[i][j][k] = unfold[i][j][index[i][j][k]]

    def forward(self, lrsr_lv3, refsr_lv3, ref_lv1, ref_lv2, ref_lv3):
        """
            args:
                lrsr_lv3: [B, C, H, W]
                refsr_lv3: [B, C, Hr, Wr]
                ref_lv1: [B, C, Hr*4, Wr*4]
                ref_lv2: [B, C, Hr*2, Wr*2]
                ref_lv3: [B, C, Hr, Wr]
        """
        H, W = lrsr_lv3.size()[-2:]

        lrsr_lv3_unfold = F.unfold(lrsr_lv3, kernel_size=3, padding=1, stride=1)    # [B, k*k*C, H*W]
        refsr_lv3_unfold = F.unfold(refsr_lv3, kernel_size=3, padding=1, stride=1).transpose(1, 2)  # [B, Hr*Wr, k*k*C]

        lrsr_lv3_unfold = F.normalize(lrsr_lv3_unfold, dim=1)
        refsr_lv3_unfold = F.normalize(refsr_lv3_unfold, dim=2)

        R = torch.bmm(refsr_lv3_unfold, lrsr_lv3_unfold)  # [B, Hr*Wr, H*W]
        score, index = torch.max(R, dim=1)  # [B, H*W]

        ref_lv3_unfold = F.unfold(ref_lv3, kernel_size=3, padding=1, stride=1)      # vgg19
        ref_lv2_unfold = F.unfold(ref_lv2, kernel_size=6, padding=2, stride=2)      # lv1->lv2, lv2->lv3有一次max pooling
        ref_lv1_unfold = F.unfold(ref_lv1, kernel_size=12, padding=4, stride=4)     # kernel_size没有按照真实的感受野计算

        # 被除数,记录fold(unfold)时的overlap
        divisor_lv3 = F.unfold(torch.ones_like(ref_lv3), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.unfold(torch.ones_like(ref_lv2), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.unfold(torch.ones_like(ref_lv1), kernel_size=12, padding=4, stride=4)

        T_lv3_unfold = self.bis(ref_lv3_unfold, 2, index)   # [B, k*k*C, H*W]
        T_lv2_unfold = self.bis(ref_lv2_unfold, 2, index)
        T_lv1_unfold = self.bis(ref_lv1_unfold, 2, index)

        divisor_lv3 = self.bis(divisor_lv3, 2, index)  # [B, k*k*C, H*W]
        divisor_lv2 = self.bis(divisor_lv2, 2, index)
        divisor_lv1 = self.bis(divisor_lv1, 2, index)

        divisor_lv3 = F.fold(divisor_lv3, (H, W), kernel_size=3, padding=1, stride=1)
        divisor_lv2 = F.fold(divisor_lv2, (2*H, 2*W), kernel_size=6, padding=2, stride=2)
        divisor_lv1 = F.fold(divisor_lv1, (4*H, 4*W), kernel_size=12, padding=4, stride=4)

        T_lv3 = F.fold(T_lv3_unfold, (H, W), kernel_size=3, padding=1, stride=1) / divisor_lv3
        T_lv2 = F.fold(T_lv2_unfold, (2*H, 2*W), kernel_size=6, padding=2, stride=2) / divisor_lv2
        T_lv1 = F.fold(T_lv1_unfold, (4*H, 4*W), kernel_size=12, padding=4, stride=4) / divisor_lv1

        score = score.view(lrsr_lv3.size(0), 1, H, W)   # [B, 1, H, W]

        return score, T_lv1, T_lv2, T_lv3

**bis中gather的解释:**使用gather时首先用expand使index的size与input相等。

index.shape == [B, blocks],用expand将index.shape变为[B,c x c x k,blocks],这样index[i, :, k]是一个1D tensor,且每个元素值都等于expand之前的index[i, j]

如此,当 j 变化时index[i][j][k]就不会变,故循环中的out[i, j, k] = input[i, j, index[i][j][k]]就将 out中的第k个block 和 input中的第index[i][j][k]个block 的每个点一一对应(遍历j)起来。

参考

pytorch.org/docs/stable…

github.com/researchmm/…