• Tutorials >
  • 将带有控制流的模型导出到 ONNX
Shortcuts

ONNX 简介 || 将 PyTorch 模型导出到 ONNX || 扩展 ONNX 导出器运算符支持 || `将带有控制流的模型导出到 ONNX

将带有控制流的模型导出到 ONNX

作者: Xavier Dupré

概述

本教程演示了如何在将 PyTorch 模型导出到 ONNX 时处理控制流逻辑。它强调了直接导出条件语句的挑战,并提供了解决方法来规避它们。

除非使用 torch.cond() 重构控制流逻辑,否则无法将其导出到 ONNX。让我们从一个实现测试的简单模型开始吧。

你将学习到的内容:

  • 如何重构模型以使用 torch.cond() 进行导出。

  • 如何将带有控制流逻辑的模型导出到 ONNX。

  • 如何使用 ONNX 优化器优化导出的模型。

前提条件

  • torch >= 2.6

import torch

定义模型

定义了两个模型:

ForwardWithControlFlowTest: 一个包含 if-else 条件的 forward 方法的模型。

ModelWithControlFlowTest: 一个包含 ForwardWithControlFlowTest 的简单 MLP 模型的一部分。这些模型使用随机输入张量进行测试以确认其按预期执行。

class ForwardWithControlFlowTest(torch.nn.Module):
    def forward(self, x):
        if x.sum():
            return x * 2
        return -x


class ModelWithControlFlowTest(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(3, 2),
            torch.nn.Linear(2, 1),
            ForwardWithControlFlowTest(),
        )

    def forward(self, x):
        out = self.mlp(x)
        return out


model = ModelWithControlFlowTest()

导出模型:第一次尝试

由于前向传递中的控制流逻辑会创建一个导出器无法处理的图拆分,所以使用 torch.export.export 导出该模型会失败。这种行为是预期的,因为不使用 torch.cond() 编写的条件逻辑是不被支持的。

使用 try-except 块捕获导出过程中的预期失败。如果导出意外成功,则会引发 AssertionError

x = torch.randn(3)
model(x)

try:
    torch.export.export(model, (x,), strict=False)
    raise AssertionError("This export should failed unless PyTorch now supports this model.")
except Exception as e:
    print(e)
def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

Could not guard on data-dependent expression Eq(u0, 1) (unhinted: Eq(u0, 1)).  (Size-like symbols: none)

Caused by: (_export/non_strict_utils.py:683 in __torch_function__)
For more information, run with TORCH_LOGS="dynamic"
For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0"
If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1
For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing

For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1

The following call raised this error:
  File "/data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py", line 56, in forward
    if x.sum():

使用带有 JIT 跟踪的 torch.onnx.export()

当导出模型时,使用 torch.onnx.export() 并添加 dynamo=True 参数。此时,导出器会默认为使用 JIT 跟踪。这种回退允许模型成功导出,但由于跟踪的限制,生成的 ONNX 图可能不能充分代表原始模型逻辑。

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...




def forward(self, arg0_1: "f32[2, 3]", arg1_1: "f32[2]", arg2_1: "f32[1, 2]", arg3_1: "f32[1]", arg4_1: "f32[3]"):
     # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
    linear: "f32[2]" = torch.ops.aten.linear.default(arg4_1, arg0_1, arg1_1);  arg4_1 = arg0_1 = arg1_1 = None
    linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, arg2_1, arg3_1);  linear = arg2_1 = arg3_1 = None

     # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
    sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1);  linear_1 = None
    ne: "b8[]" = torch.ops.aten.ne.Scalar(sum_1, 0);  sum_1 = None
    item: "Sym(Eq(u0, 1))" = torch.ops.aten.item.default(ne);  ne = item = None

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`...

class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_

         # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
        l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_);  l_x_ = None
        l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0);  l__self___mlp_0 = None

         # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
        sum_1: "f32[][]cpu" = l__self___mlp_1.sum();  l__self___mlp_1 = sum_1 = None


class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_

         # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
        l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_);  l_x_ = None
        l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0);  l__self___mlp_0 = None

         # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
        sum_1: "f32[][]cpu" = l__self___mlp_1.sum();  l__self___mlp_1 = sum_1 = None


class GraphModule(torch.nn.Module):
    def forward(self, L_x_: "f32[3][1]cpu"):
        l_x_ = L_x_

         # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:71 in forward, code: out = self.mlp(x)
        l__self___mlp_0: "f32[2][1]cpu" = self.L__self___mlp_0(l_x_);  l_x_ = None
        l__self___mlp_1: "f32[1][1]cpu" = self.L__self___mlp_1(l__self___mlp_0);  l__self___mlp_0 = None

         # File: /data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56 in forward, code: if x.sum():
        sum_1: "f32[][]cpu" = l__self___mlp_1.sum();  l__self___mlp_1 = sum_1 = None

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export`... ❌
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script...
/data1/lin/pytorch-tutorials/beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py:56: TracerWarning:

Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with Torch Script... ✅
[torch.onnx] Run decomposition...
/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/export/_unlift.py:81: UserWarning:

Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer

/data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/fx/graph.py:1772: UserWarning:

Node lifted_tensor_6 target lifted_tensor_6 lifted_tensor_6 of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target

[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
    ir_version=10,
    opset_imports={'': 18},
    producer_name='pytorch',
    producer_version='2.7.0+cu126',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"input_1"<FLOAT,[3]>
    ),
    outputs=(
        %"mul"<FLOAT,[1]>
    ),
    initializers=(
        %"model.mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([0.3437, 0.5336], requires_grad=True), name='model.mlp.0.bias')},
        %"model.mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3741], requires_grad=True), name='model.mlp.1.bias')}
    ),
) {
    0 |  # node_Constant_8
         %"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.4352824 ,  0.31417835], [-0.15039097, -0.30015165], [-0.3928429 ,  0.11529753]], dtype=float32), name='val_0')} ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[-0.4352824 ,  0.31417835], [-0.15039097, -0.30015165], [-0.3928429 ,  0.11529753]], dtype=float32), name='val_0')}
    1 |  # node_MatMul_1
         %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"input_1", %"val_0"{[[-0.4352824091911316, 0.31417834758758545], [-0.15039096772670746, -0.30015164613723755], [-0.3928428888320923, 0.11529753357172012]]})
    2 |  # node_Add_2
         %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"model.mlp.0.bias"{[0.3437083661556244, 0.5336393117904663]})
    3 |  # node_Constant_9
         %"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')}
    4 |  # node_MatMul_4
         %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6590583324432373], [-0.6693998575210571]]})
    5 |  # node_Add_5
         %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"model.mlp.1.bias"{[-0.37414515018463135]})
    6 |  # node_Constant_10
         %"convert_element_type_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='convert_element_type_default')} ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='convert_element_type_default')}
    7 |  # node_Mul_7
         %"mul"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"convert_element_type_default"{2.0})
    return %"mul"<FLOAT,[1]>
}

建议的修补方法:使用 torch.cond() 进行重构

为了使控制流可导出,本教程演示了如何将 ForwardWithControlFlowTest 中的 forward 方法替换为使用 torch.cond() 的重构版本。

重构的详情:

两个辅助函数(identity2 和 neg)表示条件逻辑的分支: * torch.cond() 用于指定条件和两个分支以及输入参数。 * 更新后的 forward 方法随后被动态分配给模型内 ForwardWithControlFlowTest 实例。打印出子模块列表以确认已完成替换。

def new_forward(x):
    def identity2(x):
        return x * 2

    def neg(x):
        return -x

    return torch.cond(x.sum() > 0, identity2, neg, (x,))


print("the list of submodules")
for name, mod in model.named_modules():
    print(name, type(mod))
    if isinstance(mod, ForwardWithControlFlowTest):
        mod.forward = new_forward
the list of submodules
 <class '__main__.ModelWithControlFlowTest'>
mlp <class 'torch.nn.modules.container.Sequential'>
mlp.0 <class 'torch.nn.modules.linear.Linear'>
mlp.1 <class 'torch.nn.modules.linear.Linear'>
mlp.2 <class '__main__.ForwardWithControlFlowTest'>

让我们看看 FX 图是什么样子。

print(torch.export.export(model, (x,), strict=False))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_mlp_0_weight: "f32[2, 3]", p_mlp_0_bias: "f32[2]", p_mlp_1_weight: "f32[1, 2]", p_mlp_1_bias: "f32[1]", x: "f32[3]"):
             # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/linear.py:125 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[2]" = torch.ops.aten.linear.default(x, p_mlp_0_weight, p_mlp_0_bias);  x = p_mlp_0_weight = p_mlp_0_bias = None
            linear_1: "f32[1]" = torch.ops.aten.linear.default(linear, p_mlp_1_weight, p_mlp_1_bias);  linear = p_mlp_1_weight = p_mlp_1_bias = None

             # File: /data1/lin/pytorch-tutorials/.venv/lib/python3.10/site-packages/torch/nn/modules/container.py:240 in forward, code: input = module(input)
            sum_1: "f32[]" = torch.ops.aten.sum.default(linear_1)
            gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0);  sum_1 = None

             # File: <eval_with_key>.30:9 in forward, code: cond = torch.ops.higher_order.cond(l_args_0_, cond_true_0, cond_false_0, [l_args_3_0_]);  l_args_0_ = cond_true_0 = cond_false_0 = l_args_3_0_ = None
            true_graph_0 = self.true_graph_0
            false_graph_0 = self.false_graph_0
            cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [linear_1]);  gt = true_graph_0 = false_graph_0 = linear_1 = None
            getitem: "f32[1]" = cond[0];  cond = None
            return (getitem,)

        class true_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                 # File: <eval_with_key>.25:6 in forward, code: mul = l_args_3_0__1.mul(2);  l_args_3_0__1 = None
                mul: "f32[1]" = torch.ops.aten.mul.Tensor(linear_1, 2);  linear_1 = None
                return (mul,)

        class false_graph_0(torch.nn.Module):
            def forward(self, linear_1: "f32[1]"):
                 # File: <eval_with_key>.26:6 in forward, code: neg = l_args_3_0__1.neg();  l_args_3_0__1 = None
                neg: "f32[1]" = torch.ops.aten.neg.default(linear_1);  linear_1 = None
                return (neg,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_0_weight'), target='mlp.0.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_0_bias'), target='mlp.0.bias', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_1_weight'), target='mlp.1.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_mlp_1_bias'), target='mlp.1.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='getitem'), target=None)])
Range constraints: {}

让我们再试着导出一次。

onnx_program = torch.onnx.export(model, (x,), dynamo=True)
print(onnx_program.model)
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`...
[torch.onnx] Obtain model graph for `ModelWithControlFlowTest([...]` with `torch.export.export(..., strict=False)`... ✅
[torch.onnx] Run decomposition...
[torch.onnx] Run decomposition... ✅
[torch.onnx] Translate the graph into ONNX...
[torch.onnx] Translate the graph into ONNX... ✅
<
    ir_version=10,
    opset_imports={'': 18},
    producer_name='pytorch',
    producer_version='2.7.0+cu126',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([0.3437, 0.5336], requires_grad=True), name='mlp.0.bias')},
        %"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3741], requires_grad=True), name='mlp.1.bias')}
    ),
) {
     0 |  # node_Constant_11
          %"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.4352824 ,  0.31417835], [-0.15039097, -0.30015165], [-0.3928429 ,  0.11529753]], dtype=float32), name='val_0')} ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[-0.4352824 ,  0.31417835], [-0.15039097, -0.30015165], [-0.3928429 ,  0.11529753]], dtype=float32), name='val_0')}
     1 |  # node_MatMul_1
          %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.4352824091911316, 0.31417834758758545], [-0.15039096772670746, -0.30015164613723755], [-0.3928428888320923, 0.11529753357172012]]})
     2 |  # node_Add_2
          %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.3437083661556244, 0.5336393117904663]})
     3 |  # node_Constant_12
          %"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')}
     4 |  # node_MatMul_4
          %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6590583324432373], [-0.6693998575210571]]})
     5 |  # node_Add_5
          %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.37414515018463135]})
     6 |  # node_ReduceSum_6
          %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
     7 |  # node_Constant_13
          %"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')} ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')}
     8 |  # node_Greater_9
          %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
     9 |  # node_If_10
          %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
              graph(
                  name=true_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"mul_true_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Constant_1
                       %"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')} ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
                  1 |  # node_Mul_2
                       %"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
                  return %"mul_true_graph_0"<FLOAT,[1]>
              }, else_branch=
              graph(
                  name=false_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"neg_false_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Neg_0
                       %"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
                  return %"neg_false_graph_0"<FLOAT,[1]>
              }}
    return %"getitem"<FLOAT,[1]>
}

我们可以优化模型并去除为捕获控制流分支而创建的模型局部函数。

<
    ir_version=10,
    opset_imports={'': 18},
    producer_name='pytorch',
    producer_version='2.7.0+cu126',
    domain=None,
    model_version=None,
>
graph(
    name=main_graph,
    inputs=(
        %"x"<FLOAT,[3]>
    ),
    outputs=(
        %"getitem"<FLOAT,[1]>
    ),
    initializers=(
        %"mlp.0.bias"<FLOAT,[2]>{TorchTensor<FLOAT,[2]>(Parameter containing: tensor([0.3437, 0.5336], requires_grad=True), name='mlp.0.bias')},
        %"mlp.1.bias"<FLOAT,[1]>{TorchTensor<FLOAT,[1]>(Parameter containing: tensor([-0.3741], requires_grad=True), name='mlp.1.bias')}
    ),
) {
     0 |  # node_Constant_11
          %"val_0"<FLOAT,[3,2]>{Tensor<FLOAT,[3,2]>(array([[-0.4352824 ,  0.31417835], [-0.15039097, -0.30015165], [-0.3928429 ,  0.11529753]], dtype=float32), name='val_0')} ⬅️ ::Constant() {value=Tensor<FLOAT,[3,2]>(array([[-0.4352824 ,  0.31417835], [-0.15039097, -0.30015165], [-0.3928429 ,  0.11529753]], dtype=float32), name='val_0')}
     1 |  # node_MatMul_1
          %"val_1"<FLOAT,[2]> ⬅️ ::MatMul(%"x", %"val_0"{[[-0.4352824091911316, 0.31417834758758545], [-0.15039096772670746, -0.30015164613723755], [-0.3928428888320923, 0.11529753357172012]]})
     2 |  # node_Add_2
          %"linear"<FLOAT,[2]> ⬅️ ::Add(%"val_1", %"mlp.0.bias"{[0.3437083661556244, 0.5336393117904663]})
     3 |  # node_Constant_12
          %"val_2"<FLOAT,[2,1]>{Tensor<FLOAT,[2,1]>(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')} ⬅️ ::Constant() {value=Tensor<FLOAT,[2,1]>(array([[-0.65905833], [-0.66939986]], dtype=float32), name='val_2')}
     4 |  # node_MatMul_4
          %"val_3"<FLOAT,[1]> ⬅️ ::MatMul(%"linear", %"val_2"{[[-0.6590583324432373], [-0.6693998575210571]]})
     5 |  # node_Add_5
          %"linear_1"<FLOAT,[1]> ⬅️ ::Add(%"val_3", %"mlp.1.bias"{[-0.37414515018463135]})
     6 |  # node_ReduceSum_6
          %"sum_1"<FLOAT,[]> ⬅️ ::ReduceSum(%"linear_1") {noop_with_empty_axes=0, keepdims=False}
     7 |  # node_Constant_13
          %"scalar_tensor_default"<FLOAT,[]>{Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')} ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(0., dtype=float32), name='scalar_tensor_default')}
     8 |  # node_Greater_9
          %"gt"<BOOL,[]> ⬅️ ::Greater(%"sum_1", %"scalar_tensor_default"{0.0})
     9 |  # node_If_10
          %"getitem"<FLOAT,[1]> ⬅️ ::If(%"gt") {then_branch=
              graph(
                  name=true_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"mul_true_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Constant_1
                       %"scalar_tensor_default_2"<FLOAT,[]>{Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')} ⬅️ ::Constant() {value=Tensor<FLOAT,[]>(array(2., dtype=float32), name='scalar_tensor_default_2')}
                  1 |  # node_Mul_2
                       %"mul_true_graph_0"<FLOAT,[1]> ⬅️ ::Mul(%"linear_1", %"scalar_tensor_default_2"{2.0})
                  return %"mul_true_graph_0"<FLOAT,[1]>
              }, else_branch=
              graph(
                  name=false_graph_0,
                  inputs=(

                  ),
                  outputs=(
                      %"neg_false_graph_0"<FLOAT,[1]>
                  ),
              ) {
                  0 |  # node_Neg_0
                       %"neg_false_graph_0"<FLOAT,[1]> ⬅️ ::Neg(%"linear_1")
                  return %"neg_false_graph_0"<FLOAT,[1]>
              }}
    return %"getitem"<FLOAT,[1]>
}

总结

本教程展示了将具有条件逻辑的模型导出到ONNX的挑战,并使用函数:func:`torch.cond`提供了实际解决方案。虽然默认导出器可能失败或生成不完美的图,但通过重构模型逻辑可以确保兼容性并生成忠实的ONNX表示。

通过理解这些技术,我们可以克服在处理PyTorch模型控制流时的常见陷阱,并确保与ONNX工作流程的顺畅集成。

进一步阅读

以下列表包含从基础示例到高级场景的教程,不一定按列出的顺序排列。可以直接跳到您感兴趣的特定主题,或者坐稳了逐步浏览所有内容,了解关于ONNX导出器的一切。

脚本总运行时间: (0 分钟 2.749 秒)

画廊由 Sphinx-Gallery 生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源