{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "How to save memory by fusing the optimizer step into the backward pass\n======================================================================\n\nHello there! This tutorial aims to showcase one way of reducing the\nmemory footprint of a training loop by reducing the memory taken by the\n*gradients*. Say you have a model and you\\'re interested in ways to\noptimize memory to avoid `Out of Memory` (OOM) errors or simply to ooze\nmore out of your GPU. Well, you \\_[might]() be in luck (if gradients\ntake up a portion of your memory and you do not need to do gradient\naccumulation). We will explore the following:\n\n1. What takes up memory during your training or finetuning loop,\n2. How to capture and visualize memory snapshots to determine the\n bottleneck,\n3. The new `Tensor.register_post_accumulate_grad_hook(hook)` API, and\n finally,\n4. How everything fits together in 10 lines to achieve memory savings.\n\nTo run this tutorial, you will need:\n\n- PyTorch 2.1.0 or newer with `torchvision`\n- 1 CUDA GPU if you\\'d like to run the memory visualizations locally.\n Otherwise, this technique would benefit similarly on any device.\n\nLet us start by importing the required modules and models. We will use a\nvision transformer model from torchvision, but feel free to substitute\nwith your own model. We will also use `torch.optim.Adam` as our\noptimizer, but, again, feel free to substitute with your own optimizer.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torchvision import models\nfrom pickle import dump\n\nmodel = models.vit_l_16(weights='DEFAULT').cuda()\noptimizer = torch.optim.Adam(model.parameters())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now let\\'s define our typical training loop. You should use real images\nwhen training, but for the purposes of this tutorial, we are passing in\nfake inputs and not worrying about loading any actual data.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "IMAGE_SIZE = 224\n\ndef train(model, optimizer):\n # create our fake image input: tensor shape is batch_size, channels, height, width\n fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()\n\n # call our forward and backward\n loss = model.forward(fake_image)\n loss.sum().backward()\n\n # optimizer update\n optimizer.step()\n optimizer.zero_grad()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Memory usage during training\n============================\n\nWe are about to look at some memory snapshots, so we should be prepared\nto analyze them properly. Typically, training memory consists of:\n\n> - Model parameters (size P)\n> - Activations that are saved for the backward pass (size A)\n> - Gradients, which are the same size as the model parameters, so\n> size G = P.\n> - Optimizer state, which is proportional to the size of the\n> parameters. In this case, the state for Adam requires 2x the model\n> parameters, so size O = 2P.\n> - Intermediate tensors, which are allocated throughout the compute.\n> We will not worry about them for now as they are usually small and\n> ephemeral.\n\nCapturing and visualizing memory snapshots\n==========================================\n\nLet\\'s get us a memory snapshot! As your code runs, consider what you\nmay expect the CUDA memory timeline to look like.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# tell CUDA to start recording memory allocations\ntorch.cuda.memory._record_memory_history(enabled='all')\n\n# train 3 steps\nfor _ in range(3):\n train(model, optimizer)\n\n# save a snapshot of the memory allocations\ns = torch.cuda.memory._snapshot()\nwith open(f\"snapshot.pickle\", \"wb\") as f:\n dump(s, f)\n\n# tell CUDA to stop recording memory allocations now\ntorch.cuda.memory._record_memory_history(enabled=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now open up the snapshot in the CUDA Memory Visualizer at\n by dragging and dropping the\n`snapshot.pickle` file. Does the memory timeline match your\nexpectations?\n\n![](https://pytorch.org/tutorials/_static/img/optim_step_in_bwd/snapshot.jpg)\n\nThe model parameters have already been loaded in memory before the\ntraining step, so we see a chunk of memory devoted to the weights right\noff the bat. As we start our forward pass, memory is allocated gradually\nfor the activations, or the tensors we are saving to be able to compute\ngradients in the backward pass. Once we start the backward pass, the\nactivations are gradually freed while memory of the gradients starts\nbuilding up.\n\nLastly, as the optimizer kicks in, its state will be lazily initialized,\nso we should see the optimizer state memory gradually increase during\nthe optimizer step of the first training loop only. In future loops, the\noptimizer memory will remain and be updated in-place. The memory for the\ngradients is then freed accordingly at the end of every training loop\nwhen `zero_grad` is called.\n\nWhere is the memory bottleneck in this training loop? Or, in other\nwords, where is the peak memory?\n\nThe peak memory usage is during the optimizer step! Note the memory then\nconsists of \\~1.2GB of parameters, \\~1.2GB of gradients, and\n\\~2.4GB=2\\*1.2GB of the optimizer state as expected. The last \\~1.2GB\ncomes from Adam optimizer requiring memory for intermediates, totaling\nto \\~6GB of peak memory. Technically, you can remove the need for the\nlast 1.2GB for optimizer intermediates if you set\n`Adam(model.parameters(), foreach=False)` which would trade off runtime\nfor memory. If switching off the `foreach` runtime optimization is\nsufficient in memory savings for you, nice, but please read on if\nyou\\'re curious how this tutorial can help you do better! With the\ntechnique we will soon introduce, we will reduce peak memory by removing\nthe need for the \\~1.2GB of **gradients memory** as well as **optimizer\nintermediates memory**. Now, what would you expect the new peak memory\nto be? The answer will be revealed in the [next]{.title-ref} snapshot.\n\nDISCLAIMER: This technique is **not** for all\n=============================================\n\nBefore we get too excited, we have to consider whether this technique is\napplicable for [your]{.title-ref} use case. This is NOT a silver bullet!\nThe technique of fusing the optimizer step into the backward only\ntargets reducing *gradient* memory (and as a side effect also optimizer\nintermediates memory). Thus, the more sizable the memory taken up by the\ngradients, the more tantamount the memory reduction. In our example\nabove, the gradients eat up 20% of the memory pie, which is quite\nsizable!\n\nThis may not be the case for you, for example, if your weights are\nalready tiny, (say, due to applying LoRa,) then the gradients do not\ntake much space in your training loop and the wins are way less\nexciting. In that case, you should first try other techniques like\nactivations checkpointing, distributed training, quantization, or\nreducing the batch size. Then, when the gradients are part of the\nbottleneck again, come back to this tutorial!\n\nStill here? Cool, let\\'s introduce our new\n`register_post_accumulate_grad_hook(hook)` API on Tensor.\n\n`Tensor.register_post_accumulate_grad_hook(hook)` API and our technique\n=======================================================================\n\nOur technique relies on not having to save the gradients during\n`backward()`. Instead, once a gradient has been accumulated, we will\nimmediately apply the optimizer to the corresponding parameter and drop\nthat gradient entirely! This removes the need for holding onto a big\nbuffer of gradients until the optimizer step.\n\nSo how can we unlock the behavior of applying the optimizer more\neagerly? In our 2.1 release, we\\'ve added a new API\n`torch.Tensor.register_post_accumulate_grad_hook`{.interpreted-text\nrole=\"func\"} that would allow us to add a hook onto a Tensor once its\n`.grad` field has been accumulated. We will encapsulate the optimizer\nstep into this hook. How?\n\nHow everything fits together in 10 lines\n========================================\n\nRemember our model and optimizer setup from the beginning? I\\'ll leave\nthem commented out below so we don\\'t spend resources rerunning the\ncode.\n\n``` {.python}\nmodel = models.vit_l_16(weights='DEFAULT').cuda()\noptimizer = torch.optim.Adam(model.parameters())\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Instead of having just *one* optimizer, we will have a ``dict`` of optimizers\n# for every parameter so we could reference them in our hook.\noptimizer_dict = {p: torch.optim.Adam([p], foreach=False) for p in model.parameters()}\n\n# Define our hook, which will call the optimizer ``step()`` and ``zero_grad()``\ndef optimizer_hook(parameter) -> None:\n optimizer_dict[parameter].step()\n optimizer_dict[parameter].zero_grad()\n\n# Register the hook onto every parameter\nfor p in model.parameters():\n p.register_post_accumulate_grad_hook(optimizer_hook)\n\n# Now remember our previous ``train()`` function? Since the optimizer has been\n# fused into the backward, we can remove the optimizer step and zero_grad calls.\ndef train(model):\n # create our fake image input: tensor shape is batch_size, channels, height, width\n fake_image = torch.rand(1, 3, IMAGE_SIZE, IMAGE_SIZE).cuda()\n\n # call our forward and backward\n loss = model.forward(fake_image)\n loss.sum().backward()\n\n # optimizer update --> no longer needed!\n # optimizer.step()\n # optimizer.zero_grad()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "That took about 10 lines of changes in our sample model, which is neat.\nHowever, for real models, it could be a fairly intrusive change to\nswitch out the optimizer for an optimizer dictionary, especially for\nthose who use `LRScheduler`s or manipulate optimizer configuration\nthroughout the training epochs. Working out this API with those changes\nwill be more involved and will likely require moving more configuration\ninto global state but should not be impossible. That said, a next step\nfor PyTorch is to make this API easier to adopt with LRSchedulers and\nother features you are already used to.\n\nBut let me get back to convincing you that this technique is worth it.\nWe will consult our friend, the memory snapshot.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# delete optimizer memory from before to get a clean slate for the next\n# memory snapshot\ndel optimizer\n\n# tell CUDA to start recording memory allocations\ntorch.cuda.memory._record_memory_history(enabled='all')\n\n# train 3 steps. note that we no longer pass the optimizer into train()\nfor _ in range(3):\n train(model)\n\n# save a snapshot of the memory allocations\ns = torch.cuda.memory._snapshot()\nwith open(f\"snapshot-opt-in-bwd.pickle\", \"wb\") as f:\n dump(s, f)\n\n# tell CUDA to stop recording memory allocations now\ntorch.cuda.memory._record_memory_history(enabled=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Yes, take some time to drag your snapshot into the CUDA Memory\nVisualizer.\n\n![](https://pytorch.org/tutorials/_static/img/optim_step_in_bwd/snapshot_opt_in_bwd.jpg)\n\nSeveral major observations:\n\n: 1. There is no more optimizer step! Right\\...we fused that into the\n backward.\n 2. Likewise, the backward drags longer and there are more random\n allocations for intermediates. This is expected, as the\n optimizer step requires intermediates.\n 3. Most importantly! The peak memory is lower! It is now \\~4GB\n (which I hope maps closely to your earlier expectation).\n\nNote that there is no longer any big chunk of memory allocated for the\ngradients compared to before, accounting for \\~1.2GB of memory savings.\nInstead, we\\'ve freed each gradient very quickly after they\\'ve been\ncomputed by moving the optimizer step as far ahead as we can. Woohoo! By\nthe way, the other \\~1.2GB of memory savings comes from breaking apart\nthe optimizer into per-parameter optimizers, so the intermediates have\nproportionally shrunk. This detail is [less important]{.title-ref} than\nthe gradient memory savings, as you can get optimizer intermediates\nsavings from just turning `foreach=False` without this technique.\n\nYou may be correctly wondering: if we saved 2.4GB of memory, why is the\npeak memory NOT 6GB - 2.4GB = 3.6GB? Well, the peak has moved! The peak\nis now near the start of the backward step, when we still have\nactivations in memory, where before, the peak was during the optimizer\nstep when the activations had been freed. The \\~0.4GB difference\naccounting for \\~4.0GB - \\~3.6GB is thus due to the activations memory.\nOne can then imagine that this technique can be coupled with activations\ncheckpointing for more memory wins.\n\nConclusion\n==========\n\nIn this tutorial, we learned about the memory saving technique of fusing\nthe optimizer into the backward step through the new\n`Tensor.register_post_accumulate_grad_hook()` API and *when* to apply\nthis technique (when gradients memory is significant). Along the way, we\nalso learned about memory snapshots, which are generally useful in\nmemory optimization.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }