{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "torch.vmap\n==========\n\nThis tutorial introduces torch.vmap, an autovectorizer for PyTorch\noperations. torch.vmap is a prototype feature and cannot handle a number\nof use cases; however, we would like to gather use cases for it to\ninform the design. If you are considering using torch.vmap or think it\nwould be really cool for something, please contact us at\n.\n\nSo, what is vmap?\n-----------------\n\nvmap is a higher-order function. It accepts a function\n[func]{.title-ref} and returns a new function that maps\n[func]{.title-ref} over some dimension of the inputs. It is highly\ninspired by JAX\\'s vmap.\n\nSemantically, vmap pushes the \\\"map\\\" into PyTorch operations called by\n[func]{.title-ref}, effectively vectorizing those operations.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\n# NB: vmap is only available on nightly builds of PyTorch.\n# You can download one at pytorch.org if you're interested in testing it out.\nfrom torch import vmap" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first use case for vmap is making it easier to handle batch\ndimensions in your code. One can write a function [func]{.title-ref}\nthat runs on examples and then lift it to a function that can take\nbatches of examples with [vmap(func)]{.title-ref}. [func]{.title-ref}\nhowever is subject to many restrictions:\n\n- it must be functional (one cannot mutate a Python data structure\n inside of it), with the exception of in-place PyTorch operations.\n- batches of examples must be provided as Tensors. This means that\n vmap doesn\\'t handle variable-length sequences out of the box.\n\nOne example of using [vmap]{.title-ref} is to compute batched dot\nproducts. PyTorch doesn\\'t provide a batched [torch.dot]{.title-ref}\nAPI; instead of unsuccessfully rummaging through docs, use\n[vmap]{.title-ref} to construct a new function:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.dot # [D], [D] -> []\nbatched_dot = torch.vmap(torch.dot) # [N, D], [N, D] -> [N]\nx, y = torch.randn(2, 5), torch.randn(2, 5)\nbatched_dot(x, y)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[vmap]{.title-ref} can be helpful in hiding batch dimensions, leading to\na simpler model authoring experience.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "batch_size, feature_size = 3, 5\nweights = torch.randn(feature_size, requires_grad=True)\n\n# Note that model doesn't work with a batch of feature vectors because\n# torch.dot must take 1D tensors. It's pretty easy to rewrite this\n# to use `torch.matmul` instead, but if we didn't want to do that or if\n# the code is more complicated (e.g., does some advanced indexing\n# shenanigins), we can simply call `vmap`. `vmap` batches over ALL\n# inputs, unless otherwise specified (with the in_dims argument,\n# please see the documentation for more details).\ndef model(feature_vec):\n # Very simple linear model with activation\n return feature_vec.dot(weights).relu()\n\nexamples = torch.randn(batch_size, feature_size)\nresult = torch.vmap(model)(examples)\nexpected = torch.stack([model(example) for example in examples.unbind()])\nassert torch.allclose(result, expected)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[vmap]{.title-ref} can also help vectorize computations that were\npreviously difficult or impossible to batch. This bring us to our second\nuse case: batched gradient computation.\n\n- \n- \n\nThe PyTorch autograd engine computes vjps (vector-Jacobian products).\nUsing vmap, we can compute (batched vector) - jacobian products.\n\nOne example of this is computing a full Jacobian matrix (this can also\nbe applied to computing a full Hessian matrix). Computing a full\nJacobian matrix for some function f: R\\^N -\\> R\\^N usually requires N\ncalls to [autograd.grad]{.title-ref}, one per Jacobian row.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Setup\nN = 5\ndef f(x):\n return x ** 2\n\nx = torch.randn(N, requires_grad=True)\ny = f(x)\nbasis_vectors = torch.eye(N)\n\n# Sequential approach\njacobian_rows = [torch.autograd.grad(y, x, v, retain_graph=True)[0]\n for v in basis_vectors.unbind()]\njacobian = torch.stack(jacobian_rows)\n\n# Using `vmap`, we can vectorize the whole computation, computing the\n# Jacobian in a single call to `autograd.grad`.\ndef get_vjp(v):\n return torch.autograd.grad(y, x, v)[0]\n\njacobian_vmap = vmap(get_vjp)(basis_vectors)\nassert torch.allclose(jacobian_vmap, jacobian)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The third main use case for vmap is computing per-sample-gradients. This\nis something that the vmap prototype cannot handle performantly right\nnow. We\\'re not sure what the API for computing per-sample-gradients\nshould be, but if you have ideas, please comment in\n.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def model(sample, weight):\n # do something... \n return torch.dot(sample, weight)\n\ndef grad_sample(sample):\n return torch.autograd.functional.vjp(lambda weight: model(sample), weight)[1]\n\n# The following doesn't actually work in the vmap prototype. But it\n# could be an API for computing per-sample-gradients.\n\n# batch_of_samples = torch.randn(64, 5)\n# vmap(grad_sample)(batch_of_samples)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }