{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Jacobians, Hessians, hvp, vhp, and more: composing function transforms\n======================================================================\n\nComputing jacobians or hessians are useful in a number of\nnon-traditional deep learning models. It is difficult (or annoying) to\ncompute these quantities efficiently using PyTorch\\'s regular autodiff\nAPIs (`Tensor.backward()`, `torch.autograd.grad`). PyTorch\\'s\n[JAX-inspired](https://github.com/google/jax) [function transforms\nAPI](https://pytorch.org/docs/master/func.html) provides ways of\ncomputing various higher-order autodiff quantities efficiently.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This tutorial requires PyTorch 2.0.0 or later.

\n```\n```{=html}\n
\n```\nComputing the Jacobian\n----------------------\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn.functional as F\nfrom functools import partial\n_ = torch.manual_seed(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s start with a function that we\\'d like to compute the jacobian of.\nThis is a simple linear function with non-linear activation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def predict(weight, bias, x):\n return F.linear(x, weight, bias).tanh()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s add some dummy data: a weight, a bias, and a feature vector x.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "D = 16\nweight = torch.randn(D, D)\nbias = torch.randn(D)\nx = torch.randn(D) # feature vector" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s think of `predict` as a function that maps the input `x` from\n$R^D \\to R^D$. PyTorch Autograd computes vector-Jacobian products. In\norder to compute the full Jacobian of this $R^D \\to R^D$ function, we\nwould have to compute it row-by-row by using a different unit vector\neach time.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def compute_jac(xp):\n jacobian_rows = [torch.autograd.grad(predict(weight, bias, xp), xp, vec)[0]\n for vec in unit_vectors]\n return torch.stack(jacobian_rows)\n\nxp = x.clone().requires_grad_()\nunit_vectors = torch.eye(D)\n\njacobian = compute_jac(xp)\n\nprint(jacobian.shape)\nprint(jacobian[0]) # show first row" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Instead of computing the jacobian row-by-row, we can use PyTorch\\'s\n`torch.vmap` function transform to get rid of the for-loop and vectorize\nthe computation. We can't directly apply `vmap` to\n`torch.autograd.grad`; instead, PyTorch provides a `torch.func.vjp`\ntransform that composes with `torch.vmap`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import vmap, vjp\n\n_, vjp_fn = vjp(partial(predict, weight, bias), x)\n\nft_jacobian, = vmap(vjp_fn)(unit_vectors)\n\n# let's confirm both methods compute the same result\nassert torch.allclose(ft_jacobian, jacobian)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In a later tutorial a composition of reverse-mode AD and `vmap` will\ngive us per-sample-gradients. In this tutorial, composing reverse-mode\nAD and `vmap` gives us Jacobian computation! Various compositions of\n`vmap` and autodiff transforms can give us different interesting\nquantities.\n\nPyTorch provides `torch.func.jacrev` as a convenience function that\nperforms the `vmap-vjp` composition to compute jacobians. `jacrev`\naccepts an `argnums` argument that says which argument we would like to\ncompute Jacobians with respect to.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import jacrev\n\nft_jacobian = jacrev(predict, argnums=2)(weight, bias, x)\n\n# Confirm by running the following:\nassert torch.allclose(ft_jacobian, jacobian)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s compare the performance of the two ways to compute the jacobian.\nThe function transform version is much faster (and becomes even faster\nthe more outputs there are).\n\nIn general, we expect that vectorization via `vmap` can help eliminate\noverhead and give better utilization of your hardware.\n\n`vmap` does this magic by pushing the outer loop down into the\nfunction\\'s primitive operations in order to obtain better performance.\n\nLet\\'s make a quick function to evaluate performance and deal with\nmicroseconds and milliseconds measurements:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_perf(first, first_descriptor, second, second_descriptor):\n \"\"\"takes torch.benchmark objects and compares delta of second vs first.\"\"\"\n faster = second.times[0]\n slower = first.times[0]\n gain = (slower-faster)/slower\n if gain < 0: gain *=-1\n final_gain = gain*100\n print(f\" Performance delta: {final_gain:.4f} percent improvement with {second_descriptor} \")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And then run the performance comparison:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.utils.benchmark import Timer\n\nwithout_vmap = Timer(stmt=\"compute_jac(xp)\", globals=globals())\nwith_vmap = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n\nno_vmap_timer = without_vmap.timeit(500)\nwith_vmap_timer = with_vmap.timeit(500)\n\nprint(no_vmap_timer)\nprint(with_vmap_timer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s do a relative performance comparison of the above with our\n`get_perf` function:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "get_perf(no_vmap_timer, \"without vmap\", with_vmap_timer, \"vmap\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Furthermore, it's pretty easy to flip the problem around and say we want\nto compute Jacobians of the parameters to our model (weight, bias)\ninstead of the input\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# note the change in input via ``argnums`` parameters of 0,1 to map to weight and bias\nft_jac_weight, ft_jac_bias = jacrev(predict, argnums=(0, 1))(weight, bias, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reverse-mode Jacobian (`jacrev`) vs forward-mode Jacobian (`jacfwd`)\n====================================================================\n\nWe offer two APIs to compute jacobians: `jacrev` and `jacfwd`:\n\n- `jacrev` uses reverse-mode AD. As you saw above it is a composition\n of our `vjp` and `vmap` transforms.\n- `jacfwd` uses forward-mode AD. It is implemented as a composition of\n our `jvp` and `vmap` transforms.\n\n`jacfwd` and `jacrev` can be substituted for each other but they have\ndifferent performance characteristics.\n\nAs a general rule of thumb, if you're computing the jacobian of an\n$R^N \\to R^M$ function, and there are many more outputs than inputs (for\nexample, $M > N$) then `jacfwd` is preferred, otherwise use `jacrev`.\nThere are exceptions to this rule, but a non-rigorous argument for this\nfollows:\n\nIn reverse-mode AD, we are computing the jacobian row-by-row, while in\nforward-mode AD (which computes Jacobian-vector products), we are\ncomputing it column-by-column. The Jacobian matrix has M rows and N\ncolumns, so if it is taller or wider one way we may prefer the method\nthat deals with fewer rows or columns.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import jacrev, jacfwd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let\\'s benchmark with more inputs than outputs:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Din = 32\nDout = 2048\nweight = torch.randn(Dout, Din)\n\nbias = torch.randn(Dout)\nx = torch.randn(Din)\n\n# remember the general rule about taller vs wider... here we have a taller matrix:\nprint(weight.shape)\n\nusing_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\nusing_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n\njacfwd_timing = using_fwd.timeit(500)\njacrev_timing = using_bwd.timeit(500)\n\nprint(f'jacfwd time: {jacfwd_timing}')\nprint(f'jacrev time: {jacrev_timing}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and then do a relative benchmark:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "get_perf(jacfwd_timing, \"jacfwd\", jacrev_timing, \"jacrev\", );" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and now the reverse - more outputs (M) than inputs (N):\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "Din = 2048\nDout = 32\nweight = torch.randn(Dout, Din)\nbias = torch.randn(Dout)\nx = torch.randn(Din)\n\nusing_fwd = Timer(stmt=\"jacfwd(predict, argnums=2)(weight, bias, x)\", globals=globals())\nusing_bwd = Timer(stmt=\"jacrev(predict, argnums=2)(weight, bias, x)\", globals=globals())\n\njacfwd_timing = using_fwd.timeit(500)\njacrev_timing = using_bwd.timeit(500)\n\nprint(f'jacfwd time: {jacfwd_timing}')\nprint(f'jacrev time: {jacrev_timing}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and a relative performance comparison:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "get_perf(jacrev_timing, \"jacrev\", jacfwd_timing, \"jacfwd\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hessian computation with functorch.hessian\n==========================================\n\nWe offer a convenience API to compute hessians: `torch.func.hessiani`.\nHessians are the jacobian of the jacobian (or the partial derivative of\nthe partial derivative, aka second order).\n\nThis suggests that one can just compose functorch jacobian transforms to\ncompute the Hessian. Indeed, under the hood, `hessian(f)` is simply\n`jacfwd(jacrev(f))`.\n\nNote: to boost performance: depending on your model, you may also want\nto use `jacfwd(jacfwd(f))` or `jacrev(jacrev(f))` instead to compute\nhessians leveraging the rule of thumb above regarding wider vs taller\nmatrices.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import hessian\n\n# lets reduce the size in order not to overwhelm Colab. Hessians require\n# significant memory:\nDin = 512\nDout = 32\nweight = torch.randn(Dout, Din)\nbias = torch.randn(Dout)\nx = torch.randn(Din)\n\nhess_api = hessian(predict, argnums=2)(weight, bias, x)\nhess_fwdfwd = jacfwd(jacfwd(predict, argnums=2), argnums=2)(weight, bias, x)\nhess_revrev = jacrev(jacrev(predict, argnums=2), argnums=2)(weight, bias, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s verify we have the same result regardless of using hessian API or\nusing `jacfwd(jacfwd())`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.allclose(hess_api, hess_fwdfwd)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Batch Jacobian and Batch Hessian\n================================\n\nIn the above examples we've been operating with a single feature vector.\nIn some cases you might want to take the Jacobian of a batch of outputs\nwith respect to a batch of inputs. That is, given a batch of inputs of\nshape `(B, N)` and a function that goes from $R^N \\to R^M$, we would\nlike a Jacobian of shape `(B, M, N)`.\n\nThe easiest way to do this is to use `vmap`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "batch_size = 64\nDin = 31\nDout = 33\n\nweight = torch.randn(Dout, Din)\nprint(f\"weight shape = {weight.shape}\")\n\nbias = torch.randn(Dout)\n\nx = torch.randn(batch_size, Din)\n\ncompute_batch_jacobian = vmap(jacrev(predict, argnums=2), in_dims=(None, None, 0))\nbatch_jacobian0 = compute_batch_jacobian(weight, bias, x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you have a function that goes from (B, N) -\\> (B, M) instead and are\ncertain that each input produces an independent output, then it\\'s also\nsometimes possible to do this without using `vmap` by summing the\noutputs and then computing the Jacobian of that function:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def predict_with_output_summed(weight, bias, x):\n return predict(weight, bias, x).sum(0)\n\nbatch_jacobian1 = jacrev(predict_with_output_summed, argnums=2)(weight, bias, x).movedim(1, 0)\nassert torch.allclose(batch_jacobian0, batch_jacobian1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you instead have a function that goes from $R^N \\to R^M$ but inputs\nthat are batched, you compose `vmap` with `jacrev` to compute batched\njacobians:\n\nFinally, batch hessians can be computed similarly. It\\'s easiest to\nthink about them by using `vmap` to batch over hessian computation, but\nin some cases the sum trick also works.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "compute_batch_hessian = vmap(hessian(predict, argnums=2), in_dims=(None, None, 0))\n\nbatch_hess = compute_batch_hessian(weight, bias, x)\nbatch_hess.shape" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Computing Hessian-vector products\n=================================\n\nThe naive way to compute a Hessian-vector product (hvp) is to\nmaterialize the full Hessian and perform a dot-product with a vector. We\ncan do better: it turns out we don\\'t need to materialize the full\nHessian to do this. We\\'ll go through two (of many) different strategies\nto compute Hessian-vector products: - composing reverse-mode AD with\nreverse-mode AD - composing reverse-mode AD with forward-mode AD\n\nComposing reverse-mode AD with forward-mode AD (as opposed to\nreverse-mode with reverse-mode) is generally the more memory efficient\nway to compute a hvp because forward-mode AD doesn\\'t need to construct\nan Autograd graph and save intermediates for backward:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import jvp, grad, vjp\n\ndef hvp(f, primals, tangents):\n return jvp(grad(f), primals, tangents)[1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here\\'s some sample usage.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def f(x):\n return x.sin().sum()\n\nx = torch.randn(2048)\ntangent = torch.randn(2048)\n\nresult = hvp(f, (x,), (tangent,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If PyTorch forward-AD does not have coverage for your operations, then\nwe can instead compose reverse-mode AD with reverse-mode AD:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def hvp_revrev(f, primals, tangents):\n _, vjp_fn = vjp(grad(f), *primals)\n return vjp_fn(*tangents)\n\nresult_hvp_revrev = hvp_revrev(f, (x,), (tangent,))\nassert torch.allclose(result, result_hvp_revrev[0])" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }