.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "beginner/onnx/export_control_flow_model_to_onnx_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_beginner_onnx_export_control_flow_model_to_onnx_tutorial.py: `Introduction to ONNX `_ || `Exporting a PyTorch model to ONNX `_ || `Extending the ONNX exporter operator support `_ || **`Export a model with control flow to ONNX** Export a model with control flow to ONNX ======================================== **Author**: `Xavier Dupré `_ .. GENERATED FROM PYTHON SOURCE LINES 16-37 Overview -------- This tutorial demonstrates how to handle control flow logic while exporting a PyTorch model to ONNX. It highlights the challenges of exporting conditional statements directly and provides solutions to circumvent them. Conditional logic cannot be exported into ONNX unless they refactored to use :func:`torch.cond`. Let's start with a simple model implementing a test. What you will learn: - How to refactor the model to use :func:`torch.cond` for exporting. - How to export a model with control flow logic to ONNX. - How to optimize the exported model using the ONNX optimizer. Prerequisites ~~~~~~~~~~~~~ * ``torch >= 2.6`` .. GENERATED FROM PYTHON SOURCE LINES 37-41 .. code-block:: default import torch .. GENERATED FROM PYTHON SOURCE LINES 42-53 Define the Models ----------------- Two models are defined: ``ForwardWithControlFlowTest``: A model with a forward method containing an if-else conditional. ``ModelWithControlFlowTest``: A model that incorporates ``ForwardWithControlFlowTest`` as part of a simple MLP. The models are tested with a random input tensor to confirm they execute as expected. .. GENERATED FROM PYTHON SOURCE LINES 53-78 .. code-block:: default class ForwardWithControlFlowTest(torch.nn.Module): def forward(self, x): if x.sum(): return x * 2 return -x class ModelWithControlFlowTest(torch.nn.Module): def __init__(self): super().__init__() self.mlp = torch.nn.Sequential( torch.nn.Linear(3, 2), torch.nn.Linear(2, 1), ForwardWithControlFlowTest(), ) def forward(self, x): out = self.mlp(x) return out model = ModelWithControlFlowTest() .. GENERATED FROM PYTHON SOURCE LINES 79-89 Exporting the Model: First Attempt ---------------------------------- Exporting this model using torch.export.export fails because the control flow logic in the forward pass creates a graph break that the exporter cannot handle. This behavior is expected, as conditional logic not written using :func:`torch.cond` is unsupported. A try-except block is used to capture the expected failure during the export process. If the export unexpectedly succeeds, an ``AssertionError`` is raised. .. GENERATED FROM PYTHON SOURCE LINES 89-99 .. code-block:: default x = torch.randn(3) model(x) try: torch.export.export(model, (x,), strict=False) raise AssertionError("This export should failed unless PyTorch now supports this model.") except Exception as e: print(e) .. rst-class:: sphx-glr-script-out .. code-block:: none def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"): # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias) linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum(): sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)). (Size-like symbols: none) Caused by: (_export/non_strict_utils.py:683 in __torch_function__) For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 The following call raised this error: File "/data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward if x.sum(): .. GENERATED FROM PYTHON SOURCE LINES 100-107 Using :func:`torch.onnx.export` with JIT Tracing ---------------------------------------- When exporting the model using :func:`torch.onnx.export` with the dynamo=True argument, the exporter defaults to using JIT tracing. This fallback allows the model to export, but the resulting ONNX graph may not faithfully represent the original model logic due to the limitations of tracing. .. GENERATED FROM PYTHON SOURCE LINES 107-113 .. code-block:: default onnx_program = torch.onnx.export(model, (x,), dynamo=True) print(onnx_program.model) .. rst-class:: sphx-glr-script-out .. code-block:: none [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"): # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias) linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1); arg4_1 = arg0_1 = arg1_1 = None linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1); linear = arg2_1 = arg3_1 = None # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum(): sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1); linear_1 = None ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0); sum_1 = None item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne); ne = item = None [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌ [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3][1]cpu"): l_x_ = L_x_ # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x) l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum(): sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3][1]cpu"): l_x_ = L_x_ # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x) l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum(): sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3][1]cpu"): l_x_ = L_x_ # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x) l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_); l_x_ = None l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0); l__self___mlp_0 = None # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum(): sum_1: "f32[][]cpu" = l__self___mlp_1.sum(); l__self___mlp_1 = sum_1 = None [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... ❌ [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... ✅ [torch.onnx] Run decomposition... /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/export/_unlift.py:81: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/fx/graph.py:1772: UserWarning: Node lifted_tensor_6 target lifted_tensor_6 lifted_tensor_6 of does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target [torch.onnx] Run decomposition... ✅ [torch.onnx] Translate the graph into ONNX... [torch.onnx] Translate the graph into ONNX... ✅ < ir_version=10, opset_imports={'': 18}, producer_name='pytorch', producer_version='2.7.0+cu126', domain=None, model_version=None, > graph( name=main_graph, inputs=( %"input_1" ), outputs=( %"mul" ), initializers=( %"model.mlp.0.bias"{TorchTensor(Parameter containing: tensor([0.3437, 0.5336], requires_grad=True), name='model.mlp.0.bias')}, %"model.mlp.1.bias"{TorchTensor(Parameter containing: tensor([-0.3741], requires_grad=True), name='model.mlp.1.bias')} ), ) { 0 | # node_Constant_8 %"val_0"{Tensor(array([[-0.4352824 , 0.31417835], [-0.15039097, -0.30015165], [-0.3928429 , 0.11529753]], dtype=float32), name='val_0')} ⬅️ ::Constant() {value=Tensor(array([[-0.4352824 , 0.31417835], [-0.15039097, -0.30015165], [-0.3928429 , 0.11529753]], dtype=float32), name='val_0')} 1 | # node_MatMul_1 %"val_1" ⬅️ ::MatMul(%"input_1", %"val_0"{[[-0.4352824091911316, 0.31417834758758545], [-0.15039096772670746, -0.30015164613723755], [-0.3928428888320923, 0.11529753357172012]]}) 2 | # node_Add_2 %"linear" ⬅️ ::Add(%"val_1", %"model.mlp.0.bias"{[0.3437083661556244, 0.5336393117904663]}) 3 | # node_Constant_9 %"val_2"{Tensor(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} ⬅️ ::Constant() {value=Tensor(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} 4 | # node_MatMul_4 %"val_3" ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6590583324432373], [-0.6693998575210571]]}) 5 | # node_Add_5 %"linear_1" ⬅️ ::Add(%"val_3", %"model.mlp.1.bias"{[-0.37414515018463135]}) 6 | # node_Constant_10 %"convert_element_type_default"{Tensor(array(2., dtype=float32), name='convert_element_type_default')} ⬅️ ::Constant() {value=Tensor(array(2., dtype=float32), name='convert_element_type_default')} 7 | # node_Mul_7 %"mul" ⬅️ ::Mul(%"linear_1", %"convert_element_type_default"{2.0}) return %"mul" } .. GENERATED FROM PYTHON SOURCE LINES 114-126 Suggested Patch: Refactoring with :func:`torch.cond` -------------------------------------------- To make the control flow exportable, the tutorial demonstrates replacing the forward method in ``ForwardWithControlFlowTest`` with a refactored version that uses :func:`torch.cond``. Details of the Refactoring: Two helper functions (identity2 and neg) represent the branches of the conditional logic: * :func:`torch.cond`` is used to specify the condition and the two branches along with the input arguments. * The updated forward method is then dynamically assigned to the ``ForwardWithControlFlowTest`` instance within the model. A list of submodules is printed to confirm the replacement. .. GENERATED FROM PYTHON SOURCE LINES 126-143 .. code-block:: default def new_forward(x): def identity2(x): return x * 2 def neg(x): return -x return torch.cond(x.sum() > 0, identity2, neg, (x,)) print("the list of submodules") for name, mod in model.named_modules(): print(name, type(mod)) if isinstance(mod, ForwardWithControlFlowTest): mod.forward = new_forward .. rst-class:: sphx-glr-script-out .. code-block:: none the list of submodules mlp mlp.0 mlp.1 mlp.2 .. GENERATED FROM PYTHON SOURCE LINES 144-145 Let's see what the FX graph looks like. .. GENERATED FROM PYTHON SOURCE LINES 145-148 .. code-block:: default print(torch.export.export(model, (x,), strict=False)) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"): # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias) linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias); x = p_mlp_0_weight = p_mlp_0_bias = None linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias); linear = p_mlp_1_weight = p_mlp_1_bias = None # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py:240 in forward, code: input = module(input) sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: .30:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, [l_args_3_0_]); l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [linear_1]); gt = true_graph_0 = false_graph_0 = linear_1 = None getitem: "f32[1]" = cond[0]; cond = None return (getitem,) class true_graph_0(torch.nn.Module): def forward(self, linear_1: "f32[1]"): # File: .25:6 in forward, code: mul = l_args_3_0__1.mul(2); l_args_3_0__1 = None mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2); linear_1 = None return (mul,) class false_graph_0(torch.nn.Module): def forward(self, linear_1: "f32[1]"): # File: .26:6 in forward, code: neg = l_args_3_0__1.neg(); l_args_3_0__1 = None neg: "f32[1]" = torch.ops.aten.neg.default(linear_1); linear_1 = None return (neg,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='p_mlp_0_weight'), target='mlp.0.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_mlp_0_bias'), target='mlp.0.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_mlp_1_weight'), target='mlp.1.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_mlp_1_bias'), target='mlp.1.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} .. GENERATED FROM PYTHON SOURCE LINES 149-150 Let's export again. .. GENERATED FROM PYTHON SOURCE LINES 150-155 .. code-block:: default onnx_program = torch.onnx.export(model, (x,), dynamo=True) print(onnx_program.model) .. rst-class:: sphx-glr-script-out .. code-block:: none [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... [torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅ [torch.onnx] Run decomposition... [torch.onnx] Run decomposition... ✅ [torch.onnx] Translate the graph into ONNX... [torch.onnx] Translate the graph into ONNX... ✅ < ir_version=10, opset_imports={'': 18}, producer_name='pytorch', producer_version='2.7.0+cu126', domain=None, model_version=None, > graph( name=main_graph, inputs=( %"x" ), outputs=( %"getitem" ), initializers=( %"mlp.0.bias"{TorchTensor(Parameter containing: tensor([0.3437, 0.5336], requires_grad=True), name='mlp.0.bias')}, %"mlp.1.bias"{TorchTensor(Parameter containing: tensor([-0.3741], requires_grad=True), name='mlp.1.bias')} ), ) { 0 | # node_Constant_11 %"val_0"{Tensor(array([[-0.4352824 , 0.31417835], [-0.15039097, -0.30015165], [-0.3928429 , 0.11529753]], dtype=float32), name='val_0')} ⬅️ ::Constant() {value=Tensor(array([[-0.4352824 , 0.31417835], [-0.15039097, -0.30015165], [-0.3928429 , 0.11529753]], dtype=float32), name='val_0')} 1 | # node_MatMul_1 %"val_1" ⬅️ ::MatMul(%"x", %"val_0"{[[-0.4352824091911316, 0.31417834758758545], [-0.15039096772670746, -0.30015164613723755], [-0.3928428888320923, 0.11529753357172012]]}) 2 | # node_Add_2 %"linear" ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.3437083661556244, 0.5336393117904663]}) 3 | # node_Constant_12 %"val_2"{Tensor(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} ⬅️ ::Constant() {value=Tensor(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} 4 | # node_MatMul_4 %"val_3" ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6590583324432373], [-0.6693998575210571]]}) 5 | # node_Add_5 %"linear_1" ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.37414515018463135]}) 6 | # node_ReduceSum_6 %"sum_1" ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False} 7 | # node_Constant_13 %"scalar_tensor_default"{Tensor(array(0., dtype=float32), name='scalar_tensor_default')} ⬅️ ::Constant() {value=Tensor(array(0., dtype=float32), name='scalar_tensor_default')} 8 | # node_Greater_9 %"gt" ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0}) 9 | # node_If_10 %"getitem" ⬅️ ::If(%"gt") {then_branch= graph( name=true_graph_0, inputs=( ), outputs=( %"mul_true_graph_0" ), ) { 0 | # node_Constant_1 %"scalar_tensor_default_2"{Tensor(array(2., dtype=float32), name='scalar_tensor_default_2')} ⬅️ ::Constant() {value=Tensor(array(2., dtype=float32), name='scalar_tensor_default_2')} 1 | # node_Mul_2 %"mul_true_graph_0" ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0}) return %"mul_true_graph_0" }, else_branch= graph( name=false_graph_0, inputs=( ), outputs=( %"neg_false_graph_0" ), ) { 0 | # node_Neg_0 %"neg_false_graph_0" ⬅️ ::Neg(%"linear_1") return %"neg_false_graph_0" }} return %"getitem" } .. GENERATED FROM PYTHON SOURCE LINES 156-157 We can optimize the model and get rid of the model local functions created to capture the control flow branches. .. GENERATED FROM PYTHON SOURCE LINES 157-161 .. code-block:: default onnx_program.optimize() print(onnx_program.model) .. rst-class:: sphx-glr-script-out .. code-block:: none < ir_version=10, opset_imports={'': 18}, producer_name='pytorch', producer_version='2.7.0+cu126', domain=None, model_version=None, > graph( name=main_graph, inputs=( %"x" ), outputs=( %"getitem" ), initializers=( %"mlp.0.bias"{TorchTensor(Parameter containing: tensor([0.3437, 0.5336], requires_grad=True), name='mlp.0.bias')}, %"mlp.1.bias"{TorchTensor(Parameter containing: tensor([-0.3741], requires_grad=True), name='mlp.1.bias')} ), ) { 0 | # node_Constant_11 %"val_0"{Tensor(array([[-0.4352824 , 0.31417835], [-0.15039097, -0.30015165], [-0.3928429 , 0.11529753]], dtype=float32), name='val_0')} ⬅️ ::Constant() {value=Tensor(array([[-0.4352824 , 0.31417835], [-0.15039097, -0.30015165], [-0.3928429 , 0.11529753]], dtype=float32), name='val_0')} 1 | # node_MatMul_1 %"val_1" ⬅️ ::MatMul(%"x", %"val_0"{[[-0.4352824091911316, 0.31417834758758545], [-0.15039096772670746, -0.30015164613723755], [-0.3928428888320923, 0.11529753357172012]]}) 2 | # node_Add_2 %"linear" ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.3437083661556244, 0.5336393117904663]}) 3 | # node_Constant_12 %"val_2"{Tensor(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} ⬅️ ::Constant() {value=Tensor(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} 4 | # node_MatMul_4 %"val_3" ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6590583324432373], [-0.6693998575210571]]}) 5 | # node_Add_5 %"linear_1" ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.37414515018463135]}) 6 | # node_ReduceSum_6 %"sum_1" ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False} 7 | # node_Constant_13 %"scalar_tensor_default"{Tensor(array(0., dtype=float32), name='scalar_tensor_default')} ⬅️ ::Constant() {value=Tensor(array(0., dtype=float32), name='scalar_tensor_default')} 8 | # node_Greater_9 %"gt" ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0}) 9 | # node_If_10 %"getitem" ⬅️ ::If(%"gt") {then_branch= graph( name=true_graph_0, inputs=( ), outputs=( %"mul_true_graph_0" ), ) { 0 | # node_Constant_1 %"scalar_tensor_default_2"{Tensor(array(2., dtype=float32), name='scalar_tensor_default_2')} ⬅️ ::Constant() {value=Tensor(array(2., dtype=float32), name='scalar_tensor_default_2')} 1 | # node_Mul_2 %"mul_true_graph_0" ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0}) return %"mul_true_graph_0" }, else_branch= graph( name=false_graph_0, inputs=( ), outputs=( %"neg_false_graph_0" ), ) { 0 | # node_Neg_0 %"neg_false_graph_0" ⬅️ ::Neg(%"linear_1") return %"neg_false_graph_0" }} return %"getitem" } .. GENERATED FROM PYTHON SOURCE LINES 162-185 Conclusion ---------- This tutorial demonstrates the challenges of exporting models with conditional logic to ONNX and presents a practical solution using :func:`torch.cond`. While the default exporters may fail or produce imperfect graphs, refactoring the model's logic ensures compatibility and generates a faithful ONNX representation. By understanding these techniques, we can overcome common pitfalls when working with control flow in PyTorch models and ensure smooth integration with ONNX workflows. Further reading --------------- The list below refers to tutorials that ranges from basic examples to advanced scenarios, not necessarily in the order they are listed. Feel free to jump directly to specific topics of your interest or sit tight and have fun going through all of them to learn all there is about the ONNX exporter. .. include:: /beginner_source/onnx/onnx_toc.txt .. toctree:: :hidden: .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 2.749 seconds) .. _sphx_glr_download_beginner_onnx_export_control_flow_model_to_onnx_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: export_control_flow_model_to_onnx_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: export_control_flow_model_to_onnx_tutorial.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_