{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Learn the Basics](intro.html) \\|\\|\n[Quickstart](quickstart_tutorial.html) \\|\\|\n[Tensors](tensorqs_tutorial.html) \\|\\| [Datasets &\nDataLoaders](data_tutorial.html) \\|\\|\n[Transforms](transforms_tutorial.html) \\|\\| [Build\nModel](buildmodel_tutorial.html) \\|\\|\n[Autograd](autogradqs_tutorial.html) \\|\\|\n[Optimization](optimization_tutorial.html) \\|\\| **Save & Load Model**\n\nSave and Load the Model\n=======================\n\nIn this section we will look at how to persist model state with saving,\nloading and running model predictions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchvision.models as models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving and Loading Model Weights\n================================\n\nPyTorch models store the learned parameters in an internal state\ndictionary, called `state_dict`. These can be persisted via the\n`torch.save` method:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = models.vgg16(weights='IMAGENET1K_V1')\ntorch.save(model.state_dict(), 'model_weights.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load model weights, you need to create an instance of the same model\nfirst, and then load the parameters using `load_state_dict()` method.\n\nIn the code below, we set `weights_only=True` to limit the functions\nexecuted during unpickling to only those necessary for loading weights.\nUsing `weights_only=True` is considered a best practice when loading\nweights.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model\nmodel.load_state_dict(torch.load('model_weights.pth', weights_only=True))\nmodel.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

be sure to call model.eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving and Loading Models with Shapes\n=====================================\n\nWhen loading model weights, we needed to instantiate the model class\nfirst, because the class defines the structure of a network. We might\nwant to save the structure of this class together with the model, in\nwhich case we can pass `model` (and not `model.state_dict()`) to the\nsaving function:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.save(model, 'model.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can then load the model as demonstrated below.\n\nAs described in [Saving and loading\ntorch.nn.Modules](https://pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules),\nsaving `state_dict` is considered the best practice. However, below we\nuse `weights_only=False` because this involves loading the model, which\nis a legacy use case for `torch.save`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = torch.load('model.pth', weights_only=False)," ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Related Tutorials\n=================\n\n- [Saving and Loading a General Checkpoint in\n PyTorch](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html)\n- [Tips for loading an nn.Module from a\n checkpoint](https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html?highlight=loading%20nn%20module%20from%20checkpoint)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }