Shortcuts

torch.export教程

Created On: Oct 02, 2023 | Last Updated: Jan 27, 2025 | Last Verified: Nov 05, 2024

作者: William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan

警告

``torch.export``及其相关功能处于原型状态,可能会发生向后兼容性破坏的变化。本教程提供了PyTorch 2.5中``torch.export``使用的快照。

torch.export() 是PyTorch 2.X提出的一种将PyTorch模型导出为标准模型表示的方法,目的是在不同的(即无需Python的)环境中运行。官方文档可以在`这里 <https://pytorch.org/docs/main/export.html>`__找到。

在本教程中,您将学习如何使用:func:torch.export`从PyTorch程序中提取``ExportedProgram``(即单图表示)。我们还详细说明了一些可能需要进行的修改,以便让您的模型兼容``torch.export`

目录

基本用法

``torch.export``通过跟踪目标函数生成单图表示,给定示例输入。``torch.export.export()``是``torch.export``的主要入口。

在本教程中,``torch.export``和``torch.export.export()``在实际应用中几乎是同义的,不过``torch.export``通常指PyTorch 2.X的导出过程,而``torch.export.export()``通常指实际的函数调用。

``torch.export.export()``的函数签名为:

export(
    mod: torch.nn.Module,
    args: Tuple[Any, ...],
    kwargs: Optional[Dict[str, Any]] = None,
    *,
    dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None
) -> ExportedProgram

torch.export.export()``通过调用``mod(*args, **kwargs)``跟踪张量计算图,并将其封装在一个``ExportedProgram``中,该程序可以序列化或稍后使用不同的输入执行。要执行``ExportedProgram,我们可以调用``.module()``来返回一个可调用的``torch.nn.Module``,就像原始程序一样。我们将在教程中详细说明``dynamic_shapes``参数。

import torch
from torch.export import export

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.lin = torch.nn.Linear(100, 10)

    def forward(self, x, y):
        return torch.nn.functional.relu(self.lin(x + y), inplace=True)

mod = MyModule()
exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100)))
print(type(exported_mod))
print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100)))

让我们回顾一下``ExportedProgram``的一些感兴趣的属性。

``graph``属性是从我们导出的函数中跟踪的一个`FX图 <https://pytorch.org/docs/stable/fx.html#torch.fx.Graph>`__,即所有PyTorch操作的计算图。FX图是“ATen IR”,意味着它仅包含“ATen级”操作。

``graph_signature``属性提供了导出图中输入和输出节点的更详细描述,描述了哪些是参数、缓冲区、用户输入或用户输出。

``range_constraints``属性将在后面介绍。

print(exported_mod)

有关更多详细信息,请参阅``torch.export`` 文档

图中断

尽管``torch.export``与``torch.compile``共享组件,但``torch.export``的关键限制,特别是与``torch.compile``相比,是它不支持图中断。这是因为处理图中断需要使用默认的Python评估解释不支持的操作,这与导出的使用场景不兼容。因此,为了使您的模型代码与``torch.export``兼容,您需要修改代码以消除图中断。

在以下情况下需要图中断:

  • 数据相关的控制流

class Bad1(torch.nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return torch.sin(x)
        return torch.cos(x)

import traceback as tb
try:
    export(Bad1(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 使用``.data``访问张量数据

class Bad2(torch.nn.Module):
    def forward(self, x):
        x.data[0, 0] = 3
        return x

try:
    export(Bad2(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()
  • 调用不支持的函数(如许多内置函数)

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

try:
    export(Bad3(), (torch.randn(3, 3),))
except Exception:
    tb.print_exc()

非严格导出

为了跟踪程序,``torch.export``默认使用TorchDynamo,一个字节码分析引擎,来符号化分析Python代码,并基于结果构建图。这种分析允许``torch.export``提供更强的安全性保证,但并非所有Python代码都受支持,导致这些图中断。

为了解决这个问题,在PyTorch 2.3版本中,我们引入了一种新的非严格模式导出方式,我们通过Python解释器以即时模式的方式准确执行程序来跟踪程序,从而跳过不支持的Python特性。可以通过添加``strict=False``标志来实现这一点。

回顾一些导致图中断的先例:

  • 调用不支持的函数(如许多内置函数)会跟踪

调用``id(x)``的情况,但是在这种情况下,``id(x)``在图中被专用化为一个常量整数。这是因为``id(x)``不是一个张量操作,因此操作未被记录在图中。

class Bad3(torch.nn.Module):
    def forward(self, x):
        x = x + 1
        return x + id(x)

bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False)
print(bad3_nonstrict)
print(bad3_nonstrict.module()(torch.ones(3, 3)))

但是,仍然存在一些需要重写原始模块的功能:

控制流操作

``torch.export``实际上支持数据相关的控制流。但这些需要使用控制流操作来表达。例如,我们可以使用``cond``操作修复上述控制流示例,如下所示:

class Bad1Fixed(torch.nn.Module):
    def forward(self, x):
        def true_fn(x):
            return torch.sin(x)
        def false_fn(x):
            return torch.cos(x)
        return torch.cond(x.sum() > 0, true_fn, false_fn, [x])

exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),))
print(exported_bad1_fixed)
print(exported_bad1_fixed.module()(torch.ones(3, 3)))
print(exported_bad1_fixed.module()(-torch.ones(3, 3)))

``cond``存在一些需要注意的限制:

  • 谓词(即``x.sum() > 0``)必须是布尔值或单个元素的张量。

  • 操作数(即``[x]``)必须是张量。

  • 分支函数(即``true_fn``和``false_fn``)的签名必须与操作数匹配,并且它们都必须返回一个具有相同元数据(例如``dtype``、``shape``等)的单个张量。

  • 分支函数不能更改输入或全局变量。

  • 分支函数不能访问闭包变量,除了函数在方法作用域内定义时的``self``。

有关``cond``的更多详细信息,请查看`cond文档 <https://pytorch.org/docs/main/cond.html>`__。

我们还可以使用``map``,它将一个函数应用于第一个张量参数的第一个维度。

from torch._higher_order_ops.map import map as torch_map

class MapModule(torch.nn.Module):
    def forward(self, xs, y, z):
        def body(x, y, z):
            return x + y + z

        return torch_map(body, xs, y, z)

inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4))
exported_map_example = export(MapModule(), inps)
print(exported_map_example)
print(exported_map_example.module()(*inps))

其他控制流操作包括``while_loop``、associative_scan``和``scan。有关每个操作的更多文档,请参阅`此页面 <https://github.com/pytorch/pytorch/tree/main/torch/_higher_order_ops>`__。

约束/动态形状

本部分介绍导出的程序的动态行为和表示。动态行为是针对特定模型的,因此在本教程的大部分内容中,我们将重点介绍这个特定的玩具模型(带有结果张量形状的注释):

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [6, 5]
        x: torch.Tensor,  # [4]
        y: torch.Tensor,  # [8, 4]
        z: torch.Tensor,  # [32]
    ):
        x0 = x + y  # [8, 4]
        x1 = self.l(w)  # [6, 3]
        x2 = x0.flatten()  # [32]
        x3 = x2 + z  # [32]
        return x1, x3

默认情况下,torch.export 会生成静态程序。其结果是,在运行时,即使输入形状在即时模式下是有效的,程序也无法处理不同形状的输入。

w = torch.randn(6, 5)
x = torch.randn(4)
y = torch.randn(8, 4)
z = torch.randn(32)
model = DynamicModel()
ep = export(model, (w, x, y, z))
model(w, x, torch.randn(3, 4), torch.randn(12))
try:
    ep.module()(w, x, torch.randn(3, 4), torch.randn(12))
except Exception:
    tb.print_exc()

基本概念:符号和守护

为了支持动态性,export() 提供了一个 dynamic_shapes 参数。处理动态形状的最简单方法是使用 Dim.AUTO 并查看返回的程序。动态行为是针对每个输入维度指定的;对于每个输入,我们可以指定一个值元组:

from torch.export.dynamic_shapes import Dim

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

在我们查看生成的程序之前,让我们了解指定 dynamic_shapes 的意义,以及它与导出的交互方式。对于每个指定了 Dim 对象的输入维度,将分配一个符号 allocated,范围为 [2, inf]``(为什么不是 ``[0, inf][1, inf]?我们将在 0/1 专门化部分 later 解释)。

导出然后运行模型追踪,查看模型执行的每个操作。每个单独的操作都可以发出所谓的“守护”;基本上是程序有效所需的布尔条件。当守护涉及为输入维度分配的符号时,程序包含关于有效输入形状的限制;即程序的动态行为。符号形状子系统负责接受所有发出的守护并生成符合这些守护的最终程序表示。在我们看到 ExportedProgram 中的这种“最终表示”之前,让我们看看我们正在追踪的玩具模型发出的守护。

在这里,每个前向输入张量都被注释为在追踪开始时分配的符号:

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(
        self,
        w: torch.Tensor,  # [s0, s1]
        x: torch.Tensor,  # [s2]
        y: torch.Tensor,  # [s3, s4]
        z: torch.Tensor,  # [s5]
    ):
        x0 = x + y  # guard: s2 == s4
        x1 = self.l(w)  # guard: s1 == 5
        x2 = x0.flatten()  # no guard added here
        x3 = x2 + z  # guard: s3 * s4 == s5
        return x1, x3

让我们了解每个操作和发出的守护:

  • x0 = x + y: 这是带广播的逐元素相加,因为 x 是一维张量而 y 是二维张量。xy 的最后一个维度上进行广播,发出守护 s2 == s4

  • x1 = self.l(w): 调用 nn.Linear() 进行矩阵乘法,并使用模型参数。在导出时,参数、缓冲区和常量被视为程序状态,这被视为静态,因此这是动态输入(w: [s0, s1])和静态的张量之间的矩阵乘法。发出守护 s1 == 5

  • x2 = x0.flatten(): 实际上,这个调用未发出任何守护!(至少没有与输入形状相关的)

  • x3 = x2 + z: x2 在展平后具有形状 [s3*s4],此逐元素相加发出 s3 * s4 == s5

将所有这些守护写下来并总结几乎就像数学证明,而符号形状子系统试图实现这一点!总结一下,我们可以得出以下输入形状是程序有效时的必须:

  • w: [s0, 5]

  • x: [s2]

  • y: [s3, s2]

  • z: [s2*s3]

当我们最后打印导出的程序以查看结果时,这些形状为对应的输入标注:

print(ep)

另一项需要注意的功能是上面的 range_constraints 字段,其中包含每个符号的有效范围。目前这不是很有趣,因为此导出调用未发出任何与符号界限相关的守护,每个基础符号都有通用界限,但稍后会涉及到。

到目前为止,由于我们一直在导出这个玩具模型,这种体验并不代表调试动态形状守护和问题通常有多么困难。在大多数情况下,不明显哪些守护被发出,也不清楚是哪些操作和用户代码部分负责。对于这个玩具模型,我们明确指出了确切的代码行,守护相当直观。

在更复杂的情况下,第一步是始终启用详细日志记录。这可通过环境变量 TORCH_LOGS="+dynamic" 或交互方式用 torch._logging.set_logs(dynamic=10) 来完成:

torch._logging.set_logs(dynamic=10)
ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)

即使是这个简单的玩具模型也会输出大量内容。这里的日志行前后已被截断以忽略不必要的信息,但通过查看日志,我们可以看到与上面描述相关的行;例如符号的分配:

"""
create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
runtime_assert True == True [statically known]
create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in <lambda>)
"""

create_symbol 行显示了何时分配了新符号,日志还标识了分配给它们的张量变量名称和维度。在其他行中,我们还可以看到发出的守护:

"""
runtime_assert Eq(s2, s4) [guard added] x0 = x + y  # output shape: [8, 4]  # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"
runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w)  # [6, 3]  # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"
runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z  # [32]  # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"
"""

[guard added] 消息旁边,我们还可以看到负责的用户代码行 - 幸运的是,这里模型足够简单。在许多实际案例中,情况并不那么直接:高级 torch 操作可能有复杂的伪内核实现或操作分解,这使得守护的发出位置和内容更加复杂。在这种情况下,深入调研和调查的最佳方式是遵循日志的建议,并重新运行环境变量 TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="...",进一步归因感兴趣的守护。

Dim.AUTO 只是与 dynamic_shapes 交互的可选项之一;当前有另外两种选项:Dim.DYNAMICDim.STATICDim.STATIC 简单地标记维度为静态,而 Dim.DYNAMICDim.AUTO 在所有方面相似,唯一区别是当专门化为常量时会引发错误;这旨在保持动态性。看看当在动态标记的维度上发出静态守护时会发生什么:

dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()

静态守护并不总是模型固有的;它们也可能来自用户规范。实际上,一个常见的导致形状专门化的陷阱是当用户为等效维度指定了冲突的标记;一个是动态,另一个是静态。当对于 x.shape[0]y.shape[1] 出现这种情况时会引发相同的错误类型:

dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO)
dynamic_shapes["x"] = (Dim.STATIC,)
dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC)
try:
    export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()

在这里,您可能会问为什么导出会“专门化”,即为什么我们通过静态路径解决这种静态/动态冲突。答案是由于上面描述的符号形状系统,即符号和守护。当 x.shape[0] 被标记为静态时,我们不分配符号,并将此形状编译为具体的整数 4。为 y.shape[1] 分配了符号,因此我们最终发出守护 s3 == 4,导致专门化。

导出的一个功能是在追踪期间,诸如 assert、torch._check()if/else 条件等语句也会发出守护。看看我们如何在现有模型中加入这些语句:

class DynamicModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.l = torch.nn.Linear(5, 3)

    def forward(self, w, x, y, z):
        assert w.shape[0] <= 512
        torch._check(x.shape[0] >= 4)
        if w.shape[0] == x.shape[0] + 2:
            x0 = x + y
            x1 = self.l(w)
            x2 = x0.flatten()
            x3 = x2 + z
            return x1, x3
        else:
            return w

dynamic_shapes = {
    "w": (Dim.AUTO, Dim.AUTO),
    "x": (Dim.AUTO,),
    "y": (Dim.AUTO, Dim.AUTO),
    "z": (Dim.AUTO,),
}
try:
    ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes)
except Exception:
    tb.print_exc()

其中每条语句都发出了一个额外的守护,并且导出的程序显示了更改;s0 被替换为 s2 + 2,并且 s2 现在包含上限和下限,反映在 range_constraints 中。

对于 if/else 条件,您可能会问为什么选择了 True 分支,以及为什么没有发出 w.shape[0] != x.shape[0] + 2 的守护。答案是导出是由追踪提供的样本输入引导的,并专门化选中的分支。如果提供的不同样本输入形状未通过 if 条件,导出将追踪并发出对应于 else 分支的守护。此外,您可能会问为什么我们仅追踪了 if 分支,以及是否可以在程序中保持控制流并同时保留两个分支。为此,请参阅在 “控制流操作” 部分中重写模型代码。

0/1 专门化

既然我们谈到了守护和专门化,是时候讨论我们之前提到的 0/1 专门化问题了。关键问题是导出会专门化值为 0 或 1 的样本输入维度,因为这些形状在追踪时具有不适用于其他形状的特性。例如,大小为 1 的张量可以广播,而其他大小则会失败;大小为 0 ……这意味着您应该在希望程序硬编码这些形状时指定 0/1 样本输入,而在希望动态行为时指定非 0/1 样本输入。看看我们导出这个线性层时在运行时会发生什么:

ep = export(
    torch.nn.Linear(4, 3),
    (torch.randn(1, 4),),
    dynamic_shapes={
        "input": (Dim.AUTO, Dim.STATIC),
    },
)
try:
    ep.module()(torch.randn(2, 4))
except Exception:
    tb.print_exc()

命名维度

到目前为止,我们只讨论了 3 种指定动态形状的方法:Dim.AUTODim.DYNAMICDim.STATIC。它们的吸引力在于用户体验的低摩擦性;模型追踪期间发出的所有守护均被遵守,以及导出会自动解决动态行为(例如最小/最大范围、关系以及静态/动态维度)。动态形状子系统本质上充当“发现”过程,总结这些守护并展示导出认为程序的总体动态行为。此设计的缺点出现在用户对这些模型的动态行为有更强的期望或看法时 - 可能用户对某些维度的动态性特别希望,因此绝不能专门化,或者只是希望通过检测对原模型代码进行更改来捕获动态行为的变化,或者可能是底层的分解或元内核。这些变化不会被检测到,并且 export() 调用很可能成功,除非有测试检查生成的 ExportedProgram 表示。

对于这些情况,我们建议采用指定动态形状的“传统”方法,长期使用导出的用户可能熟悉这种方式:命名的 Dims

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

这种动态形状的风格允许用户指定为输入维度分配哪些符号、这些符号的最小/最大边界,并对生成的 ExportedProgram 的动态行为施加限制。如果模型追踪发出的守护与给定的关系或静态/动态规格冲突,则会引发 ConstraintViolation 错误。例如,在上述规范中,以下被断言:

  • x.shape[0] 的范围为 [4, 256],并与 y.shape[0] 相关:y.shape[0] == 2 * x.shape[0]

  • x.shape[1] 是静态的。

  • y.shape[1] 的范围为 [2, 512],并且与任何其他维度无关。

在这个设计中,我们允许通过单变量线性表达式指定维度之间的关系:A * dim + B 可以为任意维度指定。这使用户能够为动态维度指定更复杂的约束,例如整数可整除性:

dx = Dim("dx", min=4, max=512)
dynamic_shapes = {
    "x": (4 * dx, None)  # x.shape[0] has range [16, 2048], and is divisible by 4.
}

约束违规,建议修复

在使用这种规格定义样式时(Dim.AUTO 被引入之前),一个常见的问题是规格定义常常与模型跟踪生成的不匹配。这会导致 ConstraintViolation 错误并导出建议修复。例如,在以下模型和规格中,模型本质上要求 xy 的维度 0 相等,并且维度 1 必须是静态的。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        w = x + y
        return w + torch.ones(4)

dx, dy, d1 = torch.export.dims("dx", "dy", "d1")
try:
    ep = export(
        Foo(),
        (torch.randn(6, 4), torch.randn(6, 4)),
        dynamic_shapes={
            "x": (dx, d1),
            "y": (dy, d1),
        },
    )
except Exception:
    tb.print_exc()

使用建议的修复方案的期望是,用户可以交互地将修复的改变复制粘贴到动态形状规范中,并随后成功导出。

最后,这里有一些关于规格选项的有用信息:

  • None 是一种静态行为的良好选项: - dynamic_shapes=None`(默认)会将整个模型导出为静态。 - 在输入级别指定 `None 会将所有张量维度导出为静态,并且对于非张量输入也是必需的。 - 在维度级别指定 None 会专门化该维度,但这已被弃用,建议使用 Dim.STATIC

  • 为每个维度指定整数值也会产生静态行为,并会额外检查提供的样本输入是否与规格匹配。

这些选项与输入和动态形状规范以下列方式结合:

inputs = (
    torch.randn(4, 4),
    torch.randn(3, 3),
    16,
    False,
)
dynamic_shapes = {
    "tensor_0": (Dim.AUTO, None),
    "tensor_1": None,
    "int_val": None,
    "bool_val": None,
}

数据依赖的错误

在尝试导出模型时,您可能遇到过类似“无法对数据依赖表达式进行保护”或“无法从数据依赖表达式中提取专门化的整数”这样的错误。出现这些错误的原因是 torch.export() 使用 FakeTensors 来编译程序,FakeTensors 符号化地表示其真实的张量对应物。尽管它们具有等效的符号属性(例如大小、步幅、数据类型),但它们在不包含任何数据值这一点上不相同。虽然这避免了不必要的内存使用和昂贵的计算,但也意味着导出可能无法直接编译依赖于数据值的用户代码部分。简单来说,如果编译器需要一个具体的、数据依赖的值才能继续,它会报错并抱怨该值不可用。

数据依赖值出现在许多地方,常见的来源是诸如 item()tolist()torch.unbind() 之类的调用,这些调用从张量中提取标量值。这些值在导出的程序中是如何表示的?在 约束/动态形状 部分中,我们讨论了为动态输入维度分配符号。同样地,在这里我们为程序中出现的每个数据依赖值分配符号。重要的区别在于,这些是“无支持”的符号,与为输入维度分配的“有支持”符号相对。“有支持/无支持” 法术指的是是否存在“提示”:一个具体的值支持符号,这可以指导编译器如何进行。

在输入形状符号情况下(有支持符号),这些提示就是提供的样本输入形状,这解释了为什么控制流分支是由样本输入属性确定的。对于数据依赖值,符号是在跟踪过程中从 FakeTensor “数据”中获取的,因此编译器不知道这些符号将取何值(提示)。

让我们看看这些在导出的程序中是如何表现的:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.tolist()
        return b + [a]

inps = (
    torch.tensor(1),
    torch.tensor([2, 3]),
)
ep = export(Foo(), inps)
print(ep)

结果是分配并返回了三个无支持符号(注意它们用“u”作为前缀,而不是输入形状/有支持符号的通常“s”前缀):一个用于 item() 调用,两个用于 tolist() 方法中 y 的每个元素。从范围约束字段可以看出,这些符号的范围是 [-int_oo, int_oo],而不是分配给输入形状符号的默认范围 [0, int_oo],因为我们无法获得关于这些值的信息——它们不表示大小,因此不一定是正值。

保护机制,torch._check()

但上述情况容易导出,因为这些符号的具体值并未在任何编译器决策中使用;所有相关的只是返回值是无支持符号。本节讨论的数据依赖错误是以下情况之类的,当遇到 数据依赖保护 时:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

这里我们实际上需要“提示”或 a 的具体值,以便编译器决定是否跟踪并作为输出返回 y + 2 或者 y * 5。因为我们使用 FakeTensors 进行跟踪,所以我们不知道 a // 2 >= 5 实际计算结果如何,因此导出失败并给出错误信息“无法对数据依赖表达式 u0 // 2 >= 5 (未提示) 进行保护”。

那么我们该如何导出这个简单模型呢?与 torch.compile() 不同,导出需要完整的图编译,不能只是进行图分段。以下是一些基本选项:

  1. 手动专门化:通过选择分支代码进行跟踪,可以介入解决问题,要么移除控制流代码仅保留特定分支,要么使用 torch.compiler.is_compiling() 在编译时保护被跟踪的代码。

  2. torch.cond():可以用 torch.cond() 重写控制流代码,从而避免专门化为某个分支。

尽管这些选项有效,但它们有其局限性。选项 1 有时需要对模型代码进行大规模侵入性重写以实现专门化,而 torch.cond() 并不是处理数据依赖错误的综合系统。正如所见,也有不涉及控制流的数据依赖错误。

一般推荐的方法是从 torch._check() 调用开始。虽然这些调用的表面看起来只是断言语句,但实际上它们是一个告知编译器有关符号属性系统。在运行时,torch._check() 调用确实会作为一个断言执行,而在编译时被追踪时,检查表达式会被发送到符号形状子系统进行推理,并且任何从表达式为真而推导出来的符号属性都将被存储为属性(如果系统足够智能以推断这些属性)。因此,即使无支持符号没有提示,如果我们能通过 torch._check() 调用传递对这些符号普遍为真的属性,也可能不需要重写模型代码就能绕过数据依赖保护。

例如,在上述模型中,插入 torch._check(a >= 10) 将告知编译器 y + 2 始终可以返回,而 torch._check(a == 4) 告诉它返回 y * 5。看看当我们重新导出这个模型时会发生什么。

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 10)
        torch._check(a <= 60)
        if a // 2 >= 5:
            return y + 2
        else:
            return y * 5

inps = (
    torch.tensor(32),
    torch.randn(4),
)
ep = export(Foo(), inps)
print(ep)

导出成功,请注意在范围约束字段中 u0 的范围是 [10, 60]

那么 torch._check() 调用实际上传达了什么信息呢?随着符号形状子系统变得更智能,这些信息会有所不同,但从根本上讲,这些通常包括:

  1. 等同于非数据依赖表达式:传达等式的 torch._check() 调用,例如 u0 == s0 + 4u0 == 5

  2. 范围细化:提供符号上界或下界的调用,例如上面提到的。

  3. 对更复杂表达式的一些基本推理:插入 torch._check(a < 4) 通常会告诉编译器 a >= 4 为假。对像 torch._check(a ** 2 - 3 * a <= 10) 这样复杂表达式的检查通常能绕过相同的保护。

如前所述,torch._check() 调用在数据依赖控制流之外也有适用性。例如,这里有一个模型,其中 torch._check() 的插入起作用,而手动专门化和 torch.cond() 不起作用:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps)
except Exception:
    tb.print_exc()

这是一个需要插入 torch._check() 来防止操作失败的场景。导出调用会失败,并出现错误信息“无法对数据依赖表达式 -u0 > 60 进行保护”,这表明编译器不知道这是否是一个有效的索引操作——x 的值是否超出 y 的边界。在这里,手动专门化过于繁琐,torch.cond() 没有用武之地。相反,告知编译器 u0 的范围就足够了:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        torch._check(a >= 0)
        torch._check(a < y.shape[0])
        return y[a]

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps)
print(ep)

专门化的值

当程序尝试在跟踪过程中提取具体数据依赖的整数或浮点值时,会发生另一类数据依赖错误。这种错误看起来像“无法从数据依赖表达式中提取专门化的整数”,类似于前一种错误类别——如果这些错误出现在尝试求具体整数/浮点值时,当尝试求具体布尔值会产生数据依赖保护错误。

此错误通常发生在显式或隐式 int() 转换对数据依赖表达式进行操作时。例如,该列表解析中有一个 range() 调用,其隐式地对列表大小执行了一个 int() 转换:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = torch.cat([y for y in range(a)], dim=0)
        return b + int(a)

inps = (
    torch.tensor(32),
    torch.randn(60),
)
try:
    export(Foo(), inps, strict=False)
except Exception:
    tb.print_exc()

对于这些错误,您可以采取的一些基本选项包括:

  1. 避免不必要的 int() 转换调用,例如返回语句中的 int(a)

  2. 使用 torch._check() 调用;不幸的是,在此情况下您可能只能专门化(使用 torch._check(a == 60))。

  3. 在更高层次上重写有问题的代码。例如,列表解析语义上是一个 repeat() 操作,它不涉及 int() 转换。以下重写避免了数据依赖错误:

class Foo(torch.nn.Module):
    def forward(self, x, y):
        a = x.item()
        b = y.unsqueeze(0).repeat(a, 1)
        return b + a

inps = (
    torch.tensor(32),
    torch.randn(60),
)
ep = export(Foo(), inps, strict=False)
print(ep)

数据依赖错误可能会更复杂,您的工具集有更多选项来处理它们:如 torch._check_is_size()guard_size_oblivious() 或真实张量跟踪等入门工具。有关更深入的指南,请参考 导出编程模型处理 GuardOnDataDependentSymNode 错误

自定义操作

torch.export 可以导出带有自定义操作符的 PyTorch 程序。有关如何使用 C++ 或 Python 编写自定义操作符,请参考 此页面

以下是一个在 Python 中注册自定义操作符以供 torch.export 使用的示例。重要的是要注意,自定义操作符必须具有 FakeTensor 内核

@torch.library.custom_op("my_custom_library::custom_op", mutates_args={})
def custom_op(x: torch.Tensor) -> torch.Tensor:
    print("custom_op called!")
    return torch.relu(x)

@custom_op.register_fake
def custom_op_meta(x):
    # Returns an empty tensor with the same shape as the expected output
    return torch.empty_like(x)

这是使用自定义操作符导出程序的示例。

class CustomOpExample(torch.nn.Module):
    def forward(self, x):
        x = torch.sin(x)
        x = torch.ops.my_custom_library.custom_op(x)
        x = torch.cos(x)
        return x

exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),))
print(exported_custom_op_example)
print(exported_custom_op_example.module()(torch.randn(3, 3)))

注意,在``ExportedProgram``中,自定义操作符包含在图中。

IR/分解

``torch.export``生成的图仅包含`ATen操作符 <https://pytorch.org/cppdocs/#aten>`__,这是 PyTorch 中的基本计算单位。由于 ATen 操作符超过 3000 个,导出提供了一种基于某些特性缩小图中使用操作符集合的方法,从而创建不同的 IR。

默认情况下,导出生成最通用的 IR,其中包含所有 ATen 操作符,包括功能性和非功能性操作符。功能性操作符是不包含任何输入突变或别名的操作符。您可以在 此处 找到所有 ATen 操作符的列表,并可以通过检查 op._schema.is_mutable 来确定操作符是否是功能性的,例如:

print(torch.ops.aten.add.Tensor._schema.is_mutable)
print(torch.ops.aten.add_.Tensor._schema.is_mutable)

此通用 IR 可用于在 PyTorch Autograd 中进行自适应训练。此 IR 可通过 API torch.export.export_for_training 更明确地获得,该 API 是在 PyTorch 2.5 中引入的,但调用 torch.export.export 应在 PyTorch 2.6 中生成相同的图。

class DecompExample(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph)

我们可以通过 API run_decompositions 将已导出的程序降低到仅包含功能性 ATen 操作符的操作符集合中,该 API 会将 ATen 操作符分解为分解表中指定的操作符,并使图功能化。通过指定一个空集,我们仅执行功能化,不进行任何额外的分解。这将生成一个包含大约 2000 操作符(而不是上面提到的 3000 个操作符)的 IR,非常适合推理场景。

ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph)

正如我们所看到的,先前的可变操作符 torch.ops.aten.add_.default 现在已替换为 torch.ops.aten.add.default, 一个功能性操作符。

我们还可以将已导出的程序进一步降低到仅包含 Core ATen 操作符集 的操作符集合,该集合仅包含大约 180 个操作符。此 IR 对于不想重新实现所有 ATen 操作符的后端来说是最佳的。

from torch.export import default_decompositions

core_aten_decomp_table = default_decompositions()
core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table)
print(core_aten_ep.graph)

我们现在看到 torch.ops.aten.conv2d.default 已被分解成 torch.ops.aten.convolution.default。这是因为 convolution 是一个更“核心”的操作符,因为像 conv1dconv2d 这样的操作可以用相同的操作符实现。

我们还可以指定自己的分解行为:

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph)

注意,与 torch.ops.aten.conv2d.default 被分解为 torch.ops.aten.convolution.default 不同,它现在被分解为 torch.ops.aten.convolution.defaulttorch.ops.aten.mul.Tensor,这与我们的自定义分解规则匹配。

ExportDB

torch.export 只会从一个 PyTorch 程序中导出单个计算图。由于这种要求,有些 Python 或 PyTorch 的功能可能与 torch.export 不兼容,这需要用户重写部分模型代码。我们在前面的教程中已经看到了一些示例,例如使用 cond 重写 if 条件语句。

ExportDB 是记录 torch.export 所支持和不支持的 Python/PyTorch 功能的标准参考。它基本上是一系列程序示例,每个示例都代表一种特定 Python/PyTorch 功能及其与 torch.export 的交互。示例还可以按类别标记,以便更容易搜索。

例如,让我们使用 ExportDB 来更好地理解 cond 操作符中的谓词是如何工作的。我们可以查看名为 cond_predicate 的示例,它有一个 torch.cond 标签。示例代码如:

def cond_predicate(x):
    """
    The conditional statement (aka predicate) passed to ``cond()`` must be one of the following:
    - ``torch.Tensor`` with a single element
    - boolean expression
    NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized.
    """
    pred = x.dim() > 2 and x.shape[2] > 10
    return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x])

更一般地说,在以下情况之一发生时,ExportDB 可以用作参考:

  1. 在尝试 torch.export 之前,您已经知道您的模型使用了一些复杂的 Python/PyTorch 功能,并希望了解 torch.export 是否支持该功能。

  2. 在尝试 torch.export 时出现了失败,并且不清楚如何解决。

ExportDB 不是详尽无遗的,但旨在覆盖典型 PyTorch 代码中发现的所有用例。如果有重要的 Python/PyTorch 功能需要添加到 ExportDB 或由 torch.export 支持,请随时联系我们。

运行导出的程序

由于 torch.export 只是一个图捕获机制,直接调用由 torch.export 生成的工件将在急性方式下等价于运行急性模块。为了优化导出程序的执行,我们可以将此导出的工件传递给诸如 Inductor 这样的后端,通过 torch.compileAOTInductor 或者 TensorRT

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3, 3)

    def forward(self, x):
        x = self.linear(x)
        return x

inp = torch.randn(2, 3, device="cuda")
m = M().to(device="cuda")
ep = torch.export.export(m, (inp,))

# Run it eagerly
res = ep.module()(inp)
print(res)

# Run it with torch.compile
res = torch.compile(ep.module(), backend="inductor")(inp)
print(res)
import torch._inductor

# Note: these APIs are subject to change
# Compile the exported program to a PT2 archive using ``AOTInductor``
with torch.no_grad():
    pt2_path = torch._inductor.aoti_compile_and_package(ep)

# Load and run the .so file in Python.
# To load and run it in a C++ environment, see:
# https://pytorch.org/docs/main/torch.compiler_aot_inductor.html
aoti_compiled = torch._inductor.aoti_load_package(pt2_path)
res = aoti_compiled(inp)

结论

我们介绍了 torch.export,这是 PyTorch 2.X 中用于从 PyTorch 程序中导出单个计算图的新方法。特别是,我们展示了为导出图需要进行的一些代码修改和注意事项(控制流操作符、约束等)。

**脚本的总运行时间:**(0分钟0.000秒)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源