• Tutorials >
  • 雅可比矩阵、赫赛矩阵、hvp、vhp等:组合函数变换
Shortcuts

雅可比矩阵、赫赛矩阵、hvp、vhp等:组合函数变换

Created On: Mar 15, 2023 | Last Updated: Apr 18, 2023 | Last Verified: Nov 05, 2024

计算雅可比矩阵或赫赛矩阵对一些非传统深度学习模型很有用。使用PyTorch的常规自动微分API(Tensor.backward()torch.autograd.grad)高效计算这些量是很困难(或者令人烦恼)的。PyTorch的`JAX启发式 <https://github.com/google/jax>`_ 的`函数变换API <https://pytorch.org/docs/master/func.html>`_提供了高效计算各种高阶自动微分量的方法。

备注

本教程需要PyTorch 2.0.0或更高版本。

计算雅可比矩阵

import torch
import torch.nn.functional as F
from functools import partial
_ = torch.manual_seed(0)

让我们首先定义一个函数,我们希望计算其雅可比矩阵。这是一个带有非线性激活的简单线性函数。

def predict(weight, bias, x):
    return F.linear(x, weight, bias).tanh()

添加一些虚拟数据:一个权重,一个偏置,以及一个特征向量x。

D = 16
weight = torch.randn(D, D)
bias = torch.randn(D)
x = torch.randn(D)  # feature vector

我们可以将``predict``视为一个函数,该函数将输入``x``从:math:`R^D to R^D`进行映射。PyTorch Autograd计算向量-雅可比矩阵积。为了计算此:math:`R^D to R^D`函数的完整雅可比矩阵,我们必须使用不同的单元向量逐行计算。

def compute_jac(xp):
    jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]
                     for vec in unit_vectors]
    return torch.stack(jacobian_rows)

xp = x.clone().requires_grad_()
unit_vectors = torch.eye(D)

jacobian = compute_jac(xp)

print(jacobian.shape)
print(jacobian[0])  # show first row
torch.Size([16, 16])
tensor([-0.5956, -0.6096, -0.1326, -0.2295,  0.4490,  0.3661, -0.1672, -1.1190,
         0.1705, -0.6683,  0.1851,  0.1630,  0.0634,  0.6547,  0.5908, -0.1308])

与逐行计算雅可比矩阵相比,我们可以使用PyTorch的``torch.vmap``函数变换来摆脱for循环并对计算进行矢量化。我们不能直接对``torch.autograd.grad``应用``vmap``;相反,PyTorch提供了一个``torch.func.vjp``变换,可以与``torch.vmap``组合使用:

from torch.func import vmap, vjp

_, vjp_fn = vjp(partial(predict, weight, bias), x)

ft_jacobian, = vmap(vjp_fn)(unit_vectors)

# let's confirm both methods compute the same result
assert torch.allclose(ft_jacobian, jacobian)

在后续教程中,反向模式AD和``vmap``的组合将为我们提供每样本的梯度。在本教程中,反向模式AD和``vmap``的组合为我们提供雅可比矩阵计算!各种``vmap``和自动微分变换的组合可以为我们提供不同有趣的量。

PyTorch提供了``torch.func.jacrev``作为一种便利功能,执行``vmap-vjp``组合来计算雅可比矩阵。``jacrev``接收一个``argnums``参数,用来指定我们希望计算雅可比矩阵的参数。

from torch.func import jacrev

ft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)

# Confirm by running the following:
assert torch.allclose(ft_jacobian, jacobian)

让我们比较两种计算雅可比矩阵方式的性能。函数变换版本更快(随着输出增加效果更为显著)。

一般来说,我们预计通过``vmap``进行矢量化可以帮助消除开销并更好地利用硬件。

``vmap``通过将外循环下推到函数的原始操作中实现“魔法”,以获得更优性能。

让我们快速编写一个函数来评估性能并处理微秒和毫秒测量:

def get_perf(first, first_descriptor, second, second_descriptor):
    """takes torch.benchmark objects and compares delta of second vs first."""
    faster = second.times[0]
    slower = first.times[0]
    gain = (slower-faster)/slower
    if gain < 0: gain *=-1
    final_gain = gain*100
    print(f" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} ")

然后运行性能比较:

from torch.utils.benchmark import Timer

without_vmap = Timer(stmt="compute_jac(xp)", globals=globals())
with_vmap = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

no_vmap_timer = without_vmap.timeit(500)
with_vmap_timer = with_vmap.timeit(500)

print(no_vmap_timer)
print(with_vmap_timer)
<torch.utils.benchmark.utils.common.Measurement object at 0x7f203afa51b0>
compute_jac(xp)
  868.74 us
  1 measurement, 500 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f203ae3f460>
jacrev(predict, argnums=2)(weight, bias, x)
  291.01 us
  1 measurement, 500 runs , 1 thread

让我们用``get_perf``函数对上述内容进行相对性能比较:

get_perf(no_vmap_timer, "without vmap",  with_vmap_timer, "vmap")
Performance delta: 66.5019 percent improvement with vmap

此外,问题可以很容易地反过来,假如我们想计算模型参数的雅可比矩阵(权重、偏差),而不是输入。

# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias
ft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)

反向模式雅可比矩阵(jacrev)与前向模式雅可比矩阵(jacfwd

我们提供了两个用于计算雅可比矩阵的API:jacrev``和``jacfwd

  • ``jacrev``使用反向模式自动微分(AD)。如您所见,它是``vjp``和``vmap``变换的组合。

  • ``jacfwd``使用前向模式自动微分(AD)。它是``jvp``和``vmap``变换的组合。

``jacfwd``和``jacrev``可以互相替换,但它们的性能特性不同。

通常的经验法则是,如果您计算一个:math:R^N to R^M`函数的雅可比矩阵,并且输出比输入多得多(例如,:math:`M > N),那么建议使用``jacfwd``,否则请使用``jacrev``。虽然这条规则有例外,但以下是一个非正式的解释:

在反向模式自动微分中,我们逐行计算雅可比矩阵;而在前向模式自动微分中(计算雅可比向量积),我们逐列计算。雅可比矩阵有M行和N列,因此如果矩阵较高或较宽,我们可能更倾向于处理行或列较少的方式。

from torch.func import jacrev, jacfwd

首先,让我们在输入多于输出的情况下进行基准测试:

Din = 32
Dout = 2048
weight = torch.randn(Dout, Din)

bias = torch.randn(Dout)
x = torch.randn(Din)

# remember the general rule about taller vs wider... here we have a taller matrix:
print(weight.shape)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
torch.Size([2048, 32])
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f2038ad3fa0>
jacfwd(predict, argnums=2)(weight, bias, x)
  1.16 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f202245e5c0>
jacrev(predict, argnums=2)(weight, bias, x)
  4.91 ms
  1 measurement, 500 runs , 1 thread

然后执行一次相对基准测试:

get_perf(jacfwd_timing, "jacfwd", jacrev_timing, "jacrev", );
Performance delta: 325.0522 percent improvement with jacrev

现在反过来——输出(M)多于输入(N):

Din = 2048
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

using_fwd = Timer(stmt="jacfwd(predict, argnums=2)(weight, bias, x)", globals=globals())
using_bwd = Timer(stmt="jacrev(predict, argnums=2)(weight, bias, x)", globals=globals())

jacfwd_timing = using_fwd.timeit(500)
jacrev_timing = using_bwd.timeit(500)

print(f'jacfwd time: {jacfwd_timing}')
print(f'jacrev time: {jacrev_timing}')
jacfwd time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f1f9f1bb640>
jacfwd(predict, argnums=2)(weight, bias, x)
  4.48 ms
  1 measurement, 500 runs , 1 thread
jacrev time: <torch.utils.benchmark.utils.common.Measurement object at 0x7f20223f34f0>
jacrev(predict, argnums=2)(weight, bias, x)
  341.04 us
  1 measurement, 500 runs , 1 thread

以及一次相对性能比较:

get_perf(jacrev_timing, "jacrev", jacfwd_timing, "jacfwd")
Performance delta: 1212.1935 percent improvement with jacfwd

使用functorch.hessian计算海森矩阵

我们提供了一个便捷的API来计算海森矩阵:torch.func.hessiani。海森矩阵是雅可比矩阵的雅可比矩阵(或者称为偏导的偏导,也就是二阶导数)。

这表明可以仅通过组合functorch的雅可比变换来计算海森矩阵。实际上,“hessian(f)”底层实现就是“jacfwd(jacrev(f))”。

注意:根据模型情况,为了提高性能,您也可以使用``jacfwd(jacfwd(f))``或``jacrev(jacrev(f))``来计算海森矩阵,参考上述关于宽矩阵与高矩阵的经验法则。

from torch.func import hessian

# lets reduce the size in order not to overwhelm Colab. Hessians require
# significant memory:
Din = 512
Dout = 32
weight = torch.randn(Dout, Din)
bias = torch.randn(Dout)
x = torch.randn(Din)

hess_api = hessian(predict, argnums=2)(weight, bias, x)
hess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)
hess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)

让我们验证无论使用hessian API还是使用``jacfwd(jacfwd())``,结果是否一致。

True

批处理雅可比矩阵和批处理海森矩阵

在上述示例中,我们操作的是单个特征向量。在某些情况下,您可能希望针对一批输入计算一批输出的雅可比矩阵。即,给定形状为``(B, N)``的一批输入和一个从:math:`R^N to R^M`的函数,我们希望得到形状为``(B, M, N)``的雅可比矩阵。

最简单的方法是使用``vmap``:

batch_size = 64
Din = 31
Dout = 33

weight = torch.randn(Dout, Din)
print(f"weight shape = {weight.shape}")

bias = torch.randn(Dout)

x = torch.randn(batch_size, Din)

compute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))
batch_jacobian0 = compute_batch_jacobian(weight, bias, x)
weight shape = torch.Size([33, 31])

如果您有一个函数从(B, N) -> (B, M),并且确信每个输入都生成独立的输出,那么有时还可以通过汇总输出,然后计算该函数的雅可比矩阵,而不使用``vmap``:

def predict_with_output_summed(weight, bias, x):
    return predict(weight, bias, x).sum(0)

batch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)
assert torch.allclose(batch_jacobian0, batch_jacobian1)

如果您有一个从:math:`R^N to R^M`的函数,但是输入是批处理的,则需要使用``vmap``组合``jacrev``来计算批处理雅可比矩阵:

最后,批处理海森矩阵可以类似地计算。可以通过使用``vmap``批处理海森矩阵计算来使其容易理解,但在某些情况下,累加技巧同样有效。

compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))

batch_hess = compute_batch_hessian(weight, bias, x)
batch_hess.shape
torch.Size([64, 33, 31, 31])

计算海森向量积

计算海森向量积(hvp)最简单的方法是生成完整的海森矩阵并与向量执行点积。我们可以做得更好:事实证明我们不需要生成完整的海森矩阵即可实现。接下来我们将介绍两种(还有很多)计算海森向量积的策略:- 组合反向模式自动微分与反向模式自动微分 - 组合反向模式自动微分与前向模式自动微分

组合反向模式自动微分与前向模式自动微分(而不是反向模式与反向模式)通常是更高效的内存计算海森向量积的方法,因为前向模式自动微分不需要构造自动微分图并保存中间结果以进行反向计算:

from torch.func import jvp, grad, vjp

def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]

这是一些示例用法。

def f(x):
  return x.sin().sum()

x = torch.randn(2048)
tangent = torch.randn(2048)

result = hvp(f, (x,), (tangent,))

如果PyTorch前向自动微分未涵盖您的操作,那么我们可以改为组合反向模式自动微分与反向模式自动微分:

def hvp_revrev(f, primals, tangents):
  _, vjp_fn = vjp(grad(f), *primals)
  return vjp_fn(*tangents)

result_hvp_revrev = hvp_revrev(f, (x,), (tangent,))
assert torch.allclose(result, result_hvp_revrev[0])

Total running time of the script: ( 0 minutes 8.070 seconds)

通过Sphinx-Gallery生成的图集

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源