{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Introduction to `torch.compile`\n===============================\n\n**Author:** William Wen\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`torch.compile` is the latest method to speed up your PyTorch code!\n`torch.compile` makes PyTorch code run faster by JIT-compiling PyTorch\ncode into optimized kernels, all while requiring minimal code changes.\n\nIn this tutorial, we cover basic `torch.compile` usage, and demonstrate\nthe advantages of `torch.compile` over previous PyTorch compiler\nsolutions, such as\n[TorchScript](https://pytorch.org/docs/stable/jit.html) and [FX\nTracing](https://pytorch.org/docs/stable/fx.html#torch.fx.symbolic_trace).\n\n**Contents**\n\n::: {.contents local=\"\"}\n:::\n\n**Required pip Dependencies**\n\n- `torch >= 2.0`\n- `torchvision`\n- `numpy`\n- `scipy`\n- `tabulate`\n\n**System Requirements** - A C++ compiler, such as `g++` - Python\ndevelopment package (`python-devel`/`python-dev`)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this\ntutorial in order to reproduce the speedup numbers shown below and\ndocumented elsewhere.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport warnings\n\ngpu_ok = False\nif torch.cuda.is_available():\n device_cap = torch.cuda.get_device_capability()\n if device_cap in ((7, 0), (8, 0), (9, 0)):\n gpu_ok = True\n\nif not gpu_ok:\n warnings.warn(\n \"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower \"\n \"than expected.\"\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basic Usage\n===========\n\n`torch.compile` is included in the latest PyTorch. Running TorchInductor\non GPU requires Triton, which is included with the PyTorch 2.0 nightly\nbinary. If Triton is still missing, try installing `torchtriton` via pip\n(`pip install torchtriton --extra-index-url \"https://download.pytorch.org/whl/nightly/cu117\"`\nfor CUDA 11.7).\n\nArbitrary Python functions can be optimized by passing the callable to\n`torch.compile`. We can then call the returned optimized function in\nplace of the original function.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def foo(x, y):\n a = torch.sin(x)\n b = torch.cos(y)\n return a + b\nopt_foo1 = torch.compile(foo)\nprint(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Alternatively, we can decorate the function.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "t1 = torch.randn(10, 10)\nt2 = torch.randn(10, 10)\n\n@torch.compile\ndef opt_foo2(x, y):\n a = torch.sin(x)\n b = torch.cos(y)\n return a + b\nprint(opt_foo2(t1, t2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also optimize `torch.nn.Module` instances.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "t = torch.randn(10, 100)\n\nclass MyModule(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.lin = torch.nn.Linear(100, 10)\n\n def forward(self, x):\n return torch.nn.functional.relu(self.lin(x))\n\nmod = MyModule()\nmod.compile()\nprint(mod(t))\n## or:\n# opt_mod = torch.compile(mod)\n# print(opt_mod(t))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "torch.compile and Nested Calls\n==============================\n\nNested function calls within the decorated function will also be\ncompiled.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def nested_function(x):\n return torch.sin(x)\n\n@torch.compile\ndef outer_function(x, y):\n a = nested_function(x)\n b = torch.cos(y)\n return a + b\n\nprint(outer_function(t1, t2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the same fashion, when compiling a module all sub-modules and methods\nwithin it, that are not in a skip list, are also compiled.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class OuterModule(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.inner_module = MyModule()\n self.outer_lin = torch.nn.Linear(10, 2)\n\n def forward(self, x):\n x = self.inner_module(x)\n return torch.nn.functional.relu(self.outer_lin(x))\n\nouter_mod = OuterModule()\nouter_mod.compile()\nprint(outer_mod(t))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can also disable some functions from being compiled by using\n`torch.compiler.disable`. Suppose you want to disable the tracing on\njust the `complex_function` function, but want to continue the tracing\nback in `complex_conjugate`. In this case, you can use\n`torch.compiler.disable(recursive=False)` option. Otherwise, the default\nis `recursive=True`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def complex_conjugate(z):\n return torch.conj(z)\n\n@torch.compiler.disable(recursive=False)\ndef complex_function(real, imag):\n # Assuming this function cause problems in the compilation\n z = torch.complex(real, imag)\n return complex_conjugate(z)\n\ndef outer_function():\n real = torch.tensor([2, 3], dtype=torch.float32)\n imag = torch.tensor([4, 5], dtype=torch.float32)\n z = complex_function(real, imag)\n return torch.abs(z)\n\n# Try to compile the outer_function\ntry:\n opt_outer_function = torch.compile(outer_function)\n print(opt_outer_function())\nexcept Exception as e:\n print(\"Compilation of outer_function failed:\", e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Best Practices and Recommendations\n==================================\n\nBehavior of `torch.compile` with Nested Modules and Function Calls\n\nWhen you use `torch.compile`, the compiler will try to recursively\ncompile every function call inside the target function or module inside\nthe target function or module that is not in a skip list (such as\nbuilt-ins, some functions in the torch.\\* namespace).\n\n**Best Practices:**\n\n1\\. **Top-Level Compilation:** One approach is to compile at the highest\nlevel possible (i.e., when the top-level module is initialized/called)\nand selectively disable compilation when encountering excessive graph\nbreaks or errors. If there are still many compile issues, compile\nindividual subcomponents instead.\n\n2\\. **Modular Testing:** Test individual functions and modules with\n`torch.compile` before integrating them into larger models to isolate\npotential issues.\n\n3\\. **Disable Compilation Selectively:** If certain functions or\nsub-modules cannot be handled by [torch.compile]{.title-ref}, use the\n[torch.compiler.disable]{.title-ref} context managers to recursively\nexclude them from compilation.\n\n4\\. **Compile Leaf Functions First:** In complex models with multiple\nnested functions and modules, start by compiling the leaf functions or\nmodules first. For more information see [TorchDynamo APIs for\nfine-grained\ntracing](https://pytorch.org/docs/stable/torch.compiler_fine_grain_apis.html).\n\n5. **Prefer \\`\\`mod.compile()\\`\\` over \\`\\`torch.compile(mod)\\`\\`:**\n Avoids `_orig_` prefix issues in `state_dict`.\n\n6\\. **Use \\`\\`fullgraph=True\\`\\` to catch graph breaks:** Helps ensure\nend-to-end compilation, maximizing speedup and compatibility with\n`torch.export`.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Demonstrating Speedups\n======================\n\nLet\\'s now demonstrate that using `torch.compile` can speed up real\nmodels. We will compare standard eager mode and `torch.compile` by\nevaluating and training a `torchvision` model on random data.\n\nBefore we start, we need to define some utility functions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Returns the result of running `fn()` and the time it took for `fn()` to run,\n# in seconds. We use CUDA events and synchronization for the most accurate\n# measurements.\ndef timed(fn):\n start = torch.cuda.Event(enable_timing=True)\n end = torch.cuda.Event(enable_timing=True)\n start.record()\n result = fn()\n end.record()\n torch.cuda.synchronize()\n return result, start.elapsed_time(end) / 1000\n\n# Generates random input and targets data for the model, where `b` is\n# batch size.\ndef generate_data(b):\n return (\n torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),\n torch.randint(1000, (b,)).cuda(),\n )\n\nN_ITERS = 10\n\nfrom torchvision.models import densenet121\ndef init_model():\n return densenet121().to(torch.float32).cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let\\'s compare inference.\n\nNote that in the call to `torch.compile`, we have the additional `mode`\nargument, which we will discuss below.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = init_model()\n\n# Reset since we are using a different mode.\nimport torch._dynamo\ntorch._dynamo.reset()\n\nmodel_opt = torch.compile(model, mode=\"reduce-overhead\")\n\ninp = generate_data(16)[0]\nwith torch.no_grad():\n print(\"eager:\", timed(lambda: model(inp))[1])\n print(\"compile:\", timed(lambda: model_opt(inp))[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Notice that `torch.compile` takes a lot longer to complete compared to\neager. This is because `torch.compile` compiles the model into optimized\nkernels as it executes. In our example, the structure of the model\ndoesn\\'t change, and so recompilation is not needed. So if we run our\noptimized model several more times, we should see a significant\nimprovement compared to eager.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "eager_times = []\nfor i in range(N_ITERS):\n inp = generate_data(16)[0]\n with torch.no_grad():\n _, eager_time = timed(lambda: model(inp))\n eager_times.append(eager_time)\n print(f\"eager eval time {i}: {eager_time}\")\n\nprint(\"~\" * 10)\n\ncompile_times = []\nfor i in range(N_ITERS):\n inp = generate_data(16)[0]\n with torch.no_grad():\n _, compile_time = timed(lambda: model_opt(inp))\n compile_times.append(compile_time)\n print(f\"compile eval time {i}: {compile_time}\")\nprint(\"~\" * 10)\n\nimport numpy as np\neager_med = np.median(eager_times)\ncompile_med = np.median(compile_times)\nspeedup = eager_med / compile_med\nassert(speedup > 1)\nprint(f\"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x\")\nprint(\"~\" * 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And indeed, we can see that running our model with `torch.compile`\nresults in a significant speedup. Speedup mainly comes from reducing\nPython overhead and GPU read/writes, and so the observed speedup may\nvary on factors such as model architecture and batch size. For example,\nif a model\\'s architecture is simple and the amount of data is large,\nthen the bottleneck would be GPU compute and the observed speedup may be\nless significant.\n\nYou may also see different speedup results depending on the chosen\n`mode` argument. The `\"reduce-overhead\"` mode uses CUDA graphs to\nfurther reduce the overhead of Python. For your own models, you may need\nto experiment with different modes to maximize speedup. You can read\nmore about modes\n[here](https://pytorch.org/get-started/pytorch-2.0/#user-experience).\n\nYou may might also notice that the second time we run our model with\n`torch.compile` is significantly slower than the other runs, although it\nis much faster than the first run. This is because the\n`\"reduce-overhead\"` mode runs a few warm-up iterations for CUDA graphs.\n\nFor general PyTorch benchmarking, you can try using\n`torch.utils.benchmark` instead of the `timed` function we defined\nabove. We wrote our own timing function in this tutorial to show\n`torch.compile`\\'s compilation latency.\n\nNow, let\\'s consider comparing training.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = init_model()\nopt = torch.optim.Adam(model.parameters())\n\ndef train(mod, data):\n opt.zero_grad(True)\n pred = mod(data[0])\n loss = torch.nn.CrossEntropyLoss()(pred, data[1])\n loss.backward()\n opt.step()\n\neager_times = []\nfor i in range(N_ITERS):\n inp = generate_data(16)\n _, eager_time = timed(lambda: train(model, inp))\n eager_times.append(eager_time)\n print(f\"eager train time {i}: {eager_time}\")\nprint(\"~\" * 10)\n\nmodel = init_model()\nopt = torch.optim.Adam(model.parameters())\ntrain_opt = torch.compile(train, mode=\"reduce-overhead\")\n\ncompile_times = []\nfor i in range(N_ITERS):\n inp = generate_data(16)\n _, compile_time = timed(lambda: train_opt(model, inp))\n compile_times.append(compile_time)\n print(f\"compile train time {i}: {compile_time}\")\nprint(\"~\" * 10)\n\neager_med = np.median(eager_times)\ncompile_med = np.median(compile_times)\nspeedup = eager_med / compile_med\nassert(speedup > 1)\nprint(f\"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x\")\nprint(\"~\" * 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, we can see that `torch.compile` takes longer in the first\niteration, as it must compile the model, but in subsequent iterations,\nwe see significant speedups compared to eager.\n\nWe remark that the speedup numbers presented in this tutorial are for\ndemonstration purposes only. Official speedup values can be seen at the\n[TorchInductor performance\ndashboard](https://hud.pytorch.org/benchmark/compilers).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Comparison to TorchScript and FX Tracing\n========================================\n\nWe have seen that `torch.compile` can speed up PyTorch code. Why else\nshould we use `torch.compile` over existing PyTorch compiler solutions,\nsuch as TorchScript or FX Tracing? Primarily, the advantage of\n`torch.compile` lies in its ability to handle arbitrary Python code with\nminimal changes to existing code.\n\nOne case that `torch.compile` can handle that other compiler solutions\nstruggle with is data-dependent control flow (the `if x.sum() < 0:` line\nbelow).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def f1(x, y):\n if x.sum() < 0:\n return -y\n return y\n\n# Test that `fn1` and `fn2` return the same result, given\n# the same arguments `args`. Typically, `fn1` will be an eager function\n# while `fn2` will be a compiled function (torch.compile, TorchScript, or FX graph).\ndef test_fns(fn1, fn2, args):\n out1 = fn1(*args)\n out2 = fn2(*args)\n return torch.allclose(out1, out2)\n\ninp1 = torch.randn(5, 5)\ninp2 = torch.randn(5, 5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchScript tracing `f1` results in silently incorrect results, since\nonly the actual control flow path is traced.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "traced_f1 = torch.jit.trace(f1, (inp1, inp2))\nprint(\"traced 1, 1:\", test_fns(f1, traced_f1, (inp1, inp2)))\nprint(\"traced 1, 2:\", test_fns(f1, traced_f1, (-inp1, inp2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "FX tracing `f1` results in an error due to the presence of\ndata-dependent control flow.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import traceback as tb\ntry:\n torch.fx.symbolic_trace(f1)\nexcept:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If we provide a value for `x` as we try to FX trace `f1`, then we run\ninto the same problem as TorchScript tracing, as the data-dependent\ncontrol flow is removed in the traced function.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fx_f1 = torch.fx.symbolic_trace(f1, concrete_args={\"x\": inp1})\nprint(\"fx 1, 1:\", test_fns(f1, fx_f1, (inp1, inp2)))\nprint(\"fx 1, 2:\", test_fns(f1, fx_f1, (-inp1, inp2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can see that `torch.compile` correctly handles data-dependent\ncontrol flow.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Reset since we are using a different mode.\ntorch._dynamo.reset()\n\ncompile_f1 = torch.compile(f1)\nprint(\"compile 1, 1:\", test_fns(f1, compile_f1, (inp1, inp2)))\nprint(\"compile 1, 2:\", test_fns(f1, compile_f1, (-inp1, inp2)))\nprint(\"~\" * 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchScript scripting can handle data-dependent control flow, but this\nsolution comes with its own set of problems. Namely, TorchScript\nscripting can require major code changes and will raise errors when\nunsupported Python is used.\n\nIn the example below, we forget TorchScript type annotations and we\nreceive a TorchScript error because the input type for argument `y`, an\n`int`, does not match with the default argument type, `torch.Tensor`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def f2(x, y):\n return x + y\n\ninp1 = torch.randn(5, 5)\ninp2 = 3\n\nscript_f2 = torch.jit.script(f2)\ntry:\n script_f2(inp1, inp2)\nexcept:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "However, `torch.compile` is easily able to handle `f2`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "compile_f2 = torch.compile(f2)\nprint(\"compile 2:\", test_fns(f2, compile_f2, (inp1, inp2)))\nprint(\"~\" * 10)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Another case that `torch.compile` handles well compared to previous\ncompilers solutions is the usage of non-PyTorch functions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import scipy\ndef f3(x):\n x = x * 2\n x = scipy.fft.dct(x.numpy())\n x = torch.from_numpy(x)\n x = x * 2\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchScript tracing treats results from non-PyTorch function calls as\nconstants, and so our results can be silently wrong.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "inp1 = torch.randn(5, 5)\ninp2 = torch.randn(5, 5)\ntraced_f3 = torch.jit.trace(f3, (inp1,))\nprint(\"traced 3:\", test_fns(f3, traced_f3, (inp2,)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchScript scripting and FX tracing disallow non-PyTorch function\ncalls.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "try:\n torch.jit.script(f3)\nexcept:\n tb.print_exc()\n\ntry:\n torch.fx.symbolic_trace(f3)\nexcept:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In comparison, `torch.compile` is easily able to handle the non-PyTorch\nfunction call.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "compile_f3 = torch.compile(f3)\nprint(\"compile 3:\", test_fns(f3, compile_f3, (inp2,)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchDynamo and FX Graphs\n=========================\n\nOne important component of `torch.compile` is TorchDynamo. TorchDynamo\nis responsible for JIT compiling arbitrary Python code into [FX\ngraphs](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph), which\ncan then be further optimized. TorchDynamo extracts FX graphs by\nanalyzing Python bytecode during runtime and detecting calls to PyTorch\noperations.\n\nNormally, TorchInductor, another component of `torch.compile`, further\ncompiles the FX graphs into optimized kernels, but TorchDynamo allows\nfor different backends to be used. In order to inspect the FX graphs\nthat TorchDynamo outputs, let us create a custom backend that outputs\nthe FX graph and simply returns the graph\\'s unoptimized forward method.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import List\ndef custom_backend(gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):\n print(\"custom backend called with FX graph:\")\n gm.graph.print_tabular()\n return gm.forward\n\n# Reset since we are using a different backend.\ntorch._dynamo.reset()\n\nopt_model = torch.compile(init_model(), backend=custom_backend)\nopt_model(generate_data(16)[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using our custom backend, we can now see how TorchDynamo is able to\nhandle data-dependent control flow. Consider the function below, where\nthe line `if b.sum() < 0` is the source of data-dependent control flow.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def bar(a, b):\n x = a / (torch.abs(a) + 1)\n if b.sum() < 0:\n b = b * -1\n return x * b\n\nopt_bar = torch.compile(bar, backend=custom_backend)\ninp1 = torch.randn(10)\ninp2 = torch.randn(10)\nopt_bar(inp1, inp2)\nopt_bar(inp1, -inp2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The output reveals that TorchDynamo extracted 3 different FX graphs\ncorresponding the following code (order may differ from the output\nabove):\n\n1. `x = a / (torch.abs(a) + 1)`\n2. `b = b * -1; return x * b`\n3. `return x * b`\n\nWhen TorchDynamo encounters unsupported Python features, such as\ndata-dependent control flow, it breaks the computation graph, lets the\ndefault Python interpreter handle the unsupported code, then resumes\ncapturing the graph.\n\nLet\\'s investigate by example how TorchDynamo would step through `bar`.\nIf `b.sum() < 0`, then TorchDynamo would run graph 1, let Python\ndetermine the result of the conditional, then run graph 2. On the other\nhand, if `not b.sum() < 0`, then TorchDynamo would run graph 1, let\nPython determine the result of the conditional, then run graph 3.\n\nThis highlights a major difference between TorchDynamo and previous\nPyTorch compiler solutions. When encountering unsupported Python\nfeatures, previous solutions either raise an error or silently fail.\nTorchDynamo, on the other hand, will break the computation graph.\n\nWe can see where TorchDynamo breaks the graph by using\n`torch._dynamo.explain`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Reset since we are using a different backend.\ntorch._dynamo.reset()\nexplain_output = torch._dynamo.explain(bar)(torch.randn(10), torch.randn(10))\nprint(explain_output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to maximize speedup, graph breaks should be limited. We can\nforce TorchDynamo to raise an error upon the first graph break\nencountered by using `fullgraph=True`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "opt_bar = torch.compile(bar, fullgraph=True)\ntry:\n opt_bar(torch.randn(10), torch.randn(10))\nexcept:\n tb.print_exc()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And below, we demonstrate that TorchDynamo does not break the graph on\nthe model we used above for demonstrating speedups.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "opt_model = torch.compile(init_model(), fullgraph=True)\nprint(opt_model(generate_data(16)[0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can use `torch.export` (from PyTorch 2.1+) to extract a single,\nexportable FX graph from the input PyTorch program. The exported graph\nis intended to be run on different (i.e. Python-less) environments. One\nimportant restriction is that the `torch.export` does not support graph\nbreaks. Please check [this\ntutorial](https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html)\nfor more details on `torch.export`.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, we introduced `torch.compile` by covering basic usage,\ndemonstrating speedups over eager mode, comparing to previous PyTorch\ncompiler solutions, and briefly investigating TorchDynamo and its\ninteractions with FX graphs. We hope that you will give `torch.compile`\na try!\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }