Shortcuts

展示 torch.export 流程的演示,以及常见挑战和解决方案

作者: Ankith Gunapal, Jordi Ramon, Marcos Carranza

torch.export 教程简介 中,我们学习了如何使用 torch.export 。这个教程是先前的扩展教程,探索了如何使用代码导出流行的模型,并解决了 torch.export 中可能会遇到的常见问题。

在本教程中,您将学习如何针对以下应用场景导出模型:

选择这四个模型是为了展示 torch.export 的独特功能,以及在实际实施中遇到的某些注意事项和问题。

前提条件

  • PyTorch 2.4 或更高版本

  • torch.export 和 PyTorch Eager 推理的基础理解。

torch.export 的关键要求: 无图断裂

torch.compile 通过使用 JIT 编译 PyTorch 代码为优化的内核加速 PyTorch 代码。它使用 TorchDynamo 优化给定的模型并创建一个优化的图,然后通过 API 指定的后端将其加载到硬件中。当 TorchDynamo 遇到不支持的 Python 功能时,它会中断计算图,让默认 Python 解释器处理不支持的代码,然后恢复图的捕获。此图的中断被称为 图断裂

torch.exporttorch.compile 的一个关键区别是 torch.export 不支持图断裂,这意味着您导出的整个模型或模型的一部分需要是单个图。这是因为处理图断裂涉及使用默认 Python 评价解释不支持的操作,这与 torch.export 的设计不兼容。您可以在此 链接 中阅读有关不同 PyTorch 框架之间差异的详细信息。

您可以使用以下命令标识程序中的图断裂:

TORCH_LOGS="graph_breaks" python <file_name>.py

您需要修改程序以消除图断裂。一旦解决,您就可以准备导出模型。PyTorch 在流行的 HuggingFace 和 TIMM 模型上运行 每夜基准测试 以分析 torch.compile。这其中的大多数模型都没有图断裂。

本教程中的模型没有图断裂,但使用 torch.export 时会失败。

视频分类

MViT 是基于 多尺度视觉Transformer 的模型类别。此模型已使用 Kinetics-400 数据集 进行视频分类训练。搭载相关数据集,该模型可用于游戏中的动作识别。

以下代码通过设置 batch_size=2 对 MViT 进行追踪导出,并检查 ExportedProgram 是否可以在 batch_size=4 下运行。

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb

model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)
# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
exported_program = torch.export.export(
    model,
    (input_frames,),
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

错误: 静态批大小

    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 2, but got 4

默认情况下,导出流程将假定所有输入形状是静态的,因此如果您使用不同于跟踪时输入形状的输入形状运行程序,将会遇到错误。

解决方案

为解决错误,我们指定输入的第一个维度(batch_size)为动态,指定 batch_size 的预期范围。在下面显示的修正示例中,我们指定预期的 batch_size 范围为 1 至 16。请注意, min=2 不是一个错误,其原因可以在 0/1 专业化问题 中找到解释。有关 torch.export 动态形状的详细描述可在导出教程中找到。以下代码演示了如何在批大小动态化的情况下导出 mViT:

import numpy as np
import torch
from torchvision.models.video import MViT_V1_B_Weights, mvit_v1_b
import traceback as tb


model = mvit_v1_b(weights=MViT_V1_B_Weights.DEFAULT)

# Create a batch of 2 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(2,16, 224, 224, 3)

# Transpose to get [1, 3, num_clips, height, width].
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))

# Export the model.
batch_dim = torch.export.Dim("batch", min=2, max=16)
exported_program = torch.export.export(
    model,
    (input_frames,),
    # Specify the first dimension of the input x as dynamic
    dynamic_shapes={"x": {0: batch_dim}},
)

# Create a batch of 4 videos, each with 16 frames of shape 224x224x3.
input_frames = torch.randn(4,16, 224, 224, 3)
input_frames = np.transpose(input_frames, (0, 4, 1, 2, 3))
try:
    exported_program.module()(input_frames)
except Exception:
    tb.print_exc()

自动语音识别

自动语音识别 (ASR) 是机器学习的一种应用,旨在将口语语言转录成文本。Whisper 是 OpenAI 的基于 Transformer 的编码器-解码器模型,训练了 68 万小时的标记数据,用于 ASR 和语音翻译。以下代码尝试导出 ASR 的 whisper-tiny 模型。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,))

错误: 使用 TorchDynamo 的严格追踪

torch._dynamo.exc.InternalTorchDynamoError: AttributeError: 'DynamicCache' object has no attribute 'key_cache'

默认情况下,torch.export 使用 TorchDynamo 对您的代码进行追踪,这是一种字节代码分析引擎,它以符号方式分析您的代码并构建图。这种分析提供了更强的安全性保证,但并不支持所有 Python 代码。当我们使用默认严格模式导出 whisper-tiny 模型时,它通常会由于不支持的功能而在 Dynamo 中返回错误。要了解为什么 Dynamo 出错,您可以参考此 GitHub 问题

解决方案

为了解决上述错误,“torch.export”支持“非严格(non_strict)”模式,在这种模式下程序通过Python解释器进行追踪,相当于PyTorch的即时执行(eager execution)。唯一不同的是,所有“Tensor”对象将被替换为“ProxyTensors”,后者将在图中记录它们的所有操作。通过设置“strict=False”,我们可以导出程序。

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset

# load model
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")

# dummy inputs for exporting the model
input_features = torch.randn(1,80, 3000)
attention_mask = torch.ones(1, 3000)
decoder_input_ids = torch.tensor([[1, 1, 1 , 1]]) * model.config.decoder_start_token_id

model.eval()

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(input_features, attention_mask, decoder_input_ids,), strict=False)

图像描述生成

**图像描述生成**是一种用语言定义图像内容的任务。在游戏场景中,图像描述生成可以通过动态生成场景中各种游戏对象的文本描述来增强游戏体验,从而为玩家提供额外的细节。BLIP 是一个由SalesForce研究团队发布的流行图像描述生成模型。以下代码尝试使用“batch_size=1”导出BLIP。

import torch
from models.blip import blip_decoder

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
image_size = 384
image = torch.randn(1, 3,384,384).to(device)
caption_input = ""

model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth'
model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
model.eval()
model = model.to(device)

exported_program: torch.export.ExportedProgram= torch.export.export(model, args=(image,caption_input,), strict=False)

错误:无法修改冻结存储的张量

在导出模型时可能会失败,因为模型实现可能包含“torch.export”尚未支持的某些Python操作。这些错误中的一些可能有解决方法。BLIP是一个案例,其中原始模型出现错误,这可以通过对代码进行一些小修改来解决。“torch.export”列出了在 ExportDB 中支持和不支持的常见操作,并展示了如何修改代码使其符合导出要求。

File "/BLIP/models/blip.py", line 112, in forward
    text.input_ids[:,0] = self.tokenizer.bos_token_id
  File "/anaconda3/envs/export/lib/python3.10/site-packages/torch/_subclasses/functional_tensor.py", line 545, in __torch_dispatch__
    outs_unwrapped = func._op_dk(
RuntimeError: cannot mutate tensors with frozen storage

解决方案

克隆 模型张量,以解决导出时的错误。

text.input_ids = text.input_ids.clone() # clone the tensor
text.input_ids[:,0] = self.tokenizer.bos_token_id

备注

此限制已在PyTorch 2.7的夜间版本中放松。在PyTorch 2.7中应该可以直接使用。

可提示的图像分割

**图像分割**是一种计算机视觉技术,根据数字图像的特性将其划分为不同的像素组或段。万物分割模型(SAM) 提出了可提示的图像分割,它能够根据提示预测对象的掩码。而 SAM 2 是第一个统一的模型,可用于图像和视频中的对象分割。SAM2ImagePredictor 类为模型提供了简单的接口用于提示。模型可接受点和框提示,以及预测的上一轮迭代生成的掩码。鉴于SAM2在对象跟踪方面提供了强大的零样本性能,它可以用来跟踪场景中的游戏对象。

SAM2ImagePredictor 的预测方法中,张量操作发生在 _predict 方法中。因此,我们尝试以下方式导出。

ep = torch.export.export(
    self._predict,
    args=(unnorm_coords, labels, unnorm_box, mask_input, multimask_output),
    kwargs={"return_logits": return_logits},
    strict=False,
)

错误:模型不是类型“torch.nn.Module”

“torch.export”要求模块类型是“torch.nn.Module”。然而,我们尝试导出的模块是一个类方法,因此出错。

Traceback (most recent call last):
  File "/sam2/image_predict.py", line 20, in <module>
    masks, scores, _ = predictor.predict(
  File "/sam2/sam2/sam2_image_predictor.py", line 312, in predict
    ep = torch.export.export(
  File "python3.10/site-packages/torch/export/__init__.py", line 359, in export
    raise ValueError(
ValueError: Expected `mod` to be an instance of `torch.nn.Module`, got <class 'method'>.

解决方案

我们编写一个辅助类,该类继承自“torch.nn.Module”并在类的“forward”方法中调用“_predict”方法。完整代码可以在 这里 找到。

class ExportHelper(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(_, *args, **kwargs):
        return self._predict(*args, **kwargs)

 model_to_export = ExportHelper()
 ep = torch.export.export(
      model_to_export,
      args=(unnorm_coords, labels, unnorm_box, mask_input,  multimask_output),
      kwargs={"return_logits": return_logits},
      strict=False,
      )

结论

在本教程中,我们学习了如何通过正确的配置和简单的代码修改使用“torch.export”导出流行使用场景下的模型。一旦可以导出模型,就可以在硬件上将“ExportedProgram”降级使用,例如在服务器上使用 AOTInductor,或在边缘设备上使用 ExecuTorch。要了解更多关于“AOTInductor”(AOTI)的内容,请参考 AOTI教程。要了解更多关于“ExecuTorch”的内容,请参考 ExecuTorch教程

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源