{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(beta) Using TORCH\\_LOGS python API with torch.compile\n======================================================\n\n**Author:** [Michael Lazos](https://github.com/mlazos)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import logging" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial introduces the `TORCH_LOGS` environment variable, as well\nas the Python API, and demonstrates how to apply it to observe the\nphases of `torch.compile`.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This tutorial requires PyTorch 2.2.0 or later.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Setup\n=====\n\nIn this example, we\\'ll set up a simple Python function which performs\nan elementwise add and observe the compilation process with `TORCH_LOGS`\nPython API.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

There is also an environment variable TORCH_LOGS, which can be used tochange logging settings at the command line. The equivalent environmentvariable setting is shown for each example.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n\n# exit cleanly if we are on a device that doesn't support torch.compile\nif torch.cuda.get_device_capability() < (7, 0):\n print(\"Skipping because torch.compile is not supported on this device.\")\nelse:\n @torch.compile()\n def fn(x, y):\n z = x + y\n return z + 2\n\n\n inputs = (torch.ones(2, 2, device=\"cuda\"), torch.zeros(2, 2, device=\"cuda\"))\n\n\n# print separator and reset dynamo\n# between each example\n def separator(name):\n print(f\"==================={name}=========================\")\n torch._dynamo.reset()\n\n\n separator(\"Dynamo Tracing\")\n# View dynamo tracing\n# TORCH_LOGS=\"+dynamo\"\n torch._logging.set_logs(dynamo=logging.DEBUG)\n fn(*inputs)\n\n separator(\"Traced Graph\")\n# View traced graph\n# TORCH_LOGS=\"graph\"\n torch._logging.set_logs(graph=True)\n fn(*inputs)\n\n separator(\"Fusion Decisions\")\n# View fusion decisions\n# TORCH_LOGS=\"fusion\"\n torch._logging.set_logs(fusion=True)\n fn(*inputs)\n\n separator(\"Output Code\")\n# View output code generated by inductor\n# TORCH_LOGS=\"output_code\"\n torch._logging.set_logs(output_code=True)\n fn(*inputs)\n\n separator(\"\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial we introduced the TORCH\\_LOGS environment variable and\npython API by experimenting with a small number of the available logging\noptions. To view descriptions of all available options, run any python\nscript which imports torch and set TORCH\\_LOGS to \\\"help\\\".\n\nAlternatively, you can view the [torch.\\_logging\ndocumentation](https://pytorch.org/docs/main/logging.html) to see\ndescriptions of all available logging options.\n\nFor more information on torch.compile, see the [torch.compile\ntutorial](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }