Recurrent Models of Visual Attention(Glimpse Sensor代码)

351 阅读5分钟

本文已参与「新人创作礼」活动,一起开启掘金创作之路。 参考:

zhuanlan.zhihu.com/p/47809274

(4条消息) 【Recurrent Models of Visual Attention】(讲解)_David Wolfowitz的博客-CSDN博客

代码:kevinzakka/recurrent-visual-attention: A PyTorch Implement of “Recurrent Models of Visual Attention” (github.com)

论文:arxiv.org/pdf/1406.62…

Glimpse Sensor(类似于patch embedding)

Glimpse Sensor提取图像输入。

image.png是前一个时间传过来的位置信息,image.png是输入数据。当位置信息image.png传入image.png时, 将会对image.pngimage.png位置进行采样。

提取以位置信息image.png为中心,长宽为w 的倍数,共3个patchs图像区域,并把提取的图像区域归一化到w*w大小,拼接起来得到输入image.png。这样就把不同层次的信息组合了起来。过程如下图所示:

image.png

代码


class Retina:
    """A visual retina.

    Extracts a foveated glimpse `phi` around location `l`
    from an image `x`.

    Concretely, encodes the region around `l` at a
    high-resolution but uses a progressively lower
    resolution for pixels further from `l`, resulting
    in a compressed representation of the original
    image `x`.

    Args:
        x: a 4D Tensor of shape (B, H, W, C). The minibatch
            of images.
        l: a 2D Tensor of shape (B, 2). Contains normalized
            coordinates in the range [-1, 1].
        g: size of the first square patch.
        k: number of patches to extract in the glimpse.
        s: scaling factor that controls the size of
            successive patches.

    Returns:
        phi: a 5D tensor of shape (B, k, g, g, C). The
            foveated glimpse of the image.
    """

    def __init__(self, g, k, s):
        self.g = g
        self.k = k
        self.s = s

    def foveate(self, x, l):
        """Extract `k` square patches of size `g`, centered
        at location `l`. The initial patch is a square of
        size `g`, and each subsequent patch is a square
        whose side is `s` times the size of the previous
        patch.

        The `k` patches are finally resized to (g, g) and
        concatenated into a tensor of shape (B, k, g, g, C).
        """
        phi = []
        size = self.g  # patch_size

        # extract k patches of increasing size
        for i in range(self.k):  # 循环patch__num  # 相当于获取了每张图片的patch__num=3个patch, 此时k=1
            phi.append(self.extract_patch(x, l, size))  # (128, 1, 8, 8)# 每张图片的一个scale的patch
            size = int(self.s * size)  # 变化scale

        # resize the patches to squares of size g  通过平均池化将所有的图片大小规范到 统一大小, 这样如果把patch看作一个通道, 那么既有整体的也有局部的信息。
        for i in range(1, len(phi)):
            k = phi[i].shape[-1] // self.g
            phi[i] = F.avg_pool2d(phi[i], k)

        # concatenate into a single tensor and flatten
        phi = torch.cat(phi, 1)
        phi = phi.view(phi.shape[0], -1)

        return phi

    def extract_patch(self, x, l, size):
        """Extract a single patch for each image in `x`.

        Args:
        x: a 4D Tensor of shape (B, H, W, C). The minibatch
            of images.
        l: a 2D Tensor of shape (B, 2).
        size: a scalar defining the size of the extracted patch.

        Returns:
            patch: a 4D Tensor of shape (B, size, size, C)
        """
        B, C, H, W = x.shape

        start = self.denormalize(H, l)  # 将【-1, 1】区间的坐标变换【0, T】区间的坐标
        end = start + size  # 坐标是patch左上角一个点(x, y) + patch_size = patch右下角坐标

        # pad with zeros
        x = F.pad(x, (size // 2, size // 2, size // 2, size // 2))  # 填充patch大小的0,为了能够取到图片左上角像素点

        # loop through mini-batch and extract patches
        patch = []
        for i in range(B):  # 循环图片个数
            patch.append(x[i, :, start[i, 1] : end[i, 1], start[i, 0] : end[i, 0]])
        return torch.stack(patch)

    # 坐标区间映射从[-1, 1]区间, 映射到[0, T=(图片的w and h)]
    def denormalize(self, T, coords):
        """Convert coordinates in the range [-1, 1] to
        coordinates in the range [0, T] where `T` is
        the size of the image.
        """
        return (0.5 * ((coords + 1.0) * T)).long()  # coords + 1.0 --【0-2】--【0-2T】, 所以最后 * 0.5
class GlimpseNetwork(nn.Module):
    """The glimpse network.

    Combines the "what" and the "where" into a glimpse
    feature vector `g_t`.

    - "what": glimpse extracted from the retina.  视网膜的那一撇, 就深深的印在我的脑海。
    - "where": location tuple where glimpse was extracted.  那一撇的位置在哪里尼。

    Concretely, feeds the output of the retina `phi` to
    a fc layer and the glimpse location vector `l_t_prev`
    to a fc layer. Finally, these outputs are fed each
    through a fc layer and their sum is rectified.

    In other words:

        `g_t = relu( fc( fc(l) ) + fc( fc(phi) ) )`

    Args:
        h_g: hidden layer size of the fc layer for `phi`.
        h_l: hidden layer size of the fc layer for `l`.
        g: size of the square patches in the glimpses extracted             g=patch_size
        by the retina.
        k: number of patches to extract per glimpse.                        k=patch_num
        s: scaling factor that controls the size of successive patches.     s=patch_scale
        c: number of channels in each image.                                c=channels
        x: a 4D Tensor of shape (B, H, W, C). The minibatch
            of images.
        l_t_prev: a 2D tensor of shape (B, 2). Contains the glimpse         坐标(位置)向量
            coordinates [x, y] for the previous timestep `t-1`.

    Returns:
        g_t: a 2D tensor of shape (B, hidden_size).
            The glimpse representation returned by
            the glimpse network for the current
            timestep `t`.
    """

    def __init__(self, h_g, h_l, g, k, s, c):
        super().__init__()

        self.retina = Retina(g, k, s)  # (patch_size, patch_num, patch_scale)

        # glimpse layer
        D_in = k * g * g * c
        self.fc1 = nn.Linear(D_in, h_g)

        # location layer
        D_in = 2
        self.fc2 = nn.Linear(D_in, h_l)

        self.fc3 = nn.Linear(h_g, h_g + h_l)
        self.fc4 = nn.Linear(h_l, h_g + h_l)

    def forward(self, x, l_t_prev):
        # generate glimpse phi from image x
        phi = self.retina.foveate(x, l_t_prev)  # 输入x, 和位置向量, 返回:一张图片只取一个patch
        # torch.Size([128, 64])  64 = 8 * 8 =patch_size * patch_size
        # flatten location vector  展平位置向量
        l_t_prev = l_t_prev.view(l_t_prev.size(0), -1)

        # feed phi and l to respective fc layers    feed patch and location vector
        phi_out = F.relu(self.fc1(phi))  # (8 * 8 - 128)
        l_out = F.relu(self.fc2(l_t_prev))  # (2-128)

        what = self.fc3(phi_out)  # (128-256)
        where = self.fc4(l_out)  # (128-256)

        # feed to fc layer
        g_t = F.relu(what + where)

        return g_t

代码释义

主代码:GlimpseNetwork

init(self, h_g, h_l, g, k, s, c):

主要参数理解:

g = patch_size 每一个小patch的w 和 h, w=h

k = patch_num 一张图片要抽取多少个patch

s = patch_scale 如果一张图片要抽取多个patch, 那么抽取patch_size是由原patch_size * patch_scale确定。如下图,三个patch。

image.png

c = channels 通道数

self.retina = Retina(g, k, s) # (patch_size, patch_num, patch_scale) 相当于transformer中的patch_embedding

主要代码是:extract_patch()方法

该方法通过patch的左上角和右下角的坐标截取一张图片的patch。返回每个图片的一个patch

foveate方法:调用extract_patch(),然后返回每张图片的多个patch

后面的主要是对patch展开向量通过一系列全连接之后将位置信息和patch信息相加。

返回的是每一张图片的location + patch的向量。

思考

让我有所收获的是patch_embedding的方法,

image.png

该方法既获取图片的全局信息,也获取了图片的局部信息。也可以理解为获取了不同目标大小的注意框框, 这样就能够考虑到目标大小的问题。

-----------待补充