{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model ensembling\n================\n\nThis tutorial illustrates how to vectorize model ensembling using\n`torch.vmap`.\n\nWhat is model ensembling?\n-------------------------\n\nModel ensembling combines the predictions from multiple models together.\nTraditionally this is done by running each model on some inputs\nseparately and then combining the predictions. However, if you\\'re\nrunning models with the same architecture, then it may be possible to\ncombine them together using `torch.vmap`. `vmap` is a function transform\nthat maps functions across dimensions of the input tensors. One of its\nuse cases is eliminating for-loops and speeding them up through\nvectorization.\n\nLet\\'s demonstrate how to do this using an ensemble of simple MLPs.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This tutorial requires PyTorch 2.0.0 or later.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\ntorch.manual_seed(0)\n\n# Here's a simple MLP\nclass SimpleMLP(nn.Module):\n def __init__(self):\n super(SimpleMLP, self).__init__()\n self.fc1 = nn.Linear(784, 128)\n self.fc2 = nn.Linear(128, 128)\n self.fc3 = nn.Linear(128, 10)\n\n def forward(self, x):\n x = x.flatten(1)\n x = self.fc1(x)\n x = F.relu(x)\n x = self.fc2(x)\n x = F.relu(x)\n x = self.fc3(x)\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's generate a batch of dummy data and pretend that we're working with\nan MNIST dataset. Thus, the dummy images are 28 by 28, and we have a\nminibatch of size 64. Furthermore, lets say we want to combine the\npredictions from 10 different models.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "device = 'cuda'\nnum_models = 10\n\ndata = torch.randn(100, 64, 1, 28, 28, device=device)\ntargets = torch.randint(10, (6400,), device=device)\n\nmodels = [SimpleMLP().to(device) for _ in range(num_models)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have a couple of options for generating predictions. Maybe we want to\ngive each model a different randomized minibatch of data. Alternatively,\nmaybe we want to run the same minibatch of data through each model (e.g.\nif we were testing the effect of different model initializations).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Option 1: different minibatch for each model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "minibatches = data[:num_models]\npredictions_diff_minibatch_loop = [model(minibatch) for model, minibatch in zip(models, minibatches)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Option 2: Same minibatch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "minibatch = data[0]\npredictions2 = [model(minibatch) for model in models]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using `vmap` to vectorize the ensemble\n======================================\n\nLet\\'s use `vmap` to speed up the for-loop. We must first prepare the\nmodels for use with `vmap`.\n\nFirst, let's combine the states of the model together by stacking each\nparameter. For example, `model[i].fc1.weight` has shape `[784, 128]`; we\nare going to stack the `.fc1.weight` of each of the 10 models to produce\na big weight of shape `[10, 784, 128]`.\n\nPyTorch offers the `torch.func.stack_module_state` convenience function\nto do this.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import stack_module_state\n\nparams, buffers = stack_module_state(models)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we need to define a function to `vmap` over. The function should,\ngiven parameters and buffers and inputs, run the model using those\nparameters, buffers, and inputs. We\\'ll use `torch.func.functional_call`\nto help out:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.func import functional_call\nimport copy\n\n# Construct a \"stateless\" version of one of the models. It is \"stateless\" in\n# the sense that the parameters are meta Tensors and do not have storage.\nbase_model = copy.deepcopy(models[0])\nbase_model = base_model.to('meta')\n\ndef fmodel(params, buffers, x):\n return functional_call(base_model, (params, buffers), (x,))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Option 1: get predictions using a different minibatch for each model.\n\nBy default, `vmap` maps a function across the first dimension of all\ninputs to the passed-in function. After using `stack_module_state`, each\nof the `params` and buffers have an additional dimension of size\n\\'num\\_models\\' at the front, and minibatches has a dimension of size\n\\'num\\_models\\'.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print([p.size(0) for p in params.values()]) # show the leading 'num_models' dimension\n\nassert minibatches.shape == (num_models, 64, 1, 28, 28) # verify minibatch has leading dimension of size 'num_models'\n\nfrom torch import vmap\n\npredictions1_vmap = vmap(fmodel)(params, buffers, minibatches)\n\n# verify the ``vmap`` predictions match the\nassert torch.allclose(predictions1_vmap, torch.stack(predictions_diff_minibatch_loop), atol=1e-3, rtol=1e-5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Option 2: get predictions using the same minibatch of data.\n\n`vmap` has an `in_dims` argument that specifies which dimensions to map\nover. By using `None`, we tell `vmap` we want the same minibatch to\napply for all of the 10 models.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "predictions2_vmap = vmap(fmodel, in_dims=(0, 0, None))(params, buffers, minibatch)\n\nassert torch.allclose(predictions2_vmap, torch.stack(predictions2), atol=1e-3, rtol=1e-5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "A quick note: there are limitations around what types of functions can\nbe transformed by `vmap`. The best functions to transform are ones that\nare pure functions: a function where the outputs are only determined by\nthe inputs that have no side effects (e.g. mutation). `vmap` is unable\nto handle mutation of arbitrary Python data structures, but it is able\nto handle many in-place PyTorch operations.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Performance\n===========\n\nCurious about performance numbers? Here\\'s how the numbers look.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.utils.benchmark import Timer\nwithout_vmap = Timer(\n stmt=\"[model(minibatch) for model, minibatch in zip(models, minibatches)]\",\n globals=globals())\nwith_vmap = Timer(\n stmt=\"vmap(fmodel)(params, buffers, minibatches)\",\n globals=globals())\nprint(f'Predictions without vmap {without_vmap.timeit(100)}')\nprint(f'Predictions with vmap {with_vmap.timeit(100)}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There\\'s a large speedup using `vmap`!\n\nIn general, vectorization with `vmap` should be faster than running a\nfunction in a for-loop and competitive with manual batching. There are\nsome exceptions though, like if we haven't implemented the `vmap` rule\nfor a particular operation or if the underlying kernels weren't\noptimized for older hardware (GPUs). If you see any of these cases,\nplease let us know by opening an issue on GitHub.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }