{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(beta) Building a Convolution/Batch Norm fuser in FX\n====================================================\n\n**Author**: [Horace He](https://github.com/chillee)\n\nIn this tutorial, we are going to use FX, a toolkit for composable\nfunction transformations of PyTorch, to do the following:\n\n1) Find patterns of conv/batch norm in the data dependencies.\n2) For the patterns found in 1), fold the batch norm statistics into\n the convolution weights.\n\nNote that this optimization only works for models in inference mode\n(i.e. [mode.eval()]{.title-ref})\n\nWe will be building the fuser that exists here:\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let\\'s get some imports out of the way (we will be using all of\nthese later in the code).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Type, Dict, Any, Tuple, Iterable\nimport copy\nimport torch.fx as fx\nimport torch\nimport torch.nn as nn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For this tutorial, we are going to create a model consisting of\nconvolutions and batch norms. Note that this model has some tricky\ncomponents - some of the conv/batch norm patterns are hidden within\nSequentials and one of the `BatchNorms` is wrapped in another Module.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class WrappedBatchNorm(nn.Module):\n def __init__(self):\n super().__init__()\n self.mod = nn.BatchNorm2d(1)\n def forward(self, x):\n return self.mod(x)\n\nclass M(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv1 = nn.Conv2d(1, 1, 1)\n self.bn1 = nn.BatchNorm2d(1)\n self.conv2 = nn.Conv2d(1, 1, 1)\n self.nested = nn.Sequential(\n nn.BatchNorm2d(1),\n nn.Conv2d(1, 1, 1),\n )\n self.wrapped = WrappedBatchNorm()\n\n def forward(self, x):\n x = self.conv1(x)\n x = self.bn1(x)\n x = self.conv2(x)\n x = self.nested(x)\n x = self.wrapped(x)\n return x\n\nmodel = M()\n\nmodel.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fusing Convolution with Batch Norm\n==================================\n\nOne of the primary challenges with trying to automatically fuse\nconvolution and batch norm in PyTorch is that PyTorch does not provide\nan easy way of accessing the computational graph. FX resolves this\nproblem by symbolically tracing the actual operations called, so that we\ncan track the computations through the [forward]{.title-ref} call,\nnested within Sequential modules, or wrapped in an user-defined module.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "traced_model = torch.fx.symbolic_trace(model)\nprint(traced_model.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This gives us a graph representation of our model. Note that both the\nmodules hidden within the sequential as well as the wrapped Module have\nbeen inlined into the graph. This is the default level of abstraction,\nbut it can be configured by the pass writer. More information can be\nfound at the FX overview\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fusing Convolution with Batch Norm\n==================================\n\nUnlike some other fusions, fusion of convolution with batch norm does\nnot require any new operators. Instead, as batch norm during inference\nconsists of a pointwise add and multiply, these operations can be\n\\\"baked\\\" into the preceding convolution\\'s weights. This allows us to\nremove the batch norm entirely from our model! Read\n for further\ndetails. The code here is copied from\n\nclarity purposes.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def fuse_conv_bn_eval(conv, bn):\n \"\"\"\n Given a conv Module `A` and an batch_norm module `B`, returns a conv\n module `C` such that C(x) == B(A(x)) in inference mode.\n \"\"\"\n assert(not (conv.training or bn.training)), \"Fusion only for eval!\"\n fused_conv = copy.deepcopy(conv)\n\n fused_conv.weight, fused_conv.bias = \\\n fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,\n bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)\n\n return fused_conv\n\ndef fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):\n if conv_b is None:\n conv_b = torch.zeros_like(bn_rm)\n if bn_w is None:\n bn_w = torch.ones_like(bn_rm)\n if bn_b is None:\n bn_b = torch.zeros_like(bn_rm)\n bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)\n\n conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))\n conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b\n\n return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "FX Fusion Pass\n==============\n\nNow that we have our computational graph as well as a method for fusing\nconvolution and batch norm, all that remains is to iterate over the FX\ngraph and apply the desired fusions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def _parent_name(target : str) -> Tuple[str, str]:\n \"\"\"\n Splits a ``qualname`` into parent path and last atom.\n For example, `foo.bar.baz` -> (`foo.bar`, `baz`)\n \"\"\"\n *parent, name = target.rsplit('.', 1)\n return parent[0] if parent else '', name\n\ndef replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module):\n assert(isinstance(node.target, str))\n parent_name, name = _parent_name(node.target)\n setattr(modules[parent_name], name, new_module)\n\n\ndef fuse(model: torch.nn.Module) -> torch.nn.Module:\n model = copy.deepcopy(model)\n # The first step of most FX passes is to symbolically trace our model to\n # obtain a `GraphModule`. This is a representation of our original model\n # that is functionally identical to our original model, except that we now\n # also have a graph representation of our forward pass.\n fx_model: fx.GraphModule = fx.symbolic_trace(model)\n modules = dict(fx_model.named_modules())\n\n # The primary representation for working with FX are the `Graph` and the\n # `Node`. Each `GraphModule` has a `Graph` associated with it - this\n # `Graph` is also what generates `GraphModule.code`.\n # The `Graph` itself is represented as a list of `Node` objects. Thus, to\n # iterate through all of the operations in our graph, we iterate over each\n # `Node` in our `Graph`.\n for node in fx_model.graph.nodes:\n # The FX IR contains several types of nodes, which generally represent\n # call sites to modules, functions, or methods. The type of node is\n # determined by `Node.op`.\n if node.op != 'call_module': # If our current node isn't calling a Module then we can ignore it.\n continue\n # For call sites, `Node.target` represents the module/function/method\n # that's being called. Here, we check `Node.target` to see if it's a\n # batch norm module, and then check `Node.args[0].target` to see if the\n # input `Node` is a convolution.\n if type(modules[node.target]) is nn.BatchNorm2d and type(modules[node.args[0].target]) is nn.Conv2d:\n if len(node.args[0].users) > 1: # Output of conv is used by other nodes\n continue\n conv = modules[node.args[0].target]\n bn = modules[node.target]\n fused_conv = fuse_conv_bn_eval(conv, bn)\n replace_node_module(node.args[0], modules, fused_conv)\n # As we've folded the batch nor into the conv, we need to replace all uses\n # of the batch norm with the conv.\n node.replace_all_uses_with(node.args[0])\n # Now that all uses of the batch norm have been replaced, we can\n # safely remove the batch norm.\n fx_model.graph.erase_node(node)\n fx_model.graph.lint()\n # After we've modified our graph, we need to recompile our graph in order\n # to keep the generated code in sync.\n fx_model.recompile()\n return fx_model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

We make some simplifications here for demonstration purposes, such as onlymatching 2D convolutions. Viewhttps://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/fuser.pyfor a more usable pass.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Testing out our Fusion Pass\n===========================\n\nWe can now run this fusion pass on our initial toy model and verify that\nour results are identical. In addition, we can print out the code for\nour fused model and verify that there are no more batch norms.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fused_model = fuse(model)\nprint(fused_model.code)\ninp = torch.randn(5, 1, 1, 1)\ntorch.testing.assert_allclose(fused_model(inp), model(inp))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Benchmarking our Fusion on ResNet18\n===================================\n\nWe can test our fusion pass on a larger model like ResNet18 and see how\nmuch this pass improves inference performance.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torchvision.models as models\nimport time\n\nrn18 = models.resnet18()\nrn18.eval()\n\ninp = torch.randn(10, 3, 224, 224)\noutput = rn18(inp)\n\ndef benchmark(model, iters=20):\n for _ in range(10):\n model(inp)\n begin = time.time()\n for _ in range(iters):\n model(inp)\n return str(time.time()-begin)\n\nfused_rn18 = fuse(rn18)\nprint(\"Unfused time: \", benchmark(rn18))\nprint(\"Fused time: \", benchmark(fused_rn18))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we previously saw, the output of our FX transformation is\n(\\\"torchscriptable\\\") PyTorch code, we can easily `jit.script` the\noutput to try and increase our performance even more. In this way, our\nFX model transformation composes with TorchScript with no issues.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "jit_rn18 = torch.jit.script(fused_rn18)\nprint(\"jit time: \", benchmark(jit_rn18))\n\n\n############\n# Conclusion\n# ----------\n# As we can see, using FX we can easily write static graph transformations on\n# PyTorch code.\n#\n# Since FX is still in beta, we would be happy to hear any\n# feedback you have about using it. Please feel free to use the\n# PyTorch Forums (https://discuss.pytorch.org/) and the issue tracker\n# (https://github.com/pytorch/pytorch/issues) to provide any feedback\n# you might have." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }