{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving and Loading Models\n=========================\n\n**Author:** [Matthew Inkawhich](https://github.com/MatthewInkawhich)\n\nThis document provides solutions to a variety of use cases regarding the\nsaving and loading of PyTorch models. Feel free to read the whole\ndocument, or just skip to the code you need for a desired use case.\n\nWhen it comes to saving and loading models, there are three core\nfunctions to be familiar with:\n\n1) [torch.save](https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save):\n Saves a serialized object to disk. This function uses Python's\n [pickle](https://docs.python.org/3/library/pickle.html) utility for\n serialization. Models, tensors, and dictionaries of all kinds of\n objects can be saved using this function.\n2) [torch.load](https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load):\n Uses [pickle](https://docs.python.org/3/library/pickle.html)'s\n unpickling facilities to deserialize pickled object files to memory.\n This function also facilitates the device to load the data into (see\n [Saving & Loading Model Across\n Devices](#saving-loading-model-across-devices)).\n3) [torch.nn.Module.load\\_state\\_dict](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict):\n Loads a model's parameter dictionary using a deserialized\n *state\\_dict*. For more information on *state\\_dict*, see [What is a\n state\\_dict?](#what-is-a-state-dict).\n\n**Contents:**\n\n- [What is a state\\_dict?](#what-is-a-state-dict)\n- [Saving & Loading Model for\n Inference](#saving-loading-model-for-inference)\n- [Saving & Loading a General\n Checkpoint](#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training)\n- [Saving Multiple Models in One\n File](#saving-multiple-models-in-one-file)\n- [Warmstarting Model Using Parameters from a Different\n Model](#warmstarting-model-using-parameters-from-a-different-model)\n- [Saving & Loading Model Across\n Devices](#saving-loading-model-across-devices)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What is a `state_dict`?\n=======================\n\nIn PyTorch, the learnable parameters (i.e.\u00a0weights and biases) of an\n`torch.nn.Module` model are contained in the model's *parameters*\n(accessed with `model.parameters()`). A *state\\_dict* is simply a Python\ndictionary object that maps each layer to its parameter tensor. Note\nthat only layers with learnable parameters (convolutional layers, linear\nlayers, etc.) and registered buffers (batchnorm\\'s running\\_mean) have\nentries in the model's *state\\_dict*. Optimizer objects (`torch.optim`)\nalso have a *state\\_dict*, which contains information about the\noptimizer\\'s state, as well as the hyperparameters used.\n\nBecause *state\\_dict* objects are Python dictionaries, they can be\neasily saved, updated, altered, and restored, adding a great deal of\nmodularity to PyTorch models and optimizers.\n\nExample:\n--------\n\nLet's take a look at the *state\\_dict* from the simple model used in the\n[Training a\nclassifier](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py)\ntutorial.\n\n``` {.python}\n# Define model\nclass TheModelClass(nn.Module):\n def __init__(self):\n super(TheModelClass, self).__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.pool = nn.MaxPool2d(2, 2)\n self.conv2 = nn.Conv2d(6, 16, 5)\n self.fc1 = nn.Linear(16 * 5 * 5, 120)\n self.fc2 = nn.Linear(120, 84)\n self.fc3 = nn.Linear(84, 10)\n\n def forward(self, x):\n x = self.pool(F.relu(self.conv1(x)))\n x = self.pool(F.relu(self.conv2(x)))\n x = x.view(-1, 16 * 5 * 5)\n x = F.relu(self.fc1(x))\n x = F.relu(self.fc2(x))\n x = self.fc3(x)\n return x\n\n# Initialize model\nmodel = TheModelClass()\n\n# Initialize optimizer\noptimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)\n\n# Print model's state_dict\nprint(\"Model's state_dict:\")\nfor param_tensor in model.state_dict():\n print(param_tensor, \"\\t\", model.state_dict()[param_tensor].size())\n\n# Print optimizer's state_dict\nprint(\"Optimizer's state_dict:\")\nfor var_name in optimizer.state_dict():\n print(var_name, \"\\t\", optimizer.state_dict()[var_name])\n```\n\n**Output:**\n\n``` {.sh}\nModel's state_dict:\nconv1.weight torch.Size([6, 3, 5, 5])\nconv1.bias torch.Size([6])\nconv2.weight torch.Size([16, 6, 5, 5])\nconv2.bias torch.Size([16])\nfc1.weight torch.Size([120, 400])\nfc1.bias torch.Size([120])\nfc2.weight torch.Size([84, 120])\nfc2.bias torch.Size([84])\nfc3.weight torch.Size([10, 84])\nfc3.bias torch.Size([10])\n\nOptimizer's state_dict:\nstate {}\nparam_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving & Loading Model for Inference\n====================================\n\nSave/Load `state_dict` (Recommended)\n------------------------------------\n\n**Save:**\n\n``` {.python}\ntorch.save(model.state_dict(), PATH)\n```\n\n**Load:**\n\n``` {.python}\nmodel = TheModelClass(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH, weights_only=True))\nmodel.eval()\n```\n\n```{=html}\n
The 1.6 release of PyTorch switched torch.save
to use a newzip file-based format. torch.load
still retains the ability toload files in the old format. If for any reason you want torch.save
to use the old format, pass the kwarg
parameter _use_new_zipfile_serialization=False
.
Notice that the load_state_dict()
function takes a dictionaryobject, NOT a path to a saved object. This means that you mustdeserialize the saved before you pass it to theload_state_dict()
function. For example, you CANNOT load usingmodel.load_state_dict(PATH)
.
If you only plan to keep the best performing model (according to theacquired validation loss), don't forget that best_model_state = model.state_dict()
returns a reference to the state and not its copy! You must serializebest_model_state
or use best_model_state = deepcopy(model.state_dict())
otherwiseyour best best_model_state
will keep getting updated by the subsequent trainingiterations. As a result, the final model state will be the state of the overfitted model.
Using the TorchScript format, you will be able to load the exported model andrun inference without defining the model class.
\n```\n```{=html}\n