{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Per-sample-gradients\n====================\n\nWhat is it?\n-----------\n\nPer-sample-gradient computation is computing the gradient for each and\nevery sample in a batch of data. It is a useful quantity in differential\nprivacy, meta-learning, and optimization research.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This tutorial requires PyTorch 2.0.0 or later.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\ntorch.manual_seed(0)\n\n# Here's a simple CNN and loss function:\n\nclass SimpleCNN(nn.Module):\n def __init__(self):\n super(SimpleCNN, self).__init__()\n self.conv1 = nn.Conv2d(1, 32, 3, 1)\n self.conv2 = nn.Conv2d(32, 64, 3, 1)\n self.fc1 = nn.Linear(9216, 128)\n self.fc2 = nn.Linear(128, 10)\n\n def forward(self, x):\n x = self.conv1(x)\n x = F.relu(x)\n x = self.conv2(x)\n x = F.relu(x)\n x = F.max_pool2d(x, 2)\n x = torch.flatten(x, 1)\n x = self.fc1(x)\n x = F.relu(x)\n x = self.fc2(x)\n output = F.log_softmax(x, dim=1)\n return output\n\ndef loss_fn(predictions, targets):\n return F.nll_loss(predictions, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's generate a batch of dummy data and pretend that we're working with\nan MNIST dataset. The dummy images are 28 by 28 and we use a minibatch\nof size 64.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "device = 'cuda'\n\nnum_models = 10\nbatch_size = 64\ndata = torch.randn(batch_size, 1, 28, 28, device=device)\n\ntargets = torch.randint(10, (64,), device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In regular model training, one would forward the minibatch through the\nmodel, and then call .backward() to compute gradients. This would\ngenerate an \\'average\\' gradient of the entire mini-batch:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = SimpleCNN().to(device=device)\npredictions = model(data) # move the entire mini-batch through the model\n\nloss = loss_fn(predictions, targets)\nloss.backward() # back propagate the 'average' gradient of this mini-batch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In contrast to the above approach, per-sample-gradient computation is\nequivalent to:\n\n- for each individual sample of the data, perform a forward and a\n backward pass to get an individual (per-sample) gradient.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def compute_grad(sample, target):\n sample = sample.unsqueeze(0) # prepend batch dimension for processing\n target = target.unsqueeze(0)\n\n prediction = model(sample)\n loss = loss_fn(prediction, target)\n\n return torch.autograd.grad(loss, list(model.parameters()))\n\n\ndef compute_sample_grads(data, targets):\n \"\"\" manually process each sample with per sample gradient \"\"\"\n sample_grads = [compute_grad(data[i], targets[i]) for i in range(batch_size)]\n sample_grads = zip(*sample_grads)\n sample_grads = [torch.stack(shards) for shards in sample_grads]\n return sample_grads\n\nper_sample_grads = compute_sample_grads(data, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`sample_grads[0]` is the per-sample-grad for model.conv1.weight.\n`model.conv1.weight.shape` is `[32, 1, 3, 3]`; notice how there is one\ngradient, per sample, in the batch for a total of 64.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(per_sample_grads[0].shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Per-sample-grads, *the efficient way*, using function transforms\n================================================================\n\nWe can compute per-sample-gradients efficiently by using function\ntransforms.\n\nThe `torch.func` function transform API transforms over functions. Our\nstrategy is to define a function that computes the loss and then apply\ntransforms to construct a function that computes per-sample-gradients.\n\nWe\\'ll use the `torch.func.functional_call` function to treat an\n`nn.Module` like a function.\n\nFirst, let's extract the state from `model` into two dictionaries,\nparameters and buffers. We\\'ll be detaching them because we won\\'t use\nregular PyTorch autograd (e.g. Tensor.backward(), torch.autograd.grad).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import functional_call, vmap, grad\n\nparams = {k: v.detach() for k, v in model.named_parameters()}\nbuffers = {k: v.detach() for k, v in model.named_buffers()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let\\'s define a function to compute the loss of the model given a\nsingle input rather than a batch of inputs. It is important that this\nfunction accepts the parameters, the input, and the target, because we\nwill be transforming over them.\n\nNote - because the model was originally written to handle batches, we'll\nuse `torch.unsqueeze` to add a batch dimension.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def compute_loss(params, buffers, sample, target):\n batch = sample.unsqueeze(0)\n targets = target.unsqueeze(0)\n\n predictions = functional_call(model, (params, buffers), (batch,))\n loss = loss_fn(predictions, targets)\n return loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, let's use the `grad` transform to create a new function that\ncomputes the gradient with respect to the first argument of\n`compute_loss` (i.e. the `params`).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ft_compute_grad = grad(compute_loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `ft_compute_grad` function computes the gradient for a single\n(sample, target) pair. We can use `vmap` to get it to compute the\ngradient over an entire batch of samples and targets. Note that\n`in_dims=(None, None, 0, 0)` because we wish to map `ft_compute_grad`\nover the 0th dimension of the data and targets, and use the same\n`params` and buffers for each.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 0, 0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Finally, let\\'s used our transformed function to compute\nper-sample-gradients:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ft_per_sample_grads = ft_compute_sample_grad(params, buffers, data, targets)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "we can double check that the results using `grad` and `vmap` match the\nresults of hand processing each one individually:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for per_sample_grad, ft_per_sample_grad in zip(per_sample_grads, ft_per_sample_grads.values()):\n assert torch.allclose(per_sample_grad, ft_per_sample_grad, atol=3e-3, rtol=1e-5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A quick note: there are limitations around what types of functions can\nbe transformed by `vmap`. The best functions to transform are ones that\nare pure functions: a function where the outputs are only determined by\nthe inputs, and that have no side effects (e.g. mutation). `vmap` is\nunable to handle mutation of arbitrary Python data structures, but it is\nable to handle many in-place PyTorch operations.\n\nPerformance comparison\n======================\n\nCurious about how the performance of `vmap` compares?\n\nCurrently the best results are obtained on newer GPU\\'s such as the A100\n(Ampere) where we\\'ve seen up to 25x speedups on this example, but here\nare some results on our build machines:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_perf(first, first_descriptor, second, second_descriptor):\n \"\"\"takes torch.benchmark objects and compares delta of second vs first.\"\"\"\n second_res = second.times[0]\n first_res = first.times[0]\n\n gain = (first_res-second_res)/first_res\n if gain < 0: gain *=-1 \n final_gain = gain*100\n\n print(f\"Performance delta: {final_gain:.4f} percent improvement with {first_descriptor} \")\n\nfrom torch.utils.benchmark import Timer\n\nwithout_vmap = Timer(stmt=\"compute_sample_grads(data, targets)\", globals=globals())\nwith_vmap = Timer(stmt=\"ft_compute_sample_grad(params, buffers, data, targets)\",globals=globals())\nno_vmap_timing = without_vmap.timeit(100)\nwith_vmap_timing = with_vmap.timeit(100)\n\nprint(f'Per-sample-grads without vmap {no_vmap_timing}')\nprint(f'Per-sample-grads with vmap {with_vmap_timing}')\n\nget_perf(with_vmap_timing, \"vmap\", no_vmap_timing, \"no vmap\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are other optimized solutions (like in\n) to computing per-sample-gradients\nin PyTorch that also perform better than the naive method. But it's cool\nthat composing `vmap` and `grad` give us a nice speedup.\n\nIn general, vectorization with `vmap` should be faster than running a\nfunction in a for-loop and competitive with manual batching. There are\nsome exceptions though, like if we haven't implemented the `vmap` rule\nfor a particular operation or if the underlying kernels weren't\noptimized for older hardware (GPUs). If you see any of these cases,\nplease let us know by opening an issue at on GitHub.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }