{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Spatial Transformer Networks Tutorial\n=====================================\n\n**Author**: [Ghassen HAMROUNI](https://github.com/GHamrouni)\n\n![](https://pytorch.org/tutorials/_static/img/stn/FSeq.png)\n\nIn this tutorial, you will learn how to augment your network using a\nvisual attention mechanism called spatial transformer networks. You can\nread more about the spatial transformer networks in the [DeepMind\npaper](https://arxiv.org/abs/1506.02025)\n\nSpatial transformer networks are a generalization of differentiable\nattention to any spatial transformation. Spatial transformer networks\n(STN for short) allow a neural network to learn how to perform spatial\ntransformations on the input image in order to enhance the geometric\ninvariance of the model. For example, it can crop a region of interest,\nscale and correct the orientation of an image. It can be a useful\nmechanism because CNNs are not invariant to rotation and scale and more\ngeneral affine transformations.\n\nOne of the best things about STN is the ability to simply plug it into\nany existing CNN with very little modification.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# License: BSD\n# Author: Ghassen Hamrouni\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchvision\nfrom torchvision import datasets, transforms\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nplt.ion() # interactive mode" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading the data\n================\n\nIn this post we experiment with the classic MNIST dataset. Using a\nstandard convolutional network augmented with a spatial transformer\nnetwork.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from six.moves import urllib\nopener = urllib.request.build_opener()\nopener.addheaders = [('User-agent', 'Mozilla/5.0')]\nurllib.request.install_opener(opener)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# Training dataset\ntrain_loader = torch.utils.data.DataLoader(\n datasets.MNIST(root='.', train=True, download=True,\n transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Normalize((0.1307,), (0.3081,))\n ])), batch_size=64, shuffle=True, num_workers=4)\n# Test dataset\ntest_loader = torch.utils.data.DataLoader(\n datasets.MNIST(root='.', train=False, transform=transforms.Compose([\n transforms.ToTensor(),\n transforms.Normalize((0.1307,), (0.3081,))\n ])), batch_size=64, shuffle=True, num_workers=4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Depicting spatial transformer networks\n======================================\n\nSpatial transformer networks boils down to three main components :\n\n- The localization network is a regular CNN which regresses the\n transformation parameters. The transformation is never learned\n explicitly from this dataset, instead the network learns\n automatically the spatial transformations that enhances the global\n accuracy.\n- The grid generator generates a grid of coordinates in the input\n image corresponding to each pixel from the output image.\n- The sampler uses the parameters of the transformation and applies it\n to the input image.\n\n![](https://pytorch.org/tutorials/_static/img/stn/stn-arch.png)\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

We need the latest version of PyTorch that containsaffine_grid and grid_sample modules.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Net(nn.Module):\n def __init__(self):\n super(Net, self).__init__()\n self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n self.conv2_drop = nn.Dropout2d()\n self.fc1 = nn.Linear(320, 50)\n self.fc2 = nn.Linear(50, 10)\n\n # Spatial transformer localization-network\n self.localization = nn.Sequential(\n nn.Conv2d(1, 8, kernel_size=7),\n nn.MaxPool2d(2, stride=2),\n nn.ReLU(True),\n nn.Conv2d(8, 10, kernel_size=5),\n nn.MaxPool2d(2, stride=2),\n nn.ReLU(True)\n )\n\n # Regressor for the 3 * 2 affine matrix\n self.fc_loc = nn.Sequential(\n nn.Linear(10 * 3 * 3, 32),\n nn.ReLU(True),\n nn.Linear(32, 3 * 2)\n )\n\n # Initialize the weights/bias with identity transformation\n self.fc_loc[2].weight.data.zero_()\n self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))\n\n # Spatial transformer network forward function\n def stn(self, x):\n xs = self.localization(x)\n xs = xs.view(-1, 10 * 3 * 3)\n theta = self.fc_loc(xs)\n theta = theta.view(-1, 2, 3)\n\n grid = F.affine_grid(theta, x.size())\n x = F.grid_sample(x, grid)\n\n return x\n\n def forward(self, x):\n # transform the input\n x = self.stn(x)\n\n # Perform the usual forward pass\n x = F.relu(F.max_pool2d(self.conv1(x), 2))\n x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n x = x.view(-1, 320)\n x = F.relu(self.fc1(x))\n x = F.dropout(x, training=self.training)\n x = self.fc2(x)\n return F.log_softmax(x, dim=1)\n\n\nmodel = Net().to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training the model\n==================\n\nNow, let\\'s use the SGD algorithm to train the model. The network is\nlearning the classification task in a supervised way. In the same time\nthe model is learning STN automatically in an end-to-end fashion.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "optimizer = optim.SGD(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n model.train()\n for batch_idx, (data, target) in enumerate(train_loader):\n data, target = data.to(device), target.to(device)\n\n optimizer.zero_grad()\n output = model(data)\n loss = F.nll_loss(output, target)\n loss.backward()\n optimizer.step()\n if batch_idx % 500 == 0:\n print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n epoch, batch_idx * len(data), len(train_loader.dataset),\n 100. * batch_idx / len(train_loader), loss.item()))\n#\n# A simple test procedure to measure the STN performances on MNIST.\n#\n\n\ndef test():\n with torch.no_grad():\n model.eval()\n test_loss = 0\n correct = 0\n for data, target in test_loader:\n data, target = data.to(device), target.to(device)\n output = model(data)\n\n # sum up batch loss\n test_loss += F.nll_loss(output, target, size_average=False).item()\n # get the index of the max log-probability\n pred = output.max(1, keepdim=True)[1]\n correct += pred.eq(target.view_as(pred)).sum().item()\n\n test_loss /= len(test_loader.dataset)\n print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'\n .format(test_loss, correct, len(test_loader.dataset),\n 100. * correct / len(test_loader.dataset)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Visualizing the STN results\n===========================\n\nNow, we will inspect the results of our learned visual attention\nmechanism.\n\nWe define a small helper function in order to visualize the\ntransformations while training.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def convert_image_np(inp):\n \"\"\"Convert a Tensor to numpy image.\"\"\"\n inp = inp.numpy().transpose((1, 2, 0))\n mean = np.array([0.485, 0.456, 0.406])\n std = np.array([0.229, 0.224, 0.225])\n inp = std * inp + mean\n inp = np.clip(inp, 0, 1)\n return inp\n\n# We want to visualize the output of the spatial transformers layer\n# after the training, we visualize a batch of input images and\n# the corresponding transformed batch using STN.\n\n\ndef visualize_stn():\n with torch.no_grad():\n # Get a batch of training data\n data = next(iter(test_loader))[0].to(device)\n\n input_tensor = data.cpu()\n transformed_input_tensor = model.stn(data).cpu()\n\n in_grid = convert_image_np(\n torchvision.utils.make_grid(input_tensor))\n\n out_grid = convert_image_np(\n torchvision.utils.make_grid(transformed_input_tensor))\n\n # Plot the results side-by-side\n f, axarr = plt.subplots(1, 2)\n axarr[0].imshow(in_grid)\n axarr[0].set_title('Dataset Images')\n\n axarr[1].imshow(out_grid)\n axarr[1].set_title('Transformed Images')\n\nfor epoch in range(1, 20 + 1):\n train(epoch)\n test()\n\n# Visualize the STN transformation on some input batch\nvisualize_stn()\n\nplt.ioff()\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }