{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Dynamic Compilation Control with `torch.compiler.set_stance`\n============================================================\n\n**Author:** [William Wen](https://github.com/williamwen42)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`torch.compiler.set_stance` is a `torch.compiler` API that enables you\nto change the behavior of `torch.compile` across different calls to your\nmodel without having to reapply `torch.compile` to your model.\n\nThis recipe provides some examples on how to use\n`torch.compiler.set_stance`.\n\n::: {.contents local=\"\"}\n:::\n\nPrerequisites\n=============\n\n- `torch >= 2.6`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Description\n===========\n\n`torch.compile.set_stance` can be used as a decorator, context manager,\nor raw function to change the behavior of `torch.compile` across\ndifferent calls to your model.\n\nIn the example below, the `\"force_eager\"` stance ignores all\n`torch.compile` directives.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n\n\n@torch.compile\ndef foo(x):\n if torch.compiler.is_compiling():\n # torch.compile is active\n return x + 1\n else:\n # torch.compile is not active\n return x - 1\n\n\ninp = torch.zeros(3)\n\nprint(foo(inp)) # compiled, prints 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample decorator usage\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compiler.set_stance(\"force_eager\")\ndef bar(x):\n # force disable the compiler\n return foo(x)\n\n\nprint(bar(inp)) # not compiled, prints -1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample context manager usage\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.compiler.set_stance(\"force_eager\"):\n print(foo(inp)) # not compiled, prints -1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample raw function usage\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.compiler.set_stance(\"force_eager\")\nprint(foo(inp)) # not compiled, prints -1\ntorch.compiler.set_stance(\"default\")\n\nprint(foo(inp)) # compiled, prints 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`torch.compile` stance can only be changed **outside** of any\n`torch.compile` region. Attempts to do otherwise will result in an\nerror.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile\ndef baz(x):\n # error!\n with torch.compiler.set_stance(\"force_eager\"):\n return x + 1\n\n\ntry:\n baz(inp)\nexcept Exception as e:\n print(e)\n\n\n@torch.compiler.set_stance(\"force_eager\")\ndef inner(x):\n return x + 1\n\n\n@torch.compile\ndef outer(x):\n # error!\n return inner(x)\n\n\ntry:\n outer(inp)\nexcept Exception as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Other stances include:\n\n: - `\"default\"`: The default stance, used for normal compilation.\n - `\"eager_on_recompile\"`: Run code eagerly when a recompile is\n necessary. If there is cached compiled code valid for the input,\n it will still be used.\n - `\"fail_on_recompile\"`: Raise an error when recompiling a\n function.\n\nSee the `torch.compiler.set_stance` [doc\npage](https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance)\nfor more stances and options. More stances/options may also be added in\nthe future.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Examples\n========\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Preventing recompilation\n========================\n\nSome models do not expect any recompilations - for example, you may\nalways have inputs with the same shape. Since recompilations may be\nexpensive, we may wish to error out when we attempt to recompile so we\ncan detect and fix recompilation cases. The `\"fail_on_recompilation\"`\nstance can be used for this.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile\ndef my_big_model(x):\n return torch.relu(x)\n\n\n# first compilation\nmy_big_model(torch.randn(3))\n\nwith torch.compiler.set_stance(\"fail_on_recompile\"):\n my_big_model(torch.randn(3)) # no recompilation - OK\n try:\n my_big_model(torch.randn(4)) # recompilation - error\n except Exception as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If erroring out is too disruptive, we can use `\"eager_on_recompile\"`\ninstead, which will cause `torch.compile` to fall back to eager instead\nof erroring out. This may be useful if we don\\'t expect recompilations\nto happen frequently, but when one is required, we\\'d rather pay the\ncost of running eagerly over the cost of recompilation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile\ndef my_huge_model(x):\n if torch.compiler.is_compiling():\n return x + 1\n else:\n return x - 1\n\n\n# first compilation\nprint(my_huge_model(torch.zeros(3))) # 1\n\nwith torch.compiler.set_stance(\"eager_on_recompile\"):\n print(my_huge_model(torch.zeros(3))) # 1\n print(my_huge_model(torch.zeros(4))) # -1\n print(my_huge_model(torch.zeros(3))) # 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Measuring performance gains\n===========================\n\n`torch.compiler.set_stance` can be used to compare eager vs. compiled\nperformance without having to define a separate eager model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Returns the result of running `fn()` and the time it took for `fn()` to run,\n# in seconds. We use CUDA events and synchronization for the most accurate\n# measurements.\ndef timed(fn):\n start = torch.cuda.Event(enable_timing=True)\n end = torch.cuda.Event(enable_timing=True)\n start.record()\n result = fn()\n end.record()\n torch.cuda.synchronize()\n return result, start.elapsed_time(end) / 1000\n\n\n@torch.compile\ndef my_gigantic_model(x, y):\n x = x @ y\n x = x @ y\n x = x @ y\n return x\n\n\ninps = torch.randn(5, 5), torch.randn(5, 5)\n\nwith torch.compiler.set_stance(\"force_eager\"):\n print(\"eager:\", timed(lambda: my_gigantic_model(*inps))[1])\n\n# warmups\nfor _ in range(3):\n my_gigantic_model(*inps)\n\nprint(\"compiled:\", timed(lambda: my_gigantic_model(*inps))[1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Crashing sooner\n===============\n\nRunning an eager iteration first before a compiled iteration using the\n`\"force_eager\"` stance can help us to catch errors unrelated to\n`torch.compile` before attempting a very long compile.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile\ndef my_humongous_model(x):\n return torch.sin(x, x)\n\n\ntry:\n with torch.compiler.set_stance(\"force_eager\"):\n print(my_humongous_model(torch.randn(3)))\n # this call to the compiled model won't run\n print(my_humongous_model(torch.randn(3)))\nexcept Exception as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this recipe, we have learned how to use the\n`torch.compiler.set_stance` API to modify the behavior of\n`torch.compile` across different calls to a model without needing to\nreapply it. The recipe demonstrates using `torch.compiler.set_stance` as\na decorator, context manager, or raw function to control compilation\nstances like `force_eager`, `default`, `eager_on_recompile`, and\n\\\"fail\\_on\\_recompile.\\\"\n\nFor more information, see: [torch.compiler.set\\_stance API\ndocumentation](https://pytorch.org/docs/main/generated/torch.compiler.set_stance.html#torch.compiler.set_stance).\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }