{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using User-Defined Triton Kernels with `torch.compile`\n======================================================\n\n**Author:** [Oguz Ulgen](https://github.com/oulgen)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "User-defined Triton kernels can be used to optimize specific parts of\nyour model\\'s computation. These kernels are written in Triton\\'s\nlanguage, which is designed to make it easier to achieve peak hardware\nperformance. By using user-defined Triton kernels with `torch.compile`,\nyou can integrate these optimized computations into your PyTorch model,\npotentially achieving significant performance improvements.\n\nThis recipes demonstrates how you can use user-defined Triton kernels\nwith `torch.compile`.\n\nPrerequisites\n=============\n\nBefore starting this recipe, make sure that you have the following:\n\n- Basic understanding of `torch.compile` and Triton. See:\n - [torch.compiler API\n documentation](https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler)\n - [Introduction to\n torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)\n - [Triton language\n documentation](https://triton-lang.org/main/index.html)\n- PyTorch 2.3 or later\n- A GPU that supports Triton\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torch.utils._triton import has_triton" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Basic Usage\n===========\n\nIn this example, we will use a simple vector addition kernel from the\nTriton documentation with `torch.compile`. For reference, see [Triton\ndocumentation](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if not has_triton():\n print(\"Skipping because triton is not supported on this device.\")\nelse:\n import triton\n from triton import language as tl\n\n @triton.jit\n def add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n ):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\n @torch.compile(fullgraph=True)\n def add_fn(x, y):\n output = torch.zeros_like(x)\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)\n return output\n\n x = torch.randn(4, device=\"cuda\")\n y = torch.randn(4, device=\"cuda\")\n out = add_fn(x, y)\n print(f\"Vector addition of\\nX:\\t{x}\\nY:\\t{y}\\nis equal to\\n{out}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Advanced Usage\n==============\n\nTriton\\'s autotune feature is a powerful tool that automatically\noptimizes the configuration parameters of your Triton kernels. It\nexplores a range of possible configurations and selects the one that\ndelivers the best performance for your specific use case.\n\nWhen used with `torch.compile`, `triton.autotune` can help ensure that\nyour PyTorch model is running as efficiently as possible. Here is an\nexample of using `torch.compile` and `triton.autotune`.\n\n```{=html}\n
torch.compile
only supports configs and key arguments to triton.autotune
.