{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reasoning about Shapes in PyTorch\n=================================\n\nWhen writing models with PyTorch, it is commonly the case that the\nparameters to a given layer depend on the shape of the output of the\nprevious layer. For example, the `in_features` of an `nn.Linear` layer\nmust match the `size(-1)` of the input. For some layers, the shape\ncomputation involves complex equations, for example convolution\noperations.\n\nOne way around this is to run the forward pass with random inputs, but\nthis is wasteful in terms of memory and compute.\n\nInstead, we can make use of the `meta` device to determine the output\nshapes of a layer without materializing any data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport timeit\n\nt = torch.rand(2, 3, 10, 10, device=\"meta\")\nconv = torch.nn.Conv2d(3, 5, 2, device=\"meta\")\nstart = timeit.default_timer()\nout = conv(t)\nend = timeit.default_timer()\n\nprint(out)\nprint(f\"Time taken: {end-start}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Observe that since data is not materialized, passing arbitrarily large\ninputs will not significantly alter the time taken for shape\ncomputation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "t_large = torch.rand(2**10, 3, 2**16, 2**16, device=\"meta\")\nstart = timeit.default_timer()\nout = conv(t_large)\nend = timeit.default_timer()\n\nprint(out)\nprint(f\"Time taken: {end-start}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Consider an arbitrary network such as the following:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Net(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.pool = nn.MaxPool2d(2, 2)\n self.conv2 = nn.Conv2d(6, 16, 5)\n self.fc1 = nn.Linear(16 * 5 * 5, 120)\n self.fc2 = nn.Linear(120, 84)\n self.fc3 = nn.Linear(84, 10)\n\n def forward(self, x):\n x = self.pool(F.relu(self.conv1(x)))\n x = self.pool(F.relu(self.conv2(x)))\n x = torch.flatten(x, 1) # flatten all dimensions except batch\n x = F.relu(self.fc1(x))\n x = F.relu(self.fc2(x))\n x = self.fc3(x)\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can view the intermediate shapes within an entire network by\nregistering a forward hook to each layer that prints the shape of the\noutput.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def fw_hook(module, input, output):\n print(f\"Shape of output to {module} is {output.shape}.\")\n\n\n# Any tensor created within this torch.device context manager will be\n# on the meta device.\nwith torch.device(\"meta\"):\n net = Net()\n inp = torch.randn((1024, 3, 32, 32))\n\nfor name, layer in net.named_modules():\n layer.register_forward_hook(fw_hook)\n\nout = net(inp)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }