{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving and Loading Models\n=========================\n\n**Author:** [Matthew Inkawhich](https://github.com/MatthewInkawhich)\n\nThis document provides solutions to a variety of use cases regarding the\nsaving and loading of PyTorch models. Feel free to read the whole\ndocument, or just skip to the code you need for a desired use case.\n\nWhen it comes to saving and loading models, there are three core\nfunctions to be familiar with:\n\n1) [torch.save](https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save):\n Saves a serialized object to disk. This function uses Python's\n [pickle](https://docs.python.org/3/library/pickle.html) utility for\n serialization. Models, tensors, and dictionaries of all kinds of\n objects can be saved using this function.\n2) [torch.load](https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load):\n Uses [pickle](https://docs.python.org/3/library/pickle.html)'s\n unpickling facilities to deserialize pickled object files to memory.\n This function also facilitates the device to load the data into (see\n [Saving & Loading Model Across\n Devices](#saving-loading-model-across-devices)).\n3) [torch.nn.Module.load\\_state\\_dict](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict):\n Loads a model's parameter dictionary using a deserialized\n *state\\_dict*. For more information on *state\\_dict*, see [What is a\n state\\_dict?](#what-is-a-state-dict).\n\n**Contents:**\n\n- [What is a state\\_dict?](#what-is-a-state-dict)\n- [Saving & Loading Model for\n Inference](#saving-loading-model-for-inference)\n- [Saving & Loading a General\n Checkpoint](#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training)\n- [Saving Multiple Models in One\n File](#saving-multiple-models-in-one-file)\n- [Warmstarting Model Using Parameters from a Different\n Model](#warmstarting-model-using-parameters-from-a-different-model)\n- [Saving & Loading Model Across\n Devices](#saving-loading-model-across-devices)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What is a `state_dict`?\n=======================\n\nIn PyTorch, the learnable parameters (i.e.\u00a0weights and biases) of an\n`torch.nn.Module` model are contained in the model's *parameters*\n(accessed with `model.parameters()`). A *state\\_dict* is simply a Python\ndictionary object that maps each layer to its parameter tensor. Note\nthat only layers with learnable parameters (convolutional layers, linear\nlayers, etc.) and registered buffers (batchnorm\\'s running\\_mean) have\nentries in the model's *state\\_dict*. Optimizer objects (`torch.optim`)\nalso have a *state\\_dict*, which contains information about the\noptimizer\\'s state, as well as the hyperparameters used.\n\nBecause *state\\_dict* objects are Python dictionaries, they can be\neasily saved, updated, altered, and restored, adding a great deal of\nmodularity to PyTorch models and optimizers.\n\nExample:\n--------\n\nLet's take a look at the *state\\_dict* from the simple model used in the\n[Training a\nclassifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py)\ntutorial.\n\n``` {.python}\n# Define model\nclass TheModelClass(nn.Module):\n def __init__(self):\n super(TheModelClass, self).__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.pool = nn.MaxPool2d(2, 2)\n self.conv2 = nn.Conv2d(6, 16, 5)\n self.fc1 = nn.Linear(16 * 5 * 5, 120)\n self.fc2 = nn.Linear(120, 84)\n self.fc3 = nn.Linear(84, 10)\n\n def forward(self, x):\n x = self.pool(F.relu(self.conv1(x)))\n x = self.pool(F.relu(self.conv2(x)))\n x = x.view(-1, 16 * 5 * 5)\n x = F.relu(self.fc1(x))\n x = F.relu(self.fc2(x))\n x = self.fc3(x)\n return x\n\n# Initialize model\nmodel = TheModelClass()\n\n# Initialize optimizer\noptimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n\n# Print model's state_dict\nprint(\"Model's state_dict:\")\nfor param_tensor in model.state_dict():\n print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())\n\n# Print optimizer's state_dict\nprint(\"Optimizer's state_dict:\")\nfor var_name in optimizer.state_dict():\n print(var_name, \"\\t\", optimizer.state_dict()[var_name])\n```\n\n**Output:**\n\n``` {.sh}\nModel's state_dict:\nconv1.weight torch.Size([6, 3, 5, 5])\nconv1.bias torch.Size([6])\nconv2.weight torch.Size([16, 6, 5, 5])\nconv2.bias torch.Size([16])\nfc1.weight torch.Size([120, 400])\nfc1.bias torch.Size([120])\nfc2.weight torch.Size([84, 120])\nfc2.bias torch.Size([84])\nfc3.weight torch.Size([10, 84])\nfc3.bias torch.Size([10])\n\nOptimizer's state_dict:\nstate {}\nparam_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving & Loading Model for Inference\n====================================\n\nSave/Load `state_dict` (Recommended)\n------------------------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model.state_dict(), PATH)\n```\n\n**Load:**\n\n``` {.python}\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, weights_only=True))\nmodel.eval()\n```\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

The 1.6 release of PyTorch switched torch.save to use a newzip file-based format. torch.load still retains the ability toload files in the old format. If for any reason you want torch.saveto use the old format, pass the kwarg parameter _use_new_zipfile_serialization=False.

\n```\n```{=html}\n
\n```\nWhen saving a model for inference, it is only necessary to save the\ntrained model's learned parameters. Saving the model's *state\\_dict*\nwith the `torch.save()` function will give you the most flexibility for\nrestoring the model later, which is why it is the recommended method for\nsaving models.\n\nA common PyTorch convention is to save models using either a `.pt` or\n`.pth` file extension.\n\nRemember that you must call `model.eval()` to set dropout and batch\nnormalization layers to evaluation mode before running inference.\nFailing to do this will yield inconsistent inference results.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

Notice that the load_state_dict() function takes a dictionaryobject, NOT a path to a saved object. This means that you mustdeserialize the saved before you pass it to theload_state_dict() function. For example, you CANNOT load usingmodel.load_state_dict(PATH).

\n```\n```{=html}\n
\n```\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

If you only plan to keep the best performing model (according to theacquired validation loss), don't forget that best_model_state = model.state_dict()returns a reference to the state and not its copy! You must serializebest_model_state or use best_model_state = deepcopy(model.state_dict()) otherwiseyour best best_model_state will keep getting updated by the subsequent trainingiterations. As a result, the final model state will be the state of the overfitted model.

\n```\n```{=html}\n
\n```\nSave/Load Entire Model\n----------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model, PATH)\n```\n\n**Load:**\n\n``` {.python}\n# Model class must be defined somewhere\nmodel = torch.load(PATH, weights_only=False)\nmodel.eval()\n```\n\nThis save/load process uses the most intuitive syntax and involves the\nleast amount of code. Saving a model in this way will save the entire\nmodule using Python's\n[pickle](https://docs.python.org/3/library/pickle.html) module. The\ndisadvantage of this approach is that the serialized data is bound to\nthe specific classes and the exact directory structure used when the\nmodel is saved. The reason for this is because pickle does not save the\nmodel class itself. Rather, it saves a path to the file containing the\nclass, which is used during load time. Because of this, your code can\nbreak in various ways when used in other projects or after refactors.\n\nA common PyTorch convention is to save models using either a `.pt` or\n`.pth` file extension.\n\nRemember that you must call `model.eval()` to set dropout and batch\nnormalization layers to evaluation mode before running inference.\nFailing to do this will yield inconsistent inference results.\n\nExport/Load Model in TorchScript Format\n---------------------------------------\n\nOne common way to do inference with a trained model is to use\n[TorchScript](https://pytorch.org/docs/stable/jit.html), an intermediate\nrepresentation of a PyTorch model that can be run in Python as well as\nin a high performance environment like C++. TorchScript is actually the\nrecommended model format for scaled inference and deployment.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

Using the TorchScript format, you will be able to load the exported model andrun inference without defining the model class.

\n```\n```{=html}\n
\n```\n**Export:**\n\n``` {.python}\nmodel_scripted = torch.jit.script(model) # Export to TorchScript\nmodel_scripted.save('model_scripted.pt') # Save\n```\n\n**Load:**\n\n``` {.python}\nmodel = torch.jit.load('model_scripted.pt')\nmodel.eval()\n```\n\nRemember that you must call `model.eval()` to set dropout and batch\nnormalization layers to evaluation mode before running inference.\nFailing to do this will yield inconsistent inference results.\n\nFor more information on TorchScript, feel free to visit the dedicated\n[tutorials](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html).\nYou will get familiar with the tracing conversion and learn how to run a\nTorchScript module in a [C++\nenvironment](https://pytorch.org/tutorials/advanced/cpp_export.html).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving & Loading a General Checkpoint for Inference and/or Resuming Training\n============================================================================\n\nSave:\n-----\n\n``` {.python}\ntorch.save({\n 'epoch': epoch,\n 'model_state_dict': model.state_dict(),\n 'optimizer_state_dict': optimizer.state_dict(),\n 'loss': loss,\n ...\n }, PATH)\n```\n\nLoad:\n-----\n\n``` {.python}\nmodel = TheModelClass(*args, **kwargs)\noptimizer = TheOptimizerClass(*args, **kwargs)\n\ncheckpoint = torch.load(PATH, weights_only=True)\nmodel.load_state_dict(checkpoint['model_state_dict'])\noptimizer.load_state_dict(checkpoint['optimizer_state_dict'])\nepoch = checkpoint['epoch']\nloss = checkpoint['loss']\n\nmodel.eval()\n# - or -\nmodel.train()\n```\n\nWhen saving a general checkpoint, to be used for either inference or\nresuming training, you must save more than just the model's\n*state\\_dict*. It is important to also save the optimizer\\'s\n*state\\_dict*, as this contains buffers and parameters that are updated\nas the model trains. Other items that you may want to save are the epoch\nyou left off on, the latest recorded training loss, external\n`torch.nn.Embedding` layers, etc. As a result, such a checkpoint is\noften 2\\~3 times larger than the model alone.\n\nTo save multiple components, organize them in a dictionary and use\n`torch.save()` to serialize the dictionary. A common PyTorch convention\nis to save these checkpoints using the `.tar` file extension.\n\nTo load the items, first initialize the model and optimizer, then load\nthe dictionary locally using `torch.load()`. From here, you can easily\naccess the saved items by simply querying the dictionary as you would\nexpect.\n\nRemember that you must call `model.eval()` to set dropout and batch\nnormalization layers to evaluation mode before running inference.\nFailing to do this will yield inconsistent inference results. If you\nwish to resuming training, call `model.train()` to ensure these layers\nare in training mode.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving Multiple Models in One File\n==================================\n\nSave:\n-----\n\n``` {.python}\ntorch.save({\n 'modelA_state_dict': modelA.state_dict(),\n 'modelB_state_dict': modelB.state_dict(),\n 'optimizerA_state_dict': optimizerA.state_dict(),\n 'optimizerB_state_dict': optimizerB.state_dict(),\n ...\n }, PATH)\n```\n\nLoad:\n-----\n\n``` {.python}\nmodelA = TheModelAClass(*args, **kwargs)\nmodelB = TheModelBClass(*args, **kwargs)\noptimizerA = TheOptimizerAClass(*args, **kwargs)\noptimizerB = TheOptimizerBClass(*args, **kwargs)\n\ncheckpoint = torch.load(PATH, weights_only=True)\nmodelA.load_state_dict(checkpoint['modelA_state_dict'])\nmodelB.load_state_dict(checkpoint['modelB_state_dict'])\noptimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])\noptimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])\n\nmodelA.eval()\nmodelB.eval()\n# - or -\nmodelA.train()\nmodelB.train()\n```\n\nWhen saving a model comprised of multiple `torch.nn.Modules`, such as a\nGAN, a sequence-to-sequence model, or an ensemble of models, you follow\nthe same approach as when you are saving a general checkpoint. In other\nwords, save a dictionary of each model's *state\\_dict* and corresponding\noptimizer. As mentioned before, you can save any other items that may\naid you in resuming training by simply appending them to the dictionary.\n\nA common PyTorch convention is to save these checkpoints using the\n`.tar` file extension.\n\nTo load the models, first initialize the models and optimizers, then\nload the dictionary locally using `torch.load()`. From here, you can\neasily access the saved items by simply querying the dictionary as you\nwould expect.\n\nRemember that you must call `model.eval()` to set dropout and batch\nnormalization layers to evaluation mode before running inference.\nFailing to do this will yield inconsistent inference results. If you\nwish to resuming training, call `model.train()` to set these layers to\ntraining mode.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Warmstarting Model Using Parameters from a Different Model\n==========================================================\n\nSave:\n-----\n\n``` {.python}\ntorch.save(modelA.state_dict(), PATH)\n```\n\nLoad:\n-----\n\n``` {.python}\nmodelB = TheModelBClass(*args, **kwargs)\nmodelB.load_state_dict(torch.load(PATH, weights_only=True), strict=False)\n```\n\nPartially loading a model or loading a partial model are common\nscenarios when transfer learning or training a new complex model.\nLeveraging trained parameters, even if only a few are usable, will help\nto warmstart the training process and hopefully help your model converge\nmuch faster than training from scratch.\n\nWhether you are loading from a partial *state\\_dict*, which is missing\nsome keys, or loading a *state\\_dict* with more keys than the model that\nyou are loading into, you can set the `strict` argument to **False** in\nthe `load_state_dict()` function to ignore non-matching keys.\n\nIf you want to load parameters from one layer to another, but some keys\ndo not match, simply change the name of the parameter keys in the\n*state\\_dict* that you are loading to match the keys in the model that\nyou are loading into.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving & Loading Model Across Devices\n=====================================\n\nSave on GPU, Load on CPU\n------------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model.state_dict(), PATH)\n```\n\n**Load:**\n\n``` {.python}\ndevice = torch.device('cpu')\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, map_location=device, weights_only=True))\n```\n\nWhen loading a model on a CPU that was trained with a GPU, pass\n`torch.device('cpu')` to the `map_location` argument in the\n`torch.load()` function. In this case, the storages underlying the\ntensors are dynamically remapped to the CPU device using the\n`map_location` argument.\n\nSave on GPU, Load on GPU\n------------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model.state_dict(), PATH)\n```\n\n**Load:**\n\n``` {.python}\ndevice = torch.device(\"cuda\")\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, weights_only=True))\nmodel.to(device)\n# Make sure to call input = input.to(device) on any input tensors that you feed to the model\n```\n\nWhen loading a model on a GPU that was trained and saved on GPU, simply\nconvert the initialized `model` to a CUDA optimized model using\n`model.to(torch.device('cuda'))`. Also, be sure to use the\n`.to(torch.device('cuda'))` function on all model inputs to prepare the\ndata for the model. Note that calling `my_tensor.to(device)` returns a\nnew copy of `my_tensor` on GPU. It does NOT overwrite `my_tensor`.\nTherefore, remember to manually overwrite tensors:\n`my_tensor = my_tensor.to(torch.device('cuda'))`.\n\nSave on CPU, Load on GPU\n------------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model.state_dict(), PATH)\n```\n\n**Load:**\n\n``` {.python}\ndevice = torch.device(\"cuda\")\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, weights_only=True, map_location=\"cuda:0\")) # Choose whatever GPU device number you want\nmodel.to(device)\n# Make sure to call input = input.to(device) on any input tensors that you feed to the model\n```\n\nWhen loading a model on a GPU that was trained and saved on CPU, set the\n`map_location` argument in the `torch.load()` function to\n`cuda:device_id`. This loads the model to a given GPU device. Next, be\nsure to call `model.to(torch.device('cuda'))` to convert the model's\nparameter tensors to CUDA tensors. Finally, be sure to use the\n`.to(torch.device('cuda'))` function on all model inputs to prepare the\ndata for the CUDA optimized model. Note that calling\n`my_tensor.to(device)` returns a new copy of `my_tensor` on GPU. It does\nNOT overwrite `my_tensor`. Therefore, remember to manually overwrite\ntensors: `my_tensor = my_tensor.to(torch.device('cuda'))`.\n\nSaving `torch.nn.DataParallel` Models\n-------------------------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model.module.state_dict(), PATH)\n```\n\n**Load:**\n\n``` {.python}\n# Load to whatever device you want\n```\n\n`torch.nn.DataParallel` is a model wrapper that enables parallel GPU\nutilization. To save a `DataParallel` model generically, save the\n`model.module.state_dict()`. This way, you have the flexibility to load\nthe model any way you want to any device you want.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }