{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Introduction to ONNX](intro_onnx.html) \\|\\| [Exporting a PyTorch model\nto ONNX](export_simple_model_to_onnx_tutorial.html) \\|\\| **Extending the\nONNX exporter operator support** \\|\\| [Export a model with control flow\nto ONNX](export_control_flow_model_to_onnx_tutorial.html)\n\nExtending the ONNX Exporter Operator Support\n============================================\n\n**Authors:** [Ti-Tai Wang](titaiwang@microsoft.com), [Justin\nChu](justinchu@microsoft.com)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overview\n========\n\nThis tutorial describes how you can create ONNX implementation for\nunsupported PyTorch operators or replace existing implementation with\nyour own.\n\nWe will cover three scenarios that require extending the ONNX\nexporter\\'s operator support:\n\n- Overriding the implementation of an existing PyTorch operator\n- Using custom ONNX operators\n- Supporting a custom PyTorch operator\n\nWhat you will learn:\n\n- How to override or add support for PyTorch operators in ONNX.\n- How to integrate custom ONNX operators for specialized runtimes.\n- How to implement and translate custom PyTorch operators to ONNX.\n\nPrerequisites\n-------------\n\nBefore starting this tutorial, make sure you have completed the\nfollowing prerequisites:\n\n- `torch >= 2.6`\n- The target PyTorch operator\n- Completed the [ONNX Script\n tutorial](https://github.com/microsoft/onnxscript/blob/main/docs/tutorial/index.md)\n before proceeding\n- The implementation of the operator using [ONNX\n Script](https://github.com/microsoft/onnxscript)\n\nOverriding the implementation of an existing PyTorch operator\n=============================================================\n\nAlthough the ONNX exporter team does their best efforts to support all\nPyTorch operators, some of them might not be supported yet. In this\nsection, we will demonstrate how you can add unsupported PyTorch\noperators to the ONNX Registry.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

The steps to implement unsupported PyTorch operators are the same as those for replacing the implementation of an existingPyTorch operator with a custom one.Because we don't actually have an unsupported PyTorch operator to use in this tutorial, we are going to leveragethis and replace the implementation of torch.ops.aten.add.Tensor with a custom implementation the same way we wouldif the operator was not implemented by the ONNX exporter.

\n```\n```{=html}\n
\n```\nWhen a model cannot be exported to ONNX due to an unsupported operator,\nthe ONNX exporter will show an error message similar to:\n\n``` {.python}\nNo decompositions registered for [...]\n```\n\nThe error message indicates that the unsupported PyTorch operator is\n`torch.ops.aten.add.Tensor`. The operator is of type\n``, and this operator is what we will use\nas the target to register our custom implementation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport onnxscript\n\n# Opset 18 is the standard supported version as of PyTorch 2.6\nfrom onnxscript import opset18 as op\n\n\n# Create a model that uses the operator torch.ops.aten.add.Tensor\nclass Model(torch.nn.Module):\n def forward(self, input_x, input_y):\n return torch.ops.aten.add.Tensor(input_x, input_y)\n\n\n# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.\n# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml\n# All attributes must be annotated with type hints.\ndef custom_aten_add(self, other, alpha: float = 1.0):\n if alpha != 1.0:\n alpha = op.CastLike(alpha, other)\n other = op.Mul(other, alpha)\n # To distinguish the custom implementation from the builtin one, we switch the order of the inputs\n return op.Add(other, self)\n\n\nx = torch.tensor([1.0])\ny = torch.tensor([2.0])\n\n# Then we provide the custom implementation to the ONNX exporter as a ``custom_translation_table``.\nonnx_program = torch.onnx.export(\n Model().eval(),\n (x, y),\n dynamo=True,\n custom_translation_table={\n torch.ops.aten.add.Tensor: custom_aten_add,\n },\n)\n# Optimize the ONNX graph to remove redundant nodes\nonnx_program.optimize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let\\'s inspect the model and verify the model is using the custom\nimplementation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(onnx_program.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The translation is using our custom implementation: In node\n`node_Add_0`, `input_y` now comes first, and `input_x` comes second.\n\nWe can use ONNX Runtime to run the model and verify the results by\ncalling the `torch.onnx.ONNXProgram`{.interpreted-text role=\"class\"}\ndirectly on the input tensors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "result = onnx_program(x, y)[0]\ntorch.testing.assert_close(result, torch.tensor([3.0]))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using custom ONNX operators\n===========================\n\nIn this case, we create a model with standard PyTorch operators, but the\nruntime (such as Microsoft\\'s ONNX Runtime) can provide a custom\nimplementation for that kernel, effectively replacing the existing\nimplementation.\n\nIn the following example, we use the `com.microsoft.Gelu` operator\nprovided by ONNX Runtime, which is not the same `Gelu` from ONNX spec.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class GeluModel(torch.nn.Module):\n def forward(self, input_x):\n return torch.ops.aten.gelu(input_x)\n\n\n# Create a namespace for the custom operator using ONNX Script\n# ``com.microsoft`` is an official ONNX Runtime namespace\nmicrosoft_op = onnxscript.values.Opset(domain=\"com.microsoft\", version=1)\n\n# NOTE: The function signature (including parameter names) must match the signature of the unsupported PyTorch operator.\n# https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml\n# NOTE: All attributes must be annotated with type hints.\n# The function must be scripted using the ``@onnxscript.script()`` decorator when\n# using operators from custom domains. This may be improved in future versions.\nfrom onnxscript import FLOAT\n\n\n@onnxscript.script(microsoft_op)\ndef custom_aten_gelu(self: FLOAT, approximate: str = \"none\") -> FLOAT:\n return microsoft_op.Gelu(self)\n\n\nonnx_program = torch.onnx.export(\n GeluModel().eval(),\n (x,),\n dynamo=True,\n custom_translation_table={\n torch.ops.aten.gelu.default: custom_aten_gelu,\n },\n)\n\n# Optimize the ONNX graph to remove redundant nodes\nonnx_program.optimize()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s inspect the model and verify the model uses op\\_type `Gelu` from\nnamespace `com.microsoft`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(onnx_program.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Similar to the previous example, we can use ONNX Runtime to run the\nmodel and verify the results.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "result = onnx_program(x)[0]\ntorch.testing.assert_close(result, torch.ops.aten.gelu(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Supporting a custom PyTorch operator\n====================================\n\nIn this case, the operator is an operator that is user implemented and\nregistered to PyTorch.\n\nIn the following example, we would like to use a custom operator that\ntakes one tensor input, and returns one output. The operator adds the\ninput to itself, and returns the rounded result.\n\nFirstly, we assume the custom operator is implemented and registered\nwith `torch.library.custom_op()`. You can refer to [Creating new custom\nops in\nPython](https://pytorch.org/docs/stable/library.html#torch.library.custom_op)\nfor a detailed guide on how to create custom operators.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Define and use the operator in PyTorch\n@torch.library.custom_op(\"mylibrary::add_and_round_op\", mutates_args=())\ndef add_and_round_op(input: torch.Tensor) -> torch.Tensor:\n return torch.round(input + input)\n\n\n@add_and_round_op.register_fake\ndef _add_and_round_op_fake(tensor_x):\n return torch.empty_like(tensor_x)\n\n\nclass AddAndRoundModel(torch.nn.Module):\n def forward(self, input):\n return add_and_round_op(input)\n\n\n# Implement the custom operator in ONNX using ONNX Script\ndef onnx_add_and_round(input):\n return op.Round(op.Add(input, input))\n\n\nonnx_program = torch.onnx.export(\n AddAndRoundModel().eval(),\n (x,),\n dynamo=True,\n custom_translation_table={\n torch.ops.mylibrary.add_and_round_op.default: onnx_add_and_round,\n },\n)\n\n# Optimize the ONNX graph to remove redundant nodes\nonnx_program.optimize()\nprint(onnx_program)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The translation is using our custom implementation to translate the\n`torch.ops.mylibrary.add_and_round_op.default` operator in the\n`torch.export.ExportedProgram`{.interpreted-text role=\"class\"}[ to the\nONNX operator ]{.title-ref}[Add]{.title-ref}[ and\n]{.title-ref}[Round]{.title-ref}\\`.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally we verify the results.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "result = onnx_program(x)[0]\ntorch.testing.assert_close(result, add_and_round_op(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nCongratulations! In this tutorial, we explored the\n`custom_translation_table` option and discovered how to create custom\nimplementations for unsupported or existing PyTorch operators using ONNX\nScript.\n\nFinally, we leveraged ONNX Runtime to execute the model and compare the\nresults with PyTorch, providing us with a comprehensive understanding of\nhandling unsupported operators in the ONNX ecosystem.\n\nFurther reading\n===============\n\nThe list below refers to tutorials that ranges from basic examples to\nadvanced scenarios, not necessarily in the order they are listed. Feel\nfree to jump directly to specific topics of your interest or sit tight\nand have fun going through all of them to learn all there is about the\nONNX exporter.\n\n::: {.toctree hidden=\"\"}\n:::\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }