{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Learn the Basics](intro.html) \\|\\|\n[Quickstart](quickstart_tutorial.html) \\|\\|\n[Tensors](tensorqs_tutorial.html) \\|\\| [Datasets &\nDataLoaders](data_tutorial.html) \\|\\|\n[Transforms](transforms_tutorial.html) \\|\\| [Build\nModel](buildmodel_tutorial.html) \\|\\|\n[Autograd](autogradqs_tutorial.html) \\|\\|\n[Optimization](optimization_tutorial.html) \\|\\| **Save & Load Model**\n\nSave and Load the Model\n=======================\n\nIn this section we will look at how to persist model state with saving,\nloading and running model predictions.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchvision.models as models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Saving and Loading Model Weights\n================================\n\nPyTorch models store the learned parameters in an internal state\ndictionary, called `state_dict`. These can be persisted via the\n`torch.save` method:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = models.vgg16(weights='IMAGENET1K_V1')\ntorch.save(model.state_dict(), 'model_weights.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load model weights, you need to create an instance of the same model\nfirst, and then load the parameters using `load_state_dict()` method.\n\nIn the code below, we set `weights_only=True` to limit the functions\nexecuted during unpickling to only those necessary for loading weights.\nUsing `weights_only=True` is considered a best practice when loading\nweights.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model\nmodel.load_state_dict(torch.load('model_weights.pth', weights_only=True))\nmodel.eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
be sure to call model.eval()
method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.
This approach uses Python pickle module when serializing the model, thus it relies on the actual class definition to be available when loading the model.
\n```\n```{=html}\n