跳过模块参数初始化¶
Created On: Jun 17, 2021 | Last Updated: Jun 17, 2021 | Last Verified: Not Verified
简介¶
当创建模块时,其可学习参数会根据与模块类型相关的默认初始化方案进行初始化。例如,:class:`torch.nn.Linear`模块的`weight`参数会从`uniform(-1/sqrt(in_features), 1/sqrt(in_features))`分布初始化。如果希望使用其他初始化方案,这通常需要在模块实例化后重新初始化参数:
from torch import nn
# Initializes weight from the default distribution: uniform(-1/sqrt(10), 1/sqrt(10)).
m = nn.Linear(10, 5)
# Re-initialize weight from a different distribution.
nn.init.orthogonal_(m.weight)
在这种情况下,构造时完成的初始化是浪费的计算,并且如果`weight`参数很大,可能会复杂而耗时。
跳过初始化¶
现在可在模块构造过程中跳过参数初始化,从而避免浪费计算。这可以轻松实现,使用 torch.nn.utils.skip_init()
函数:
from torch import nn
from torch.nn.utils import skip_init
m = skip_init(nn.Linear, 10, 5)
# Example: Do custom, non-default parameter initialization.
nn.init.orthogonal_(m.weight)
这可以应用于满足以下条件的任何模块(详见:ref:`更新`部分)。请注意,`torch.nn`提供的所有模块均满足这些条件,因此支持跳过初始化。
更新模块以支持跳过初始化¶
由于 torch.nn.utils.skip_init()
的实现方式(详见:ref:细节),模块必须满足两个要求才能与该函数兼容。您可以通过遵守这些要求,为您的自定义模块选择跳过参数初始化功能:
1. The module must accept a device kwarg in its constructor that is passed to any parameters or buffers created during construction.
2. The module must not perform any computation on parameters or buffers in its constructor except initialization (i.e. functions from torch.nn.init).
以下示例展示了一个支持 `device`关键字参数并将其传递给所创建的参数、缓冲区或子模块的模块:
import torch
from torch import nn
class MyModule(torch.nn.Module):
def __init__(self, foo, bar, device=None):
super().__init__()
# ==== Case 1: Module creates parameters directly. ====
# Pass device along to any created parameters.
self.param1 = nn.Parameter(torch.empty((foo, bar), device=device))
self.register_parameter('param2', nn.Parameter(torch.empty(bar, device=device)))
# To ensure support for the meta device, avoid using ops except those in
# torch.nn.init on parameters in your module's constructor.
with torch.no_grad():
nn.init.kaiming_uniform_(self.param1)
nn.init.uniform_(self.param2)
# ==== Case 2: Module creates submodules. ====
# Pass device along recursively. All submodules will need to support
# them as well; this is the case for all torch.nn provided modules.
self.fc = nn.Linear(bar, 5, device=device)
# This also works with containers.
self.linears = nn.Sequential(
nn.Linear(5, 5, device=device),
nn.Linear(5, 1, device=device)
)
# ==== Case 3: Module creates buffers. ====
# Pass device along during buffer tensor creation.
self.register_buffer('some_buffer', torch.ones(7, device=device))
...
实现细节¶
在幕后,torch.nn.utils.skip_init()
函数是通过两步模式实现的:
# 1. Initialize module on the meta device; all torch.nn.init ops have
# no-op behavior on the meta device.
m = nn.Linear(10, 5, device='meta')
# 2. Materialize an uninitialized (empty) form of the module on the CPU device.
# The result of this is a module instance with uninitialized parameters.
m.to_empty(device='cpu')
它通过将模块实例化到一个”meta”设备上来工作,该设备具有张量形状信息但不分配任何存储空间。torch.nn.init 操作特别为这个meta设备实现,因此它们的行为是无操作的。这使得参数初始化逻辑实际上被跳过了。
请注意,仅对构造中正确支持`device`关键字参数的模块,这种模式才能正常工作,如:ref:`更新`中所述。