{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Spatial Transformer Networks Tutorial\n=====================================\n\n**Author**: [Ghassen HAMROUNI](https://github.com/GHamrouni)\n\n\n\nIn this tutorial, you will learn how to augment your network using a\nvisual attention mechanism called spatial transformer networks. You can\nread more about the spatial transformer networks in the [DeepMind\npaper](https://arxiv.org/abs/1506.02025)\n\nSpatial transformer networks are a generalization of differentiable\nattention to any spatial transformation. Spatial transformer networks\n(STN for short) allow a neural network to learn how to perform spatial\ntransformations on the input image in order to enhance the geometric\ninvariance of the model. For example, it can crop a region of interest,\nscale and correct the orientation of an image. It can be a useful\nmechanism because CNNs are not invariant to rotation and scale and more\ngeneral affine transformations.\n\nOne of the best things about STN is the ability to simply plug it into\nany existing CNN with very little modification.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# License: BSD\n# Author: Ghassen Hamrouni\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchvision\nfrom torchvision import datasets, transforms\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nplt.ion() # interactive mode" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading the data\n================\n\nIn this post we experiment with the classic MNIST dataset. Using a\nstandard convolutional network augmented with a spatial transformer\nnetwork.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from six.moves import urllib\nopener = urllib.request.build_opener()\nopener.addheaders = [('User-agent', 'Mozilla/5.0')]\nurllib.request.install_opener(opener)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# Training dataset\ntrain_loader = torch.utils.data.DataLoader(\n datasets.MNIST(root='.', train=True, download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Normalize((0.1307,), (0.3081,))\n ])), batch_size=64, shuffle=True, num_workers=4)\n# Test dataset\ntest_loader = torch.utils.data.DataLoader(\n datasets.MNIST(root='.', train=False, transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Normalize((0.1307,), (0.3081,))\n ])), batch_size=64, shuffle=True, num_workers=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depicting spatial transformer networks\n======================================\n\nSpatial transformer networks boils down to three main components :\n\n- The localization network is a regular CNN which regresses the\n transformation parameters. The transformation is never learned\n explicitly from this dataset, instead the network learns\n automatically the spatial transformations that enhances the global\n accuracy.\n- The grid generator generates a grid of coordinates in the input\n image corresponding to each pixel from the output image.\n- The sampler uses the parameters of the transformation and applies it\n to the input image.\n\n\n\n```{=html}\n
We need the latest version of PyTorch that containsaffine_grid and grid_sample modules.
\n```\n```{=html}\n