SPATIAL TRANSFORMER NETWORKS TUTORIAL Author: Ghassen HAMROUNI
In this tutorial, you will learn how to augment your network using a visual attention mechanism called spatial transformer networks. You can read more about the spatial transformer networks in the DeepMind paper
Spatial transformer networks are a generalization of differentiable attention to any spatial transformation. Spatial transformer networks (STN for short) allow a neural network to learn how to perform spatial transformations on the input image in order to enhance the geometric invariance of the model. For example, it can crop a region of interest, scale and correct the orientation of an image. It can be a useful mechanism because CNNs are not invariant to rotation and scale and more general affine transformations.
grid_sample,画了一个草图作为解释。
图像尺寸归一化:首先对图像的尺寸进行归一化,(-1,-1)表示原来图像的(0,0)位置,(1,1)表示原来图像的(H-1,W-1)位置,这样一来,特征点的位置也被归一化到了相应的位置。 构建grid:将归一化后的特征点罗列起来,构成一个尺度为11K2的张量,其中K表示特征数量,2分别表示xy坐标。 特征点位置反归一化:根据输入张量的H与W对grid(1,1,0,:)(表示第一个特征点,其余特征点类似)进行反归一化,其实就是按照比例进行缩放+平移,得到反归一化特征点在张量某个slice(通道)上的位置;但是这个位置可能并非为整像素,此时要对其进行双线性插值补齐,然后其余slice按照同样的方式进行双线性插值。注:代码中实际的就是双线性插值,并非文中讲的双三次插值; 输出维度:1C1K。
One of the best things about STN is the ability to simply plug it into any existing CNN with very little modification.
#!/usr/bin/env python
coding: utf-8
In[ ]:
get_ipython().run_line_magic('matplotlib', 'inline')
Spatial Transformer Networks Tutorial
=====================================
Author: Ghassen HAMROUNI <https://github.com/GHamrouni>_
.. figure:: /_static/img/stn/FSeq.png
In this tutorial, you will learn how to augment your network using
a visual attention mechanism called spatial transformer
networks. You can read more about the spatial transformer
networks in the DeepMind paper <https://arxiv.org/abs/1506.02025>__
Spatial transformer networks are a generalization of differentiable
attention to any spatial transformation. Spatial transformer networks
(STN for short) allow a neural network to learn how to perform spatial
transformations on the input image in order to enhance the geometric
invariance of the model.
For example, it can crop a region of interest, scale and correct
the orientation of an image. It can be a useful mechanism because CNNs
are not invariant to rotation and scale and more general affine
transformations.
One of the best things about STN is the ability to simply plug it into
any existing CNN with very little modification.
In[ ]:
License: BSD
Author: Ghassen Hamrouni
from future import print_function
import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision from torchvision import datasets, transforms import matplotlib.pyplot as plt import numpy as np
plt.ion() # interactive mode
Loading the data
----------------
In this post we experiment with the classic MNIST dataset. Using a
standard convolutional network augmented with a spatial transformer
network.
In[ ]:
from six.moves import urllib opener = urllib.request.build_opener() opener.addheaders = [('User-agent', 'Mozilla/5.0')] urllib.request.install_opener(opener)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Training dataset
train_loader = torch.utils.data.DataLoader( datasets.MNIST(root='.', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True, num_workers=4)
Test dataset
test_loader = torch.utils.data.DataLoader( datasets.MNIST(root='.', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=64, shuffle=True, num_workers=4)
Depicting spatial transformer networks
--------------------------------------
Spatial transformer networks boils down to three main components :
- The localization network is a regular CNN which regresses the
transformation parameters. The transformation is never learned
explicitly from this dataset, instead the network learns automatically
the spatial transformations that enhances the global accuracy.
- The grid generator generates a grid of coordinates in the input
image corresponding to each pixel from the output image.
- The sampler uses the parameters of the transformation and applies
it to the input image.
.. figure:: /_static/img/stn/stn-arch.png
.. Note::
We need the latest version of PyTorch that contains
affine_grid and grid_sample modules.
In[ ]:
class Net(nn.Module): def init(self): super(Net, self).init() self.conv1 = nn.Conv2d(1, 10, kernel_size=5) self.conv2 = nn.Conv2d(10, 20, kernel_size=5) self.conv2_drop = nn.Dropout2d() self.fc1 = nn.Linear(320, 50) self.fc2 = nn.Linear(50, 10)
# Spatial transformer localization-network
self.localization = nn.Sequential(
nn.Conv2d(1, 8, kernel_size=7),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True),
nn.Conv2d(8, 10, kernel_size=5),
nn.MaxPool2d(2, stride=2),
nn.ReLU(True)
)
# Regressor for the 3 * 2 affine matrix
self.fc_loc = nn.Sequential(
nn.Linear(10 * 3 * 3, 32),
nn.ReLU(True),
nn.Linear(32, 3 * 2)
)
# Initialize the weights/bias with identity transformation
self.fc_loc[2].weight.data.zero_()
self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))
# Spatial transformer network forward function
def stn(self, x):
xs = self.localization(x)
xs = xs.view(-1, 10 * 3 * 3)
theta = self.fc_loc(xs)
theta = theta.view(-1, 2, 3)
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def forward(self, x):
# transform the input
x = self.stn(x)
# Perform the usual forward pass
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net().to(device)
Training the model
------------------
Now, let's use the SGD algorithm to train the model. The network is
learning the classification task in a supervised way. In the same time
the model is learning STN automatically in an end-to-end fashion.
In[ ]:
optimizer = optim.SGD(model.parameters(), lr=0.01)
def train(epoch): model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % 500 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
A simple test procedure to measure the STN performances on MNIST.
def test(): with torch.no_grad(): model.eval() test_loss = 0 correct = 0 for data, target in test_loader: data, target = data.to(device), target.to(device) output = model(data)
# sum up batch loss
test_loss += F.nll_loss(output, target, size_average=False).item()
# get the index of the max log-probability
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
.format(test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
Visualizing the STN results
---------------------------
Now, we will inspect the results of our learned visual attention
mechanism.
We define a small helper function in order to visualize the
transformations while training.
In[ ]:
def convert_image_np(inp): """Convert a Tensor to numpy image.""" inp = inp.numpy().transpose((1, 2, 0)) mean = np.array([0.485, 0.456, 0.406]) std = np.array([0.229, 0.224, 0.225]) inp = std * inp + mean inp = np.clip(inp, 0, 1) return inp
We want to visualize the output of the spatial transformers layer
after the training, we visualize a batch of input images and
the corresponding transformed batch using STN.
def visualize_stn(): with torch.no_grad(): # Get a batch of training data data = next(iter(test_loader))[0].to(device)
input_tensor = data.cpu()
transformed_input_tensor = model.stn(data).cpu()
in_grid = convert_image_np(
torchvision.utils.make_grid(input_tensor))
out_grid = convert_image_np(
torchvision.utils.make_grid(transformed_input_tensor))
# Plot the results side-by-side
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(in_grid)
axarr[0].set_title('Dataset Images')
axarr[1].imshow(out_grid)
axarr[1].set_title('Transformed Images')
for epoch in range(1, 20 + 1): train(epoch) test()
Visualize the STN transformation on some input batch
visualize_stn()
plt.ioff() plt.show()
中文版
pytorch的grid_sample返回不正确的值 - 问答 - 腾讯云开发者社区-腾讯云 (tencent.com)
在神经网络框架中有前向映射/翘曲的实现吗
一种深度学习特征SuperPoint
如何使用光流和grid_sample对图像进行扭曲? 卷积网络求解仿射变换参数