{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(beta) Utilizing Torch Function modes with torch.compile\n========================================================\n\n**Author:** [Michael Lazos](https://github.com/mlazos)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This recipe covers how to use a key torch extensibility point,\n\n: torch function modes, in tandem with `torch.compile` to override the\n behavior of torch operators, also know as **ops**, at trace time,\n with no runtime overhead.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This recipe requires PyTorch 2.7.0 or later.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Rewriting a torch op (torch.add -\\> torch.mul)\n==============================================\n\nFor this example, we\\'ll use torch function modes to rewrite occurences\nof addition with multiply instead. This type of override can be common\nif a certain backend has a custom implementation that should be\ndispatched for a given op.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n\n# exit cleanly if we are on a device that doesn't support ``torch.compile``\nif torch.cuda.get_device_capability() < (7, 0):\n print(\"Exiting because torch.compile is not supported on this device.\")\n import sys\n sys.exit(0)\n\nfrom torch.overrides import BaseTorchFunctionMode\n\n# Define our mode, Note: ``BaseTorchFunctionMode``\n# implements the actual invocation of func(..)\nclass AddToMultiplyMode(BaseTorchFunctionMode):\n def __torch_function__(self, func, types, args=(), kwargs=None):\n if func == torch.Tensor.add:\n func = torch.mul\n\n return super().__torch_function__(func, types, args, kwargs)\n\n@torch.compile()\ndef test_fn(x, y):\n return x + y * x # Note: infix operators map to torch.Tensor.* methods\n\nx = torch.rand(2, 2)\ny = torch.rand_like(x)\n\nwith AddToMultiplyMode():\n z = test_fn(x, y)\n\nassert torch.allclose(z, x * y * x)\n\n# The mode can also be used within the compiled region as well like this:\n\n@torch.compile()\ndef test_fn(x, y):\n with AddToMultiplyMode():\n return x + y * x # Note: infix operators map to torch.Tensor.* methods\n\nx = torch.rand(2, 2)\ny = torch.rand_like(x)\nz = test_fn(x, y)\n\nassert torch.allclose(z, x * y * x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this recipe we demonstrated how to override the behavior of `torch.*`\noperators using torch function modes from within `torch.compile`. This\nenables users to utilize the extensibility benefits of torch function\nmodes without the runtime overhead of calling torch function on every op\ninvocation.\n\n- See [Extending Torch API with\n Modes](https://pytorch.org/docs/stable/notes/extending.html#extending-all-torch-api-with-modes)\n for other examples and background on Torch Function modes.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }