备注
点击 此处 下载完整示例代码
torch.vmap¶
Created On: Oct 26, 2020 | Last Updated: Sep 01, 2021 | Last Verified: Not Verified
本教程介绍torch.vmap,PyTorch操作的自动向量化工具。torch.vmap是一个原型功能,无法处理许多用例;然而,我们希望收集相关用例以指导其设计。如果您正在考虑使用torch.vmap或者认为它在某些方面会非常有用,请通过 https://github.com/pytorch/pytorch/issues/42368 联系我们。
那么,什么是vmap?¶
vmap是一个高阶函数。它接受一个函数`func`并返回一个新的函数,将`func`映射到输入的某个维度上。这受到JAX的vmap强烈启发。
从语义上讲,vmap将”map”应用到`func`调用的PyTorch操作中,有效地向量化这些操作。
import torch
# NB: vmap is only available on nightly builds of PyTorch.
# You can download one at pytorch.org if you're interested in testing it out.
from torch import vmap
vmap的第一个用例是使处理代码中的批量维度更容易。您可以编写一个`func`函数,该函数在单个样本上运行,然后使用`vmap(func)`提升到支持处理样本批量的函数。然而,`func`受许多限制:
它必须是功能性的(不能在其中修改Python数据结构),除非是PyTorch的就地操作。
样本批量必须以张量形式提供。这意味着vmap不能直接处理变长序列。
使用`vmap`的一个例子是计算批量点积。PyTorch没有提供批量`torch.dot` API;与其在文档中找不到相关API,不如使用`vmap`构造一个新函数:
torch.dot # [D], [D] -> []
batched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]
x, y = torch.randn(2, 5), torch.randn(2, 5)
batched_dot(x, y)
`vmap`可以帮助隐藏批量维度,从而简化模型设计体验。
batch_size, feature_size = 3, 5
weights = torch.randn(feature_size, requires_grad=True)
# Note that model doesn't work with a batch of feature vectors because
# torch.dot must take 1D tensors. It's pretty easy to rewrite this
# to use `torch.matmul` instead, but if we didn't want to do that or if
# the code is more complicated (e.g., does some advanced indexing
# shenanigins), we can simply call `vmap`. `vmap` batches over ALL
# inputs, unless otherwise specified (with the in_dims argument,
# please see the documentation for more details).
def model(feature_vec):
# Very simple linear model with activation
return feature_vec.dot(weights).relu()
examples = torch.randn(batch_size, feature_size)
result = torch.vmap(model)(examples)
expected = torch.stack([model(example) for example in examples.unbind()])
assert torch.allclose(result, expected)
`vmap`还可以帮助向量化以前难以或几乎不可能批量化的计算。这将我们引导到第二个用例:批量梯度计算。
PyTorch自动梯度引擎计算vjps(向量-雅可比积)。使用vmap,我们可以计算(批量向量)-雅可比积。
一个例子是计算一个完整的雅可比矩阵(这也可以应用于计算完整的海森矩阵)。计算某个函数f: R^N -> R^N的一个完整雅可比矩阵通常需要N次调用`autograd.grad`,每次调用计算一个雅可比行。
# Setup
N = 5
def f(x):
return x ** 2
x = torch.randn(N, requires_grad=True)
y = f(x)
basis_vectors = torch.eye(N)
# Sequential approach
jacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]
for v in basis_vectors.unbind()]
jacobian = torch.stack(jacobian_rows)
# Using `vmap`, we can vectorize the whole computation, computing the
# Jacobian in a single call to `autograd.grad`.
def get_vjp(v):
return torch.autograd.grad(y, x, v)[0]
jacobian_vmap = vmap(get_vjp)(basis_vectors)
assert torch.allclose(jacobian_vmap, jacobian)
vmap的第三个主要用例是计算每样本梯度。这是vmap原型目前无法高效处理的事情。我们还不确定计算每样本梯度的API应该是什么样,但如果您有任何想法,请在 https://github.com/pytorch/pytorch/issues/7786 上评论。
def model(sample, weight):
# do something...
return torch.dot(sample, weight)
def grad_sample(sample):
return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]
# The following doesn't actually work in the vmap prototype. But it
# could be an API for computing per-sample-gradients.
# batch_of_samples = torch.randn(64, 5)
# vmap(grad_sample)(batch_of_samples)
脚本总运行时间: (0分钟 0.000秒)