{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Custom Python Operators {#python-custom-ops-tutorial}\n=======================\n\n```{=html}\n

What you will learn

Prerequisites

\n```\nPyTorch offers a large library of operators that work on Tensors (e.g.\n`torch.add`, `torch.sum`, etc). However, you might wish to use a new\ncustomized operator with PyTorch, perhaps written by a third-party\nlibrary. This tutorial shows how to wrap Python functions so that they\nbehave like PyTorch native operators. Reasons why you may wish to create\na custom operator in PyTorch include:\n\n- Treating an arbitrary Python function as an opaque callable with\n respect to `torch.compile` (that is, prevent `torch.compile` from\n tracing into the function).\n- Adding training support to an arbitrary Python function\n\nUse `torch.library.custom_op`{.interpreted-text role=\"func\"} to create\nPython custom operators. Use the C++ `TORCH_LIBRARY` APIs to create C++\ncustom operators (these work in Python-less environments). See the\n[Custom Operators Landing\nPage](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html)\nfor more details.\n\nPlease note that if your operation can be expressed as a composition of\nexisting PyTorch operators, then there is usually no need to use the\ncustom operator API \\-- everything (for example `torch.compile`,\ntraining support) should just work.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Example: Wrapping PIL\\'s crop into a custom operator\n\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\--Let\\'s\nsay that we are using PIL\\'s `crop` operation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torchvision.transforms.functional import to_pil_image, pil_to_tensor\nimport PIL\nimport IPython\nimport matplotlib.pyplot as plt\n\ndef crop(pic, box):\n img = to_pil_image(pic.cpu())\n cropped_img = img.crop(box)\n return pil_to_tensor(cropped_img).to(pic.device) / 255.\n\ndef display(img):\n plt.imshow(img.numpy().transpose((1, 2, 0)))\n\nimg = torch.ones(3, 64, 64)\nimg *= torch.linspace(0, 1, steps=64) * torch.linspace(0, 1, steps=64).unsqueeze(-1)\ndisplay(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cropped_img = crop(img, (10, 10, 50, 50))\ndisplay(cropped_img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`crop` is not handled effectively out-of-the-box by `torch.compile`:\n`torch.compile` induces a [\\\"graph\nbreak\\\"](https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks)\non functions it is unable to handle and graph breaks are bad for\nperformance. The following code demonstrates this by raising an error\n(`torch.compile` with `fullgraph=True` raises an error if a graph break\noccurs).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile(fullgraph=True)\ndef f(img):\n return crop(img, (10, 10, 50, 50))\n\n# The following raises an error. Uncomment the line to see it.\n# cropped_img = f(img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In order to black-box `crop` for use with `torch.compile`, we need to do\ntwo things:\n\n1. wrap the function into a PyTorch custom operator.\n2. add a \\\"`FakeTensor` kernel\\\" (aka \\\"meta kernel\\\") to the operator.\n Given some `FakeTensors` inputs (dummy Tensors that don\\'t have\n storage), this function should return dummy Tensors of your choice\n with the correct Tensor metadata (shape/strides/`dtype`/device).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import Sequence\n\n# Use torch.library.custom_op to define a new custom operator.\n# If your operator mutates any input Tensors, their names must be specified\n# in the ``mutates_args`` argument.\n@torch.library.custom_op(\"mylib::crop\", mutates_args=())\ndef crop(pic: torch.Tensor, box: Sequence[int]) -> torch.Tensor:\n img = to_pil_image(pic.cpu())\n cropped_img = img.crop(box)\n return (pil_to_tensor(cropped_img) / 255.).to(pic.device, pic.dtype)\n\n# Use register_fake to add a ``FakeTensor`` kernel for the operator\n@crop.register_fake\ndef _(pic, box):\n channels = pic.shape[0]\n x0, y0, x1, y1 = box\n result = pic.new_empty(y1 - y0, x1 - x0, channels).permute(2, 0, 1)\n # The result should have the same metadata (shape/strides/``dtype``/device)\n # as running the ``crop`` function above.\n return result" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After this, `crop` now works without graph breaks:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile(fullgraph=True)\ndef f(img):\n return crop(img, (10, 10, 50, 50))\n\ncropped_img = f(img)\ndisplay(img)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "display(cropped_img)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding training support for crop\n================================\n\nUse `torch.library.register_autograd` to add training support for an\noperator. Prefer this over directly using `torch.autograd.Function`;\nsome compositions of `autograd.Function` with PyTorch operator\nregistration APIs can lead to (and has led to) silent incorrectness when\ncomposed with `torch.compile`.\n\nIf you don\\'t need training support, there is no need to use\n`torch.library.register_autograd`. If you end up training with a\n`custom_op` that doesn\\'t have an autograd registration, we\\'ll raise an\nerror message.\n\nThe gradient formula for `crop` is essentially `PIL.paste` (we\\'ll leave\nthe derivation as an exercise to the reader). Let\\'s first wrap `paste`\ninto a custom operator:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.library.custom_op(\"mylib::paste\", mutates_args=())\ndef paste(im1: torch.Tensor, im2: torch.Tensor, coord: Sequence[int]) -> torch.Tensor:\n assert im1.device == im2.device\n assert im1.dtype == im2.dtype\n im1_pil = to_pil_image(im1.cpu())\n im2_pil = to_pil_image(im2.cpu())\n PIL.Image.Image.paste(im1_pil, im2_pil, coord)\n return (pil_to_tensor(im1_pil) / 255.).to(im1.device, im1.dtype)\n\n@paste.register_fake\ndef _(im1, im2, coord):\n assert im1.device == im2.device\n assert im1.dtype == im2.dtype\n return torch.empty_like(im1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And now let\\'s use `register_autograd` to specify the gradient formula\nfor `crop`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def backward(ctx, grad_output):\n grad_input = grad_output.new_zeros(ctx.pic_shape)\n grad_input = paste(grad_input, grad_output, ctx.coords)\n return grad_input, None\n\ndef setup_context(ctx, inputs, output):\n pic, box = inputs\n ctx.coords = box[:2]\n ctx.pic_shape = pic.shape\n\ncrop.register_autograd(backward, setup_context=setup_context)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that the backward must be a composition of PyTorch-understood\noperators, which is why we wrapped paste into a custom operator instead\nof directly using PIL\\'s paste.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "img = img.requires_grad_()\nresult = crop(img, (10, 10, 50, 50))\nresult.sum().backward()\ndisplay(img.grad)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This is the correct gradient, with 1s (white) in the cropped region and\n0s (black) in the unused region.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Testing Python Custom operators\n===============================\n\nUse `torch.library.opcheck` to test that the custom operator was\nregistered correctly. This does not test that the gradients are\nmathematically correct; please write separate tests for that (either\nmanual ones or `torch.autograd.gradcheck`).\n\nTo use `opcheck`, pass it a set of example inputs to test against. If\nyour operator supports training, then the examples should include\nTensors that require grad. If your operator supports multiple devices,\nthen the examples should include Tensors from each device.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "examples = [\n [torch.randn(3, 64, 64), [0, 0, 10, 10]],\n [torch.randn(3, 91, 91, requires_grad=True), [10, 0, 20, 10]],\n [torch.randn(3, 60, 60, dtype=torch.double), [3, 4, 32, 20]],\n [torch.randn(3, 512, 512, requires_grad=True, dtype=torch.double), [3, 4, 32, 45]],\n]\n\nfor example in examples:\n torch.library.opcheck(crop, example)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Mutable Python Custom operators\n===============================\n\nYou can also wrap a Python function that mutates its inputs into a\ncustom operator. Functions that mutate inputs are common because that is\nhow many low-level kernels are written; for example, a kernel that\ncomputes `sin` may take in the input and an output tensor and write\n`input.sin()` to the output tensor.\n\nWe\\'ll use `numpy.sin` to demonstrate an example of a mutable Python\ncustom operator.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\n@torch.library.custom_op(\"mylib::numpy_sin\", mutates_args={\"output\"}, device_types=\"cpu\")\ndef numpy_sin(input: torch.Tensor, output: torch.Tensor) -> None:\n assert input.device == output.device\n assert input.device.type == \"cpu\"\n input_np = input.numpy()\n output_np = output.numpy()\n np.sin(input_np, out=output_np)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because the operator doesn\\'t return anything, there is no need to\nregister a `FakeTensor` kernel (meta kernel) to get it to work with\n`torch.compile`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "@torch.compile(fullgraph=True)\ndef f(x):\n out = torch.empty(3)\n numpy_sin(x, out)\n return out\n\nx = torch.randn(3)\ny = f(x)\nassert torch.allclose(y, x.sin())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And here\\'s an `opcheck` run telling us that we did indeed register the\noperator correctly. `opcheck` would error out if we forgot to add the\noutput to `mutates_args`, for example.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "example_inputs = [\n [torch.randn(3), torch.empty(3)],\n [torch.randn(0, 3), torch.empty(0, 3)],\n [torch.randn(1, 2, 3, 4, dtype=torch.double), torch.empty(1, 2, 3, 4, dtype=torch.double)],\n]\n\nfor example in example_inputs:\n torch.library.opcheck(numpy_sin, example)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, we learned how to use `torch.library.custom_op` to\ncreate a custom operator in Python that works with PyTorch subsystems\nsuch as `torch.compile` and autograd.\n\nThis tutorial provides a basic introduction to custom operators. For\nmore detailed information, see:\n\n- [the torch.library\n documentation](https://pytorch.org/docs/stable/library.html)\n- [the Custom Operators\n Manual](https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html#the-custom-operators-manual)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }