{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Introduction to ONNX](intro_onnx.html) \\|\\| [Exporting a PyTorch model\nto ONNX](export_simple_model_to_onnx_tutorial.html) \\|\\| [Extending the\nONNX exporter operator support](onnx_registry_tutorial.html) \\|\\|\n**\\`Export a model with control flow to ONNX**\n\nExport a model with control flow to ONNX\n========================================\n\n**Author**: [Xavier Dupr\u00e9](https://github.com/xadupre)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overview\n========\n\nThis tutorial demonstrates how to handle control flow logic while\nexporting a PyTorch model to ONNX. It highlights the challenges of\nexporting conditional statements directly and provides solutions to\ncircumvent them.\n\nConditional logic cannot be exported into ONNX unless they refactored to\nuse `torch.cond`{.interpreted-text role=\"func\"}. Let\\'s start with a\nsimple model implementing a test.\n\nWhat you will learn:\n\n- How to refactor the model to use `torch.cond`{.interpreted-text\n role=\"func\"} for exporting.\n- How to export a model with control flow logic to ONNX.\n- How to optimize the exported model using the ONNX optimizer.\n\nPrerequisites\n-------------\n\n- `torch >= 2.6`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the Models\n=================\n\nTwo models are defined:\n\n`ForwardWithControlFlowTest`: A model with a forward method containing\nan if-else conditional.\n\n`ModelWithControlFlowTest`: A model that incorporates\n`ForwardWithControlFlowTest` as part of a simple MLP. The models are\ntested with a random input tensor to confirm they execute as expected.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class ForwardWithControlFlowTest(torch.nn.Module):\n def forward(self, x):\n if x.sum():\n return x * 2\n return -x\n\n\nclass ModelWithControlFlowTest(torch.nn.Module):\n def __init__(self):\n super().__init__()\n self.mlp = torch.nn.Sequential(\n torch.nn.Linear(3, 2),\n torch.nn.Linear(2, 1),\n ForwardWithControlFlowTest(),\n )\n\n def forward(self, x):\n out = self.mlp(x)\n return out\n\n\nmodel = ModelWithControlFlowTest()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Exporting the Model: First Attempt\n==================================\n\nExporting this model using torch.export.export fails because the control\nflow logic in the forward pass creates a graph break that the exporter\ncannot handle. This behavior is expected, as conditional logic not\nwritten using `torch.cond`{.interpreted-text role=\"func\"} is\nunsupported.\n\nA try-except block is used to capture the expected failure during the\nexport process. If the export unexpectedly succeeds, an `AssertionError`\nis raised.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x = torch.randn(3)\nmodel(x)\n\ntry:\n torch.export.export(model, (x,), strict=False)\n raise AssertionError(\"This export should failed unless PyTorch now supports this model.\")\nexcept Exception as e:\n print(e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using `torch.onnx.export`{.interpreted-text role=\"func\"} with JIT\nTracing\n\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\--\n\nWhen exporting the model using `torch.onnx.export`{.interpreted-text\nrole=\"func\"} with the dynamo=True argument, the exporter defaults to\nusing JIT tracing. This fallback allows the model to export, but the\nresulting ONNX graph may not faithfully represent the original model\nlogic due to the limitations of tracing.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "onnx_program = torch.onnx.export(model, (x,), dynamo=True) \nprint(onnx_program.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suggested Patch: Refactoring with `torch.cond`{.interpreted-text\nrole=\"func\"}\n\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\-\\--\n\nTo make the control flow exportable, the tutorial demonstrates replacing\nthe forward method in `ForwardWithControlFlowTest` with a refactored\nversion that uses `torch.cond`{.interpreted-text role=\"func\"}\\`.\n\nDetails of the Refactoring:\n\nTwo helper functions (identity2 and neg) represent the branches of the\nconditional logic: \\* `torch.cond`{.interpreted-text role=\"func\"}[ is\nused to specify the condition and the two branches along with the input\narguments. \\* The updated forward method is then dynamically assigned to\nthe ]{.title-ref}[ForwardWithControlFlowTest]{.title-ref}\\` instance\nwithin the model. A list of submodules is printed to confirm the\nreplacement.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def new_forward(x):\n def identity2(x):\n return x * 2\n\n def neg(x):\n return -x\n\n return torch.cond(x.sum() > 0, identity2, neg, (x,))\n\n\nprint(\"the list of submodules\")\nfor name, mod in model.named_modules():\n print(name, type(mod))\n if isinstance(mod, ForwardWithControlFlowTest):\n mod.forward = new_forward" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s see what the FX graph looks like.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(torch.export.export(model, (x,), strict=False))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s export again.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "onnx_program = torch.onnx.export(model, (x,), dynamo=True) \nprint(onnx_program.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can optimize the model and get rid of the model local functions\ncreated to capture the control flow branches.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "onnx_program.optimize() \nprint(onnx_program.model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nThis tutorial demonstrates the challenges of exporting models with\nconditional logic to ONNX and presents a practical solution using\n`torch.cond`{.interpreted-text role=\"func\"}. While the default exporters\nmay fail or produce imperfect graphs, refactoring the model\\'s logic\nensures compatibility and generates a faithful ONNX representation.\n\nBy understanding these techniques, we can overcome common pitfalls when\nworking with control flow in PyTorch models and ensure smooth\nintegration with ONNX workflows.\n\nFurther reading\n===============\n\nThe list below refers to tutorials that ranges from basic examples to\nadvanced scenarios, not necessarily in the order they are listed. Feel\nfree to jump directly to specific topics of your interest or sit tight\nand have fun going through all of them to learn all there is about the\nONNX exporter.\n\n::: {.toctree hidden=\"\"}\n:::\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }