Shortcuts

自动加载树外扩展

Created On: Oct 10, 2024 | Last Updated: Oct 10, 2024 | Last Verified: Oct 10, 2024

作者: `Yuanhao Ji`__

扩展自动加载机制使PyTorch可以在没有显式导入语句的情况下自动加载树外后端扩展。此功能为用户带来了便利,使他们能够遵循熟悉的PyTorch设备编程模型,而无需显式加载或导入特定于设备的扩展。此外,这可以在树外设备上零代码变更轻松采用现有PyTorch应用程序。有关详情,请参阅 [RFC] 自动加载设备扩展

What you will learn
  • 如何在PyTorch中使用树外扩展自动加载

  • 查看使用Intel Gaudi HPU、华为Ascend NPU的示例

Prerequisites
  • PyTorch v2.5或更新版本

备注

此功能默认启用,可以通过使用 export TORCH_DEVICE_BACKEND_AUTOLOAD=0 禁用。如果遇到类似“无法加载后端扩展”的错误,此错误与PyTorch无关,应禁用此功能并向树外扩展维护者求助。

如何将此机制应用于树外扩展?

例如,假设你有一个名为``foo``的后端,以及对应的包名为``torch_foo``。请确保你的包兼容PyTorch 2.5或更高版本,并在其``__init__.py``文件中包含以下代码片段:

def _autoload():
    print("Check things are working with `torch.foo.is_available()`.")

然后,你需要做的唯一事情就是在你的Python包中定义一个入口点:

setup(
    name="torch_foo",
    version="1.0",
    entry_points={
        "torch.backends": [
            "torch_foo = torch_foo:_autoload",
        ],
    }
)

现在你可以通过简单地添加``import torch``语句来导入``torch_foo``模块,而无需添加``import torch_foo``:

>>> import torch
Check things are working with `torch.foo.is_available()`.
>>> torch.foo.is_available()
True

在某些情况下,你可能会遇到循环导入问题。下面的示例演示了如何解决这些问题。

示例

在这个示例中,我们将使用Intel Gaudi HPU和华为Ascend NPU来确定如何使用PyTorch的自动加载功能将你的扩展集成到PyTorch中。

`habana_frameworks.torch`_是一个Python包,它允许用户通过使用PyTorch的``HPU``设备键在Intel Gaudi上运行PyTorch程序。

``habana_frameworks.torch``是``habana_frameworks``的一个子模块,我们在``habana_frameworks/setup.py``中向``__autoload()``添加一个入口点:

setup(
    name="habana_frameworks",
    version="2.5",
+   entry_points={
+       'torch.backends': [
+           "device_backend = habana_frameworks:__autoload",
+       ],
+   }
)

在``habana_frameworks/init.py``中,我们使用一个全局变量来跟踪我们的模块是否已被加载:

import os

is_loaded = False  # A member variable of habana_frameworks module to track if our module has been imported

def __autoload():
    # This is an entrypoint for pytorch autoload mechanism
    # If the following condition is true, that means our backend has already been loaded, either explicitly
    # or by the autoload mechanism and importing it again should be skipped to avoid circular imports
    global is_loaded
    if is_loaded:
        return
    import habana_frameworks.torch

在``habana_frameworks/torch/init.py``中,我们通过更新全局变量的状态来防止循环导入:

import os

# This is to prevent torch autoload mechanism from causing circular imports
import habana_frameworks

habana_frameworks.is_loaded = True

torch_npu``使用户能够在华为Ascend NPU上运行PyTorch程序,它使用``PrivateUse1``设备键,并向终端用户公开设备名为``npu

我们在`torch_npu/setup.py`_中定义了一个入口点:

setup(
    name="torch_npu",
    version="2.5",
+   entry_points={
+       'torch.backends': [
+           'torch_npu = torch_npu:_autoload',
+       ],
+   }
)

不同于``habana_frameworks``,``torch_npu``使用环境变量``TORCH_DEVICE_BACKEND_AUTOLOAD``来控制自动加载过程。例如,我们可以将其设置为``0``以禁用自动加载以防止循环导入:

# Disable autoloading before running 'import torch'
os.environ['TORCH_DEVICE_BACKEND_AUTOLOAD'] = '0'

import torch

工作原理

自动加载实现

自动加载是基于Python的`Entrypoints <https://packaging.python.org/en/latest/specifications/entry-points/>`_机制实现的。我们在``torch/__init__.py``中发现并加载所有由扩展定义的特定入口点。

如上所示,安装``torch_foo``后,在加载你定义的入口点时,你的Python模块可以被导入,然后你可以在调用时做一些必要的操作。

请参阅此拉取请求中的实现:[RFC] Add support for device extension autoloading

总结

在本教程中,我们学习了PyTorch中的扩展自动加载机制,该机制自动加载后端扩展,无需添加额外的导入语句。我们还学习了如何通过定义一个入口点将此机制应用于扩展,并如何防止循环导入。此外,我们还回顾了如何在Intel Gaudi HPU和华为Ascend NPU上使用自动加载机制的示例。

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源