{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using User-Defined Triton Kernels with `torch.compile`\n======================================================\n\n**Author:** [Oguz Ulgen](https://github.com/oulgen)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "User-defined Triton kernels can be used to optimize specific parts of\nyour model\\'s computation. These kernels are written in Triton\\'s\nlanguage, which is designed to make it easier to achieve peak hardware\nperformance. By using user-defined Triton kernels with `torch.compile`,\nyou can integrate these optimized computations into your PyTorch model,\npotentially achieving significant performance improvements.\n\nThis recipes demonstrates how you can use user-defined Triton kernels\nwith `torch.compile`.\n\nPrerequisites\n=============\n\nBefore starting this recipe, make sure that you have the following:\n\n- Basic understanding of `torch.compile` and Triton. See:\n - [torch.compiler API\n documentation](https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler)\n - [Introduction to\n torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)\n - [Triton language\n documentation](https://triton-lang.org/main/index.html)\n- PyTorch 2.3 or later\n- A GPU that supports Triton\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torch.utils._triton import has_triton" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basic Usage\n===========\n\nIn this example, we will use a simple vector addition kernel from the\nTriton documentation with `torch.compile`. For reference, see [Triton\ndocumentation](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if not has_triton():\n print(\"Skipping because triton is not supported on this device.\")\nelse:\n import triton\n from triton import language as tl\n\n @triton.jit\n def add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n @torch.compile(fullgraph=True)\n def add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)\n return output\n\n x = torch.randn(4, device=\"cuda\")\n y = torch.randn(4, device=\"cuda\")\n out = add_fn(x, y)\n print(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Advanced Usage\n==============\n\nTriton\\'s autotune feature is a powerful tool that automatically\noptimizes the configuration parameters of your Triton kernels. It\nexplores a range of possible configurations and selects the one that\ndelivers the best performance for your specific use case.\n\nWhen used with `torch.compile`, `triton.autotune` can help ensure that\nyour PyTorch model is running as efficiently as possible. Here is an\nexample of using `torch.compile` and `triton.autotune`.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

torch.compile only supports configs and key arguments to triton.autotune.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if not has_triton():\n print(\"Skipping because triton is not supported on this device.\")\nelse:\n import triton\n from triton import language as tl\n\n @triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 4}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_SIZE\": 2}, num_stages=4, num_warps=4),\n ],\n key=[],\n )\n @triton.jit\n def add_kernel_autotuned(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n @torch.compile(fullgraph=True)\n def add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel_autotuned[grid](x, y, output, n_elements)\n return output\n\n x = torch.randn(4, device=\"cuda\")\n y = torch.randn(4, device=\"cuda\")\n out = add_fn(x, y)\n print(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Composability\n=============\n\nUser-defined Triton kernels do not automatically support all PyTorch\nsubsystems. This can be seen in the following use cases:\n\n- Adding a CPU fallback\n- Adding a `FlopCounter` formula\n- Composing with Tensor Subclasses\n\nTo compose with additional PyTorch subsystems, use\n`torch.library.triton_op`.\n\n`triton_op is` a structured way of defining a custom operator that is\nbacked by one or more Triton kernels: like regular custom operators\n(`torch.library.custom_op`), you are able to specify the interactions\nwith PyTorch subsystems via `torch.library`. However, unlike\n`torch.library.custom_op`, which creates opaque callables with respect\nto `torch.compile`, `torch.compile` traces into `triton_op` to apply\noptimizations.\n\nHere's a chart of which API to use when integrating Triton kernels with\nPyTorch.\n\n Triton kernel (no explicit `torch.library` wrapper) `torch.library.triton_op` `torch.library.custom_op`\n ----------------------------------------------------------------------------------- ----------------------------------------------------- --------------------------- ---------------------------\n Supports inference Yes Yes Yes\n Supports training In the majority of cases Yes Yes\n Supports `torch.compile` Yes Yes Yes\n Supports `torch.compile(fullgraph=True)` In the majority of cases In the majority of cases In all cases\n Does torch.compile trace into the implementation? Yes Yes No\n Supports AOTInductor Yes Yes No\n Supports PyTorch Subsystems like FlopCounterMode, CPU Fallback, Tensor Subclasses No Yes Yes\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Wrapping Triton kernels with `triton_op`\n\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\\^\n\nUse `torch.library.triton_op` to wrap a function that may invoke one or\nmore Triton kernels. Use `torch.library.wrap_triton` to wrap the calls\nto the Triton kernel.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.library import triton_op, wrap_triton\n\n@triton_op(\"mylib::mysin\", mutates_args={})\ndef mysin(x: torch.Tensor) -> torch.Tensor:\n out = torch.empty_like(x)\n n_elements = x.numel()\n wrap_triton(sin_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)\n return out\n\n@triton.jit\ndef sin_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = tl.sin(x)\n tl.store(out_ptr + offsets, output, mask=mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can invoke the `triton_op` in one of the following two ways.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = torch.randn(3, device=\"cuda\")\ny = mysin(x)\nz = torch.ops.mylib.mysin.default(x)\n\nassert torch.allclose(y, x.sin())\nassert torch.allclose(z, x.sin())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The resulting `triton_op` works with `torch.compile` and `AOTInductor`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "y = torch.compile(mysin)(x)\nassert torch.allclose(y, x.sin())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding training support\n=======================\n\nUse `register_autograd` to add an autograd formula for the `triton_op`.\nPrefer this to using `torch.autograd.Function` (which has various\ncomposability footguns with `torch.compile`).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def backward(ctx, grad):\n x, = ctx.saved_tensors\n return grad * x.cos()\n\ndef setup_context(ctx, inputs, output):\n x, = inputs\n ctx.save_for_backward(x)\n\nmysin.register_autograd(backward, setup_context=setup_context)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the backward must be a composition of PyTorch-understood\noperators. If you want the backward to call Triton kernels, then those\nmust be wrapped in `triton_op` as well:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@triton.jit\ndef cos_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = tl.cos(x)\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton_op(\"mylib::mycos\", mutates_args={})\ndef mycos(x: torch.Tensor) -> torch.Tensor:\n out = torch.empty_like(x)\n n_elements = x.numel()\n wrap_triton(cos_kernel)[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)\n return out\n\ndef backward(ctx, grad):\n x, = ctx.saved_tensors\n return grad * mycos(x)\n\ndef setup_context(ctx, inputs, output):\n x, = inputs\n ctx.save_for_backward(x)\n\nmysin.register_autograd(backward, setup_context=setup_context)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding a CPU Fallback\n=====================\n\nTriton kernels don't run on CPU. Use `register_kernel` to add a CPU (or\nany other device) fallback for the `triton_op`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@mysin.register_kernel(\"cpu\")\ndef _(x):\n return torch.sin(x)\n\nx = torch.randn(3)\ny = mysin(x)\nassert torch.allclose(y, x.sin())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The fallback must be composed of PyTorch operators.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding a FlopCounter Formula\n============================\n\nTo specify how many flops the triton kernel reports under PyTorch\\'s\nflop counter, use `register_flop_formula`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.utils.flop_counter import FlopCounterMode, register_flop_formula\n\n@register_flop_formula(torch.ops.mylib.mysin)\ndef _(x_shape):\n numel = 1\n for s in x_shape:\n numel *= s\n return numel\n\nx = torch.randn(3, device=\"cuda\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`FlopCounterMode` requires\n[tabulate](https://pypi.org/project/tabulate/). Before running the code\nbelow, make sure you have `tabulate` installed or install by running\n`pip install tabulate`.\n\n\\>\\>\\> with FlopCounterMode() as flop\\_counter: \\>\\>\\> y = mysin(x)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Limitations\n===========\n\nAs of PyTorch 2.3, the support for user-defined Triton kernels in\n`torch.compile` includes dynamic shapes, `torch.autograd.Function`, JIT\ninductor, and AOT inductor. You can use these features together to build\ncomplex, high-performance models.\n\nPyTorch 2.6 added `torch.library.triton_op`, which adds support for\nuser-defined Triton kernels in tensor subclasses and other advanced\nfeatures.\n\nHowever, there are certain limitations to be aware of:\n\n- **Triton Features:** While `triton.heuristics` can be used either\n standalone or before `triton.autotune`, it cannot be used after\n `triton.autotune`. This implies that if `triton.heuristics` and\n `triton.autotune` are to be used together, `triton.heuristics` must\n be used first.\n\nConclusion\n==========\n\nIn this recipe, we explored how to utilize user-defined Triton kernels\nwith `torch.compile`. We delved into the basic usage of a simple vector\naddition kernel and advanced usage involving Triton\\'s autotune feature.\nWe also discussed the composability of user-defined Triton kernels with\nother PyTorch features and highlighted some current limitations.\n\nSee Also\n========\n\n- [Compiling the\n Optimizers](https://pytorch.org/tutorials/recipes/compiling_optimizer.html)\n- [Implementing High-Performance Transformers with Scaled Dot Product\n Attention](https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }