Pytorch nn.Fold()的简单理解与用法

1,101 阅读3分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。

官方文档:pytorch.org/docs/stable…

这个东西基本上就是绑定Unfold使用的。实际上,在没有overlapping、参数相同的情况下,其与Unfold操作是互逆的。

官方对该函数作用的描述如下: ...This operation combines these local blocks into the large output tensor by summing the overlapping values.... 这一操作通过对重叠的数值进行求和,将这些局部块结合到大的输出tensor中

说的比较含糊,那我们先上代码试一下unfold。对于一张1×1×4×4的特征图:

[[[[  1,  2,  3,  4],
   [  5,  6,  7,  8],
   [  9, 10, 11, 12],
   [ 13, 14, 15, 16]]]]

对其进行2×2,stride=2的滑动窗口操作以unfold,实现如下:

import torch
import torch.nn as nn
x = torch.Tensor([[[[  1,  2,  3,  4],
   					[  5,  6,  7,  8], 
   					[  9, 10, 11, 12],
   					[ 13, 14, 15, 16]]]])
unfold = nn.Unfold((2,2), stride=2)
print(x)
print(x.size())

输出unfold结果为:

tensor([[[ 1.,  3.,  9., 11.],
         [ 2.,  4., 10., 12.],
         [ 5.,  7., 13., 15.],
         [ 6.,  8., 14., 16.]]])
torch.Size([1, 4, 4])

再来看fold。前面我们看到,fold做的其实就是利用h×wh×w的核进行滑动窗口操作,然后将每次滑动得到的结果展平成一个列向量,逐个填充至结果中。那么unfold的话,做的工作就是处理fold得到的列向量。具体而言,unfold每次读取一个列向量,然后将其reshape回一个h×wh×w的块,再填回结果中。这时候就涉及一个问题,如果stride较小的话,reshape得到的块再填回结果时是会有overlapping的,因此只有在无overlapping(对于本例,需要stride=2)的情况下unfold与fold才可逆。

现在我们继续接着上面的例子,从unfold结果中提取第一列数据,将其reshape为2×2:

1 2 
5 6

然后将其填充到3×3结果中,有:

[[[[0+1, 0+2, 0],
   [0+5, 0+6, 0],
   [  0,   0, 0]]]]

继续提取第二列数据,将其reshape为2×2:

3 4
7 8

然后将其填充到3×3结果中。需要注意的是,由于stride=1,因此此时用于填充结果的kernel只会向右移一格,导致结果填充重叠:

[[[[1, 2+3, 0+4],
   [5, 6+7, 0+8],
   [0,   0,   0]]]]

继续提取第三列数据,将其reshape为2×2:

9  10
13 14

然后将其填充到3×3结果中:

[[[[   1,     5, 4],
   [ 5+9, 13+10, 8],
   [0+13,  0+14, 0]]]]

提取第四列数据,将其reshape为2×2:

11 12 
15 16

将其填充到3×3结果中,得到最后结果:

[[[[ 1,     5,    4],
   [14, 23+11, 8+12],
   [13, 14+15, 0+16]]]]

完整编码实现如下:

import torch
import torch.nn as nn

x = torch.Tensor([[[[  1,  2,  3,  4],
   					[  5,  6,  7,  8],
   					[  9, 10, 11, 12],
   					[ 13, 14, 15, 16]]]])

print(x)
unfold = nn.Unfold((2,2), stride=2)
fold = nn.Fold(kernel_size=(2,2), stride=1, output_size=(3,3))
x = unfold(x) 
print(x)
print(x.size())
x = fold(x)
print(x)
print(x.size())