{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "torch.export Tutorial\n=====================\n\n**Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
WARNING:
\n```\n```{=html}\n
\n```\n```{=html}\n

torch.export and its related features are in prototype status and are subject to backwards compatibilitybreaking changes. This tutorial provides a snapshot of torch.export usage as of PyTorch 2.5.

\n```\n```{=html}\n
\n```\n`torch.export`{.interpreted-text role=\"func\"} is the PyTorch 2.X way to\nexport PyTorch models into standardized model representations, intended\nto be run on different (i.e. Python-less) environments. The official\ndocumentation can be found\n[here](https://pytorch.org/docs/main/export.html).\n\nIn this tutorial, you will learn how to use\n`torch.export`{.interpreted-text role=\"func\"} to extract\n`ExportedProgram`\\'s (i.e. single-graph representations) from PyTorch\nprograms. We also detail some considerations/modifications that you may\nneed to make in order to make your model compatible with `torch.export`.\n\n**Contents**\n\n::: {.contents local=\"\"}\n:::\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basic Usage\n===========\n\n`torch.export` extracts single-graph representations from PyTorch\nprograms by tracing the target function, given example inputs.\n`torch.export.export()` is the main entry point for `torch.export`.\n\nIn this tutorial, `torch.export` and `torch.export.export()` are\npractically synonymous, though `torch.export` generally refers to the\nPyTorch 2.X export process, and `torch.export.export()` generally refers\nto the actual function call.\n\nThe signature of `torch.export.export()` is:\n\n``` {.python}\nexport(\n mod: torch.nn.Module,\n args: Tuple[Any, ...],\n kwargs: Optional[Dict[str, Any]] = None,\n *,\n dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None\n) -> ExportedProgram\n```\n\n`torch.export.export()` traces the tensor computation graph from calling\n`mod(*args, **kwargs)` and wraps it in an `ExportedProgram`, which can\nbe serialized or executed later with different inputs. To execute the\n`ExportedProgram` we can call `.module()` on it to return a\n`torch.nn.Module` which is callable, just like the original program. We\nwill detail the `dynamic_shapes` argument later in the tutorial.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torch.export import export\n\nclass MyModule(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.lin = torch.nn.Linear(100, 10)\n\n def forward(self, x, y):\n return torch.nn.functional.relu(self.lin(x + y), inplace=True)\n\nmod = MyModule()\nexported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))\nprint(type(exported_mod))\nprint(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s review some attributes of `ExportedProgram` that are of interest.\n\nThe `graph` attribute is an [FX\ngraph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph) traced\nfrom the function we exported, that is, the computation graph of all\nPyTorch operations. The FX graph is in \\\"ATen IR\\\" meaning that it\ncontains only \\\"ATen-level\\\" operations.\n\nThe `graph_signature` attribute gives a more detailed description of the\ninput and output nodes in the exported graph, describing which ones are\nparameters, buffers, user inputs, or user outputs.\n\nThe `range_constraints` attributes will be covered later.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(exported_mod)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See the `torch.export`\n[documentation](https://pytorch.org/docs/main/export.html#torch.export.export)\nfor more details.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Graph Breaks\n============\n\nAlthough `torch.export` shares components with `torch.compile`, the key\nlimitation of `torch.export`, especially when compared to\n`torch.compile`, is that it does not support graph breaks. This is\nbecause handling graph breaks involves interpreting the unsupported\noperation with default Python evaluation, which is incompatible with the\nexport use case. Therefore, in order to make your model code compatible\nwith `torch.export`, you will need to modify your code to remove graph\nbreaks.\n\nA graph break is necessary in cases such as:\n\n- data-dependent control flow\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Bad1(torch.nn.Module):\n def forward(self, x):\n if x.sum() > 0:\n return torch.sin(x)\n return torch.cos(x)\n\nimport traceback as tb\ntry:\n export(Bad1(), (torch.randn(3, 3),))\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- accessing tensor data with `.data`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Bad2(torch.nn.Module):\n def forward(self, x):\n x.data[0, 0] = 3\n return x\n\ntry:\n export(Bad2(), (torch.randn(3, 3),))\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- calling unsupported functions (such as many built-in functions)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Bad3(torch.nn.Module):\n def forward(self, x):\n x = x + 1\n return x + id(x)\n\ntry:\n export(Bad3(), (torch.randn(3, 3),))\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Non-Strict Export\n=================\n\nTo trace the program, `torch.export` uses TorchDynamo by default, a byte\ncode analysis engine, to symbolically analyze the Python code and build\na graph based on the results. This analysis allows `torch.export` to\nprovide stronger guarantees about safety, but not all Python code is\nsupported, causing these graph breaks.\n\nTo address this issue, in PyTorch 2.3, we introduced a new mode of\nexporting called non-strict mode, where we trace through the program\nusing the Python interpreter executing it exactly as it would in eager\nmode, allowing us to skip over unsupported Python features. This is done\nthrough adding a `strict=False` flag.\n\nLooking at some of the previous examples which resulted in graph breaks:\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\\- Calling unsupported functions (such as many built-in functions)\ntraces through, but in this case, `id(x)` gets specialized as a constant\ninteger in the graph. This is because `id(x)` is not a tensor operation,\nso the operation is not recorded in the graph.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Bad3(torch.nn.Module):\n def forward(self, x):\n x = x + 1\n return x + id(x)\n\nbad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)\nprint(bad3_nonstrict)\nprint(bad3_nonstrict.module()(torch.ones(3, 3)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, there are still some features that require rewrites to the\noriginal module:\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Control Flow Ops\n================\n\n`torch.export` actually does support data-dependent control flow. But\nthese need to be expressed using control flow ops. For example, we can\nfix the control flow example above using the `cond` op, like so:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Bad1Fixed(torch.nn.Module):\n def forward(self, x):\n def true_fn(x):\n return torch.sin(x)\n def false_fn(x):\n return torch.cos(x)\n return torch.cond(x.sum() > 0, true_fn, false_fn, [x])\n\nexported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))\nprint(exported_bad1_fixed)\nprint(exported_bad1_fixed.module()(torch.ones(3, 3)))\nprint(exported_bad1_fixed.module()(-torch.ones(3, 3)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are limitations to `cond` that one should be aware of:\n\n- The predicate (i.e. `x.sum() > 0`) must result in a boolean or a\n single-element tensor.\n- The operands (i.e. `[x]`) must be tensors.\n- The branch function (i.e. `true_fn` and `false_fn`) signature must\n match with the operands and they must both return a single tensor\n with the same metadata (for example, `dtype`, `shape`, etc.).\n- Branch functions cannot mutate input or global variables.\n- Branch functions cannot access closure variables, except for `self`\n if the function is defined in the scope of a method.\n\nFor more details about `cond`, check out the [cond\ndocumentation](https://pytorch.org/docs/main/cond.html).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also use `map`, which applies a function across the first\ndimension of the first tensor argument.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch._higher_order_ops.map import map as torch_map\n\nclass MapModule(torch.nn.Module):\n def forward(self, xs, y, z):\n def body(x, y, z):\n return x + y + z\n\n return torch_map(body, xs, y, z)\n\ninps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))\nexported_map_example = export(MapModule(), inps)\nprint(exported_map_example)\nprint(exported_map_example.module()(*inps))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Other control flow ops include `while_loop`, `associative_scan`, and\n`scan`. For more documentation on each operator, please refer to [this\npage](https://github.com/pytorch/pytorch/tree/main/torch/_higher_order_ops).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Constraints/Dynamic Shapes\n==========================\n\nThis section covers dynamic behavior and representation of exported\nprograms. Dynamic behavior is subjective to the particular model being\nexported, so for the most part of this tutorial, we\\'ll focus on this\nparticular toy model (with the resulting tensor shapes annotated):\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class DynamicModel(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.l = torch.nn.Linear(5, 3)\n\n def forward(\n self,\n w: torch.Tensor, # [6, 5]\n x: torch.Tensor, # [4]\n y: torch.Tensor, # [8, 4]\n z: torch.Tensor, # [32]\n ):\n x0 = x + y # [8, 4]\n x1 = self.l(w) # [6, 3]\n x2 = x0.flatten() # [32]\n x3 = x2 + z # [32]\n return x1, x3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, `torch.export` produces a static program. One consequence of\nthis is that at runtime, the program won\\'t work on inputs with\ndifferent shapes, even if they\\'re valid in eager mode.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "w = torch.randn(6, 5)\nx = torch.randn(4)\ny = torch.randn(8, 4)\nz = torch.randn(32)\nmodel = DynamicModel()\nep = export(model, (w, x, y, z))\nmodel(w, x, torch.randn(3, 4), torch.randn(12))\ntry:\n ep.module()(w, x, torch.randn(3, 4), torch.randn(12))\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basic concepts: symbols and guards\n==================================\n\nTo enable dynamism, `export()` provides a `dynamic_shapes` argument. The\neasiest way to work with dynamic shapes is using `Dim.AUTO` and looking\nat the program that\\'s returned. Dynamic behavior is specified at a\ninput dimension-level; for each input we can specify a tuple of values:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.export.dynamic_shapes import Dim\n\ndynamic_shapes = {\n \"w\": (Dim.AUTO, Dim.AUTO),\n \"x\": (Dim.AUTO,),\n \"y\": (Dim.AUTO, Dim.AUTO),\n \"z\": (Dim.AUTO,),\n}\nep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Before we look at the program that\\'s produced, let\\'s understand what\nspecifying `dynamic_shapes` entails, and how that interacts with export.\nFor every input dimension where a `Dim` object is specified, a symbol is\n[allocated](https://pytorch.org/docs/main/export.programming_model.html#basics-of-symbolic-shapes),\ntaking on a range of `[2, inf]` (why not `[0, inf]` or `[1, inf]`?\nwe\\'ll explain later in the 0/1 specialization section).\n\nExport then runs model tracing, looking at each operation that\\'s\nperformed by the model. Each individual operation can emit what\\'s\ncalled \\\"guards\\\"; basically boolean condition that are required to be\ntrue for the program to be valid. When guards involve symbols allocated\nfor input dimensions, the program contains restrictions on what input\nshapes are valid; i.e. the program\\'s dynamic behavior. The symbolic\nshapes subsystem is the part responsible for taking in all the emitted\nguards and producing a final program representation that adheres to all\nof these guards. Before we see this \\\"final representation\\\" in an\n`ExportedProgram`, let\\'s look at the guards emitted by the toy model\nwe\\'re tracing.\n\nHere, each forward input tensor is annotated with the symbol allocated\nat the start of tracing:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class DynamicModel(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.l = torch.nn.Linear(5, 3)\n\n def forward(\n self,\n w: torch.Tensor, # [s0, s1]\n x: torch.Tensor, # [s2]\n y: torch.Tensor, # [s3, s4]\n z: torch.Tensor, # [s5]\n ):\n x0 = x + y # guard: s2 == s4\n x1 = self.l(w) # guard: s1 == 5\n x2 = x0.flatten() # no guard added here\n x3 = x2 + z # guard: s3 * s4 == s5\n return x1, x3" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s understand each of the operations and the emitted guards:\n\n- `x0 = x + y`: This is an element-wise add with broadcasting, since\n `x` is a 1-d tensor and `y` a 2-d tensor. `x` is broadcasted along\n the last dimension of `y`, emitting the guard `s2 == s4`.\n- `x1 = self.l(w)`: Calling `nn.Linear()` performs a matrix\n multiplication with model parameters. In export, parameters,\n buffers, and constants are considered program state, which is\n considered static, and so this is a matmul between a dynamic input\n (`w: [s0, s1]`), and a statically-shaped tensor. This emits the\n guard `s1 == 5`.\n- `x2 = x0.flatten()`: This call actually doesn\\'t emit any guards!\n (at least none relevant to input shapes)\n- `x3 = x2 + z`: `x2` has shape `[s3*s4]` after flattening, and this\n element-wise add emits `s3 * s4 == s5`.\n\nWriting all of these guards down and summarizing is almost like a\nmathematical proof, which is what the symbolic shapes subsystem tries to\ndo! In summary, we can conclude that the program must have the following\ninput shapes to be valid:\n\n- `w: [s0, 5]`\n- `x: [s2]`\n- `y: [s3, s2]`\n- `z: [s2*s3]`\n\nAnd when we do finally print out the exported program to see our result,\nthose shapes are what we see annotated on the corresponding inputs:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(ep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another feature to notice is the range\\_constraints field above, which\ncontains a valid range for each symbol. This isn\\'t so interesting\ncurrently, since this export call doesn\\'t emit any guards related to\nsymbol bounds and each base symbol has a generic bound, but this will\ncome up later.\n\nSo far, because we\\'ve been exporting this toy model, this experience\nhas not been representative of how hard it typically is to debug dynamic\nshapes guards & issues. In most cases it isn\\'t obvious what guards are\nbeing emitted, and which operations and parts of user code are\nresponsible. For this toy model we pinpoint the exact lines, and the\nguards are rather intuitive.\n\nIn more complicated cases, a helpful first step is always to enable\nverbose logging. This can be done either with the environment variable\n`TORCH_LOGS=\"+dynamic\"`, or interactively with\n`torch._logging.set_logs(dynamic=10)`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch._logging.set_logs(dynamic=10)\nep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This spits out quite a handful, even with this simple toy model. The log\nlines here have been cut short at front and end to ignore unnecessary\ninfo, but looking through the logs we can see the lines relevant to what\nwe described above; e.g. the allocation of symbols:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\"\"\"\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\n\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The lines with [create\\_symbol]{.title-ref} show when a new symbol has\nbeen allocated, and the logs also identify the tensor variable names and\ndimensions they\\'ve been allocated for. In other lines we can also see\nthe guards emitted:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "\"\"\"\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s2, s4)\"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s1, 5)\"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"Eq(s2*s3, s5)\"\n\"\"\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next to the `[guard added]` messages, we also see the responsible user\nlines of code - luckily here the model is simple enough. In many\nreal-world cases it\\'s not so straightforward: high-level torch\noperations can have complicated fake-kernel implementations or operator\ndecompositions that complicate where and what guards are emitted. In\nsuch cases the best way to dig deeper and investigate is to follow the\nlogs\\' suggestion, and re-run with environment variable\n`TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED=\"...\"`, to further attribute the\nguard of interest.\n\n`Dim.AUTO` is just one of the available options for interacting with\n`dynamic_shapes`; as of writing this 2 other options are available:\n`Dim.DYNAMIC`, and `Dim.STATIC`. `Dim.STATIC` simply marks a dimension\nstatic, while `Dim.DYNAMIC` is similar to `Dim.AUTO` in all ways except\none: it raises an error when specializing to a constant; this is\ndesigned to maintain dynamism. See for example what happens when a\nstatic guard is emitted on a dynamically-marked dimension:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dynamic_shapes[\"w\"] = (Dim.AUTO, Dim.DYNAMIC)\ntry:\n export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Static guards also aren\\'t always inherent to the model; they can also\ncome from user specifications. In fact, a common pitfall leading to\nshape specializations is when the user specifies conflicting markers for\nequivalent dimensions; one dynamic and another static. The same error\ntype is raised when this is the case for `x.shape[0]` and `y.shape[1]`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dynamic_shapes[\"w\"] = (Dim.AUTO, Dim.AUTO)\ndynamic_shapes[\"x\"] = (Dim.STATIC,)\ndynamic_shapes[\"y\"] = (Dim.AUTO, Dim.DYNAMIC)\ntry:\n export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here you might ask why export \\\"specializes\\\", i.e. why we resolve this\nstatic/dynamic conflict by going with the static route. The answer is\nbecause of the symbolic shapes system described above, of symbols and\nguards. When `x.shape[0]` is marked static, we don\\'t allocate a symbol,\nand compile treating this shape as a concrete integer 4. A symbol is\nallocated for `y.shape[1]`, and so we finally emit the guard `s3 == 4`,\nleading to specialization.\n\nOne feature of export is that during tracing, statements like asserts,\n`torch._check()`, and `if/else` conditions will also emit guards. See\nwhat happens when we augment the existing model with such statements:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class DynamicModel(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.l = torch.nn.Linear(5, 3)\n\n def forward(self, w, x, y, z):\n assert w.shape[0] <= 512\n torch._check(x.shape[0] >= 4)\n if w.shape[0] == x.shape[0] + 2:\n x0 = x + y\n x1 = self.l(w)\n x2 = x0.flatten()\n x3 = x2 + z\n return x1, x3\n else:\n return w\n\ndynamic_shapes = {\n \"w\": (Dim.AUTO, Dim.AUTO),\n \"x\": (Dim.AUTO,),\n \"y\": (Dim.AUTO, Dim.AUTO),\n \"z\": (Dim.AUTO,),\n}\ntry:\n ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each of these statements emits an additional guard, and the exported\nprogram shows the changes; `s0` is eliminated in favor of `s2 + 2`, and\n`s2` now contains lower and upper bounds, reflected in\n`range_constraints`.\n\nFor the if/else condition, you might ask why the True branch was taken,\nand why it wasn\\'t the `w.shape[0] != x.shape[0] + 2` guard that got\nemitted from tracing. The answer is that export is guided by the sample\ninputs provided by tracing, and specializes on the branches taken. If\ndifferent sample input shapes were provided that fail the `if`\ncondition, export would trace and emit guards corresponding to the\n`else` branch. Additionally, you might ask why we traced only the `if`\nbranch, and if it\\'s possible to maintain control-flow in your program\nand keep both branches alive. For that, refer to rewriting your model\ncode following the `Control Flow Ops` section above.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "0/1 specialization\n==================\n\nSince we\\'re talking about guards and specializations, it\\'s a good time\nto talk about the 0/1 specialization issue we brought up earlier. The\nbottom line is that export will specialize on sample input dimensions\nwith value 0 or 1, because these shapes have trace-time properties that\ndon\\'t generalize to other shapes. For example, size 1 tensors can\nbroadcast while other sizes fail; and size 0 \\... . This just means that\nyou should specify 0/1 sample inputs when you\\'d like your program to\nhardcode them, and non-0/1 sample inputs when dynamic behavior is\ndesirable. See what happens at runtime when we export this linear layer:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ep = export(\n torch.nn.Linear(4, 3),\n (torch.randn(1, 4),),\n dynamic_shapes={\n \"input\": (Dim.AUTO, Dim.STATIC),\n },\n)\ntry:\n ep.module()(torch.randn(2, 4))\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Named Dims\n==========\n\nSo far we\\'ve only been talking about 3 ways to specify dynamic shapes:\n`Dim.AUTO`, `Dim.DYNAMIC`, and `Dim.STATIC`. The attraction of these is\nthe low-friction user experience; all the guards emitted during model\ntracing are adhered to, and dynamic behavior like min/max ranges,\nrelations, and static/dynamic dimensions are automatically figured out\nunderneath export. The dynamic shapes subsystem essentially acts as a\n\\\"discovery\\\" process, summarizing these guards and presenting what\nexport believes is the overall dynamic behavior of the program. The\ndrawback of this design appears once the user has stronger expectations\nor beliefs about the dynamic behavior of these models - maybe there is a\nstrong desire on dynamism and specializations on particular dimensions\nare to be avoided at all costs, or maybe we just want to catch changes\nin dynamic behavior with changes to the original model code, or possibly\nunderlying decompositions or meta-kernels. These changes won\\'t be\ndetected and the `export()` call will most likely succeed, unless tests\nare in place that check the resulting `ExportedProgram` representation.\n\nFor such cases, our stance is to recommend the \\\"traditional\\\" way of\nspecifying dynamic shapes, which longer-term users of export might be\nfamiliar with: named `Dims`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dx = Dim(\"dx\", min=4, max=256)\ndh = Dim(\"dh\", max=512)\ndynamic_shapes = {\n \"x\": (dx, None),\n \"y\": (2 * dx, dh),\n}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This style of dynamic shapes allows the user to specify what symbols are\nallocated for input dimensions, min/max bounds on those symbols, and\nplaces restrictions on the dynamic behavior of the `ExportedProgram`\nproduced; `ConstraintViolation` errors will be raised if model tracing\nemits guards that conflict with the relations or static/dynamic\nspecifications given. For example, in the above specification, the\nfollowing is asserted:\n\n- `x.shape[0]` is to have range `[4, 256]`, and related to\n `y.shape[0]` by `y.shape[0] == 2 * x.shape[0]`.\n- `x.shape[1]` is static.\n- `y.shape[1]` has range `[2, 512]`, and is unrelated to any other\n dimension.\n\nIn this design, we allow relations between dimensions to be specified\nwith univariate linear expressions: `A * dim + B` can be specified for\nany dimension. This allows users to specify more complex constraints\nlike integer divisibility for dynamic dimensions:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dx = Dim(\"dx\", min=4, max=512)\ndynamic_shapes = {\n \"x\": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4.\n}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Constraint violations, suggested fixes\n======================================\n\nOne common issue with this specification style (before `Dim.AUTO` was\nintroduced), is that the specification would often be mismatched with\nwhat was produced by model tracing. That would lead to\n`ConstraintViolation` errors and export suggested fixes - see for\nexample with this model & specification, where the model inherently\nrequires equality between dimensions 0 of `x` and `y`, and requires\ndimension 1 to be static.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n w = x + y\n return w + torch.ones(4)\n\ndx, dy, d1 = torch.export.dims(\"dx\", \"dy\", \"d1\")\ntry:\n ep = export(\n Foo(),\n (torch.randn(6, 4), torch.randn(6, 4)),\n dynamic_shapes={\n \"x\": (dx, d1),\n \"y\": (dy, d1),\n },\n )\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The expectation with suggested fixes is that the user can interactively\ncopy-paste the changes into their dynamic shapes specification, and\nsuccessfully export afterwards.\n\nLastly, there\\'s couple nice-to-knows about the options for\nspecification:\n\n- `None` is a good option for static behavior:\n - `dynamic_shapes=None` (default) exports with the entire model\n being static.\n - specifying `None` at an input-level exports with all tensor\n dimensions static, and is also required for non-tensor inputs.\n - specifying `None` at a dimension-level specializes that\n dimension, though this is deprecated in favor of `Dim.STATIC`.\n- specifying per-dimension integer values also produces static\n behavior, and will additionally check that the provided sample input\n matches the specification.\n\nThese options are combined in the inputs & dynamic shapes spec below:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "inputs = (\n torch.randn(4, 4),\n torch.randn(3, 3),\n 16,\n False,\n)\ndynamic_shapes = {\n \"tensor_0\": (Dim.AUTO, None),\n \"tensor_1\": None,\n \"int_val\": None,\n \"bool_val\": None,\n}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data-dependent errors\n=====================\n\nWhile trying to export models, you have may have encountered errors like\n\\\"Could not guard on data-dependent expression\\\", or Could not extract\nspecialized integer from data-dependent expression\\\". These errors exist\nbecause `torch.export()` compiles programs using FakeTensors, which\nsymbolically represent their real tensor counterparts. While these have\nequivalent symbolic properties (e.g. sizes, strides, dtypes), they\ndiverge in that FakeTensors do not contain any data values. While this\navoids unnecessary memory usage and expensive computation, it does mean\nthat export may be unable to out-of-the-box compile parts of user code\nwhere compilation relies on data values. In short, if the compiler\nrequires a concrete, data-dependent value in order to proceed, it will\nerror out, complaining that the value is not available.\n\nData-dependent values appear in many places, and common sources are\ncalls like `item()`, `tolist()`, or `torch.unbind()` that extract scalar\nvalues from tensors. How are these values represented in the exported\nprogram? In the [Constraints/Dynamic\nShapes](https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes)\nsection, we talked about allocating symbols to represent dynamic input\ndimensions. The same happens here: we allocate symbols for every\ndata-dependent value that appears in the program. The important\ndistinction is that these are \\\"unbacked\\\" symbols, in contrast to the\n\\\"backed\\\" symbols allocated for input dimensions. The\n[\\\"backed/unbacked\\\"](https://pytorch.org/docs/main/export.programming_model.html#basics-of-symbolic-shapes)\nnomenclature refers to the presence/absence of a \\\"hint\\\" for the\nsymbol: a concrete value backing the symbol, that can inform the\ncompiler on how to proceed.\n\nIn the input shape symbol case (backed symbols), these hints are simply\nthe sample input shapes provided, which explains why control-flow\nbranching is determined by the sample input properties. For\ndata-dependent values, the symbols are taken from FakeTensor \\\"data\\\"\nduring tracing, and so the compiler doesn\\'t know the actual values\n(hints) that these symbols would take on.\n\nLet\\'s see how these show up in exported programs:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n b = y.tolist()\n return b + [a]\n\ninps = (\n torch.tensor(1),\n torch.tensor([2, 3]),\n)\nep = export(Foo(), inps)\nprint(ep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result is that 3 unbacked symbols (notice they\\'re prefixed with\n\\\"u\\\", instead of the usual \\\"s\\\" for input shape/backed symbols) are\nallocated and returned: 1 for the `item()` call, and 1 for each of the\nelements of `y` with the `tolist()` call. Note from the range\nconstraints field that these take on ranges of `[-int_oo, int_oo]`, not\nthe default `[0, int_oo]` range allocated to input shape symbols, since\nwe have no information on what these values are - they don\\'t represent\nsizes, so don\\'t necessarily have positive values.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Guards, torch.\\_check()\n=======================\n\nBut the case above is easy to export, because the concrete values of\nthese symbols aren\\'t used in any compiler decision-making; all that\\'s\nrelevant is that the return values are unbacked symbols. The\ndata-dependent errors highlighted in this section are cases like the\nfollowing, where [data-dependent\nguards](https://pytorch.org/docs/main/export.programming_model.html#control-flow-static-vs-dynamic)\nare encountered:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n if a // 2 >= 5:\n return y + 2\n else:\n return y * 5" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we actually need the \\\"hint\\\", or the concrete value of `a` for the\ncompiler to decide whether to trace `return y + 2` or `return y * 5` as\nthe output. Because we trace with FakeTensors, we don\\'t know what\n`a // 2 >= 5` actually evaluates to, and export errors out with \\\"Could\nnot guard on data-dependent expression `u0 // 2 >= 5 (unhinted)`\\\".\n\nSo how do we export this toy model? Unlike `torch.compile()`, export\nrequires full graph compilation, and we can\\'t just graph break on this.\nHere are some basic options:\n\n1. Manual specialization: we could intervene by selecting the branch to\n trace, either by removing the control-flow code to contain only the\n specialized branch, or using `torch.compiler.is_compiling()` to\n guard what\\'s traced at compile-time.\n2. `torch.cond()`: we could rewrite the control-flow code to use\n `torch.cond()` so we don\\'t specialize on a branch.\n\nWhile these options are valid, they have their pitfalls. Option 1\nsometimes requires drastic, invasive rewrites of the model code to\nspecialize, and `torch.cond()` is not a comprehensive system for\nhandling data-dependent errors. As we will see, there are data-dependent\nerrors that do not involve control-flow.\n\nThe generally recommended approach is to start with `torch._check()`\ncalls. While these give the impression of purely being assert\nstatements, they are in fact a system of informing the compiler on\nproperties of symbols. While a `torch._check()` call does act as an\nassertion at runtime, when traced at compile-time, the checked\nexpression is sent to the symbolic shapes subsystem for reasoning, and\nany symbol properties that follow from the expression being true, are\nstored as symbol properties (provided it\\'s smart enough to infer those\nproperties). So even if unbacked symbols don\\'t have hints, if we\\'re\nable to communicate properties that are generally true for these symbols\nvia `torch._check()` calls, we can potentially bypass data-dependent\nguards without rewriting the offending model code.\n\nFor example in the model above, inserting `torch._check(a >= 10)` would\ntell the compiler that `y + 2` can always be returned, and\n`torch._check(a == 4)` tells it to return `y * 5`. See what happens when\nwe re-export this model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n torch._check(a >= 10)\n torch._check(a <= 60)\n if a // 2 >= 5:\n return y + 2\n else:\n return y * 5\n\ninps = (\n torch.tensor(32),\n torch.randn(4),\n)\nep = export(Foo(), inps)\nprint(ep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Export succeeds, and note from the range constraints field that `u0`\ntakes on a range of `[10, 60]`.\n\nSo what information do `torch._check()` calls actually communicate? This\nvaries as the symbolic shapes subsystem gets smarter, but at a\nfundamental level, these are generally true:\n\n1. Equality with non-data-dependent expressions: `torch._check()` calls\n that communicate equalities like `u0 == s0 + 4` or `u0 == 5`.\n2. Range refinement: calls that provide lower or upper bounds for\n symbols, like the above.\n3. Some basic reasoning around more complicated expressions: inserting\n `torch._check(a < 4)` will typically tell the compiler that `a >= 4`\n is false. Checks on complex expressions like\n `torch._check(a ** 2 - 3 * a <= 10)` will typically get you past\n identical guards.\n\nAs mentioned previously, `torch._check()` calls have applicability\noutside of data-dependent control flow. For example, here\\'s a model\nwhere `torch._check()` insertion prevails while manual specialization &\n`torch.cond()` do not:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n return y[a]\n\ninps = (\n torch.tensor(32),\n torch.randn(60),\n)\ntry:\n export(Foo(), inps)\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is a scenario where `torch._check()` insertion is required simply\nto prevent an operation from failing. The export call will fail with\n\\\"Could not guard on data-dependent expression `-u0 > 60`\\\", implying\nthat the compiler doesn\\'t know if this is a valid indexing operation\n-if the value of `x` is out-of-bounds for `y` or not. Here, manual\nspecialization is too prohibitive, and `torch.cond()` has no place.\nInstead, informing the compiler of `u0`\\'s range is sufficient:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n torch._check(a >= 0)\n torch._check(a < y.shape[0])\n return y[a]\n\ninps = (\n torch.tensor(32),\n torch.randn(60),\n)\nep = export(Foo(), inps)\nprint(ep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Specialized values\n==================\n\nAnother category of data-dependent error happens when the program\nattempts to extract a concrete data-dependent integer/float value while\ntracing. This looks something like \\\"Could not extract specialized\ninteger from data-dependent expression\\\", and is analogous to the\nprevious class of errors - if these occur when attempting to evaluate\nconcrete integer/float values, data-dependent guard errors arise with\nevaluating concrete boolean values.\n\nThis error typically occurs when there is an explicit or implicit\n`int()` cast on a data-dependent expression. For example, this list\ncomprehension has a [range()]{.title-ref} call that implicitly does an\n`int()` cast on the size of the list:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n b = torch.cat([y for y in range(a)], dim=0)\n return b + int(a)\n\ninps = (\n torch.tensor(32),\n torch.randn(60),\n)\ntry:\n export(Foo(), inps, strict=False)\nexcept Exception:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For these errors, some basic options you have are:\n\n1. Avoid unnecessary `int()` cast calls, in this case the `int(a)` in\n the return statement.\n2. Use `torch._check()` calls; unfortunately all you may be able to do\n in this case is specialize (with `torch._check(a == 60)`).\n3. Rewrite the offending code at a higher level. For example, the list\n comprehension is semantically a `repeat()` op, which doesn\\'t\n involve an `int()` cast. The following rewrite avoids data-dependent\n errors:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Foo(torch.nn.Module):\n def forward(self, x, y):\n a = x.item()\n b = y.unsqueeze(0).repeat(a, 1)\n return b + a\n\ninps = (\n torch.tensor(32),\n torch.randn(60),\n)\nep = export(Foo(), inps, strict=False)\nprint(ep)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data-dependent errors can be much more involved, and there are many more\noptions in your toolkit to deal with them: `torch._check_is_size()`,\n`guard_size_oblivious()`, or real-tensor tracing, as starters. For more\nin-depth guides, please refer to the [Export Programming\nModel](https://pytorch.org/docs/main/export.programming_model.html), or\n[Dealing with GuardOnDataDependentSymNode\nerrors](https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Custom Ops\n==========\n\n`torch.export` can export PyTorch programs with custom operators. Please\nrefer to [this\npage](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)\non how to author a custom operator in either C++ or Python.\n\nThe following is an example of registering a custom operator in python\nto be used by `torch.export`. The important thing to note is that the\ncustom op must have a [FakeTensor\nkernel](https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit?tab=t.0#heading=h.xvrg7clz290).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.library.custom_op(\"my_custom_library::custom_op\", mutates_args={})\ndef custom_op(x: torch.Tensor) -> torch.Tensor:\n print(\"custom_op called!\")\n return torch.relu(x)\n\n@custom_op.register_fake\ndef custom_op_meta(x):\n # Returns an empty tensor with the same shape as the expected output\n return torch.empty_like(x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here is an example of exporting a program with the custom op.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class CustomOpExample(torch.nn.Module):\n def forward(self, x):\n x = torch.sin(x)\n x = torch.ops.my_custom_library.custom_op(x)\n x = torch.cos(x)\n return x\n\nexported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))\nprint(exported_custom_op_example)\nprint(exported_custom_op_example.module()(torch.randn(3, 3)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that in the `ExportedProgram`, the custom operator is included in\nthe graph.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "IR/Decompositions\n=================\n\nThe graph produced by `torch.export` returns a graph containing only\n[ATen operators](https://pytorch.org/cppdocs/#aten), which are the basic\nunit of computation in PyTorch. As there are over 3000 ATen operators,\nexport provides a way to narrow down the operator set used in the graph\nbased on certain characteristics, creating different IRs.\n\nBy default, export produces the most generic IR which contains all ATen\noperators, including both functional and non-functional operators. A\nfunctional operator is one that does not contain any mutations or\naliasing of the inputs. You can find a list of all ATen operators\n[here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml)\nand you can inspect if an operator is functional by checking\n`op._schema.is_mutable`, for example:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(torch.ops.aten.add.Tensor._schema.is_mutable)\nprint(torch.ops.aten.add_.Tensor._schema.is_mutable)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This generic IR can be used to train in eager PyTorch Autograd. This IR\ncan be more explicitly reached through the API\n`torch.export.export_for_training`, which was introduced in PyTorch 2.5,\nbut calling `torch.export.export` should produce the same graph as of\nPyTorch 2.6.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class DecompExample(torch.nn.Module):\n def __init__(self) -> None:\n super().__init__()\n self.conv = torch.nn.Conv2d(1, 3, 1, 1)\n self.bn = torch.nn.BatchNorm2d(3)\n\n def forward(self, x):\n x = self.conv(x)\n x = self.bn(x)\n return (x,)\n\nep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))\nprint(ep_for_training.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then lower this exported program to an operator set which only\ncontains functional ATen operators through the API `run_decompositions`,\nwhich decomposes the ATen operators into the ones specified in the\ndecomposition table, and functionalizes the graph. By specifying an\nempty set, we\\'re only performing functionalization, and does not do any\nadditional decompositions. This results in an IR which contains \\~2000\noperators (instead of the 3000 operators above), and is ideal for\ninference cases.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ep_for_inference = ep_for_training.run_decompositions(decomp_table={})\nprint(ep_for_inference.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, the previously mutable operator,\n`torch.ops.aten.add_.default` has now been replaced with\n`torch.ops.aten.add.default`, a l operator.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also further lower this exported program to an operator set which\nonly contains the [Core ATen Operator\nSet](https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir),\nwhich is a collection of only \\~180 operators. This IR is optimal for\nbackends who do not want to reimplement all ATen operators.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.export import default_decompositions\n\ncore_aten_decomp_table = default_decompositions()\ncore_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)\nprint(core_aten_ep.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now see that `torch.ops.aten.conv2d.default` has been decomposed into\n`torch.ops.aten.convolution.default`. This is because `convolution` is a\nmore \\\"core\\\" operator, as operations like `conv1d` and `conv2d` can be\nimplemented using the same op.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also specify our own decomposition behaviors:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "my_decomp_table = torch.export.default_decompositions()\n\ndef my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):\n return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)\n\nmy_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function\nmy_ep = ep_for_training.run_decompositions(my_decomp_table)\nprint(my_ep.graph)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that instead of `torch.ops.aten.conv2d.default` being decomposed\ninto `torch.ops.aten.convolution.default`, it is now decomposed into\n`torch.ops.aten.convolution.default` and `torch.ops.aten.mul.Tensor`,\nwhich matches our custom decomposition rule.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "ExportDB\n========\n\n`torch.export` will only ever export a single computation graph from a\nPyTorch program. Because of this requirement, there will be Python or\nPyTorch features that are not compatible with `torch.export`, which will\nrequire users to rewrite parts of their model code. We have seen\nexamples of this earlier in the tutorial \\-- for example, rewriting\nif-statements using `cond`.\n\n[ExportDB](https://pytorch.org/docs/main/generated/exportdb/index.html)\nis the standard reference that documents supported and unsupported\nPython/PyTorch features for `torch.export`. It is essentially a list a\nprogram samples, each of which represents the usage of one particular\nPython/PyTorch feature and its interaction with `torch.export`. Examples\nare also tagged by category so that they can be more easily searched.\n\nFor example, let\\'s use ExportDB to get a better understanding of how\nthe predicate works in the `cond` operator. We can look at the example\ncalled `cond_predicate`, which has a `torch.cond` tag. The example code\nlooks like:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def cond_predicate(x):\n \"\"\"\n The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:\n - ``torch.Tensor`` with a single element\n - boolean expression\n NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.\n \"\"\"\n pred = x.dim() > 2 and x.shape[2] > 10\n return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "More generally, ExportDB can be used as a reference when one of the\nfollowing occurs:\n\n1. Before attempting `torch.export`, you know ahead of time that your\n model uses some tricky Python/PyTorch features and you want to know\n if `torch.export` covers that feature.\n2. When attempting `torch.export`, there is a failure and it\\'s unclear\n how to work around it.\n\nExportDB is not exhaustive, but is intended to cover all use cases found\nin typical PyTorch code. Feel free to reach out if there is an important\nPython/PyTorch feature that should be added to ExportDB or supported by\n`torch.export`.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Running the Exported Program\n============================\n\nAs `torch.export` is only a graph capturing mechanism, calling the\nartifact produced by `torch.export` eagerly will be equivalent to\nrunning the eager module. To optimize the execution of the Exported\nProgram, we can pass this exported artifact to backends such as Inductor\nthrough `torch.compile`,\n[AOTInductor](https://pytorch.org/docs/main/torch.compiler_aot_inductor.html),\nor [TensorRT](https://pytorch.org/TensorRT/dynamo/dynamo_export.html).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class M(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.linear = torch.nn.Linear(3, 3)\n\n def forward(self, x):\n x = self.linear(x)\n return x\n\ninp = torch.randn(2, 3, device=\"cuda\")\nm = M().to(device=\"cuda\")\nep = torch.export.export(m, (inp,))\n\n# Run it eagerly\nres = ep.module()(inp)\nprint(res)\n\n# Run it with torch.compile\nres = torch.compile(ep.module(), backend=\"inductor\")(inp)\nprint(res)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "``` {.python}\nimport torch._inductor\n\n# Note: these APIs are subject to change\n# Compile the exported program to a PT2 archive using ``AOTInductor``\nwith torch.no_grad():\n pt2_path = torch._inductor.aoti_compile_and_package(ep)\n\n# Load and run the .so file in Python.\n# To load and run it in a C++ environment, see:\n# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html\naoti_compiled = torch._inductor.aoti_load_package(pt2_path)\nres = aoti_compiled(inp)\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nWe introduced `torch.export`, the new PyTorch 2.X way to export single\ncomputation graphs from PyTorch programs. In particular, we demonstrate\nseveral code modifications and considerations (control flow ops,\nconstraints, etc.) that need to be made in order to export a graph.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }