{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[Introduction](introyt1_tutorial.html) \\|\\|\n[Tensors](tensors_deeper_tutorial.html) \\|\\|\n[Autograd](autogradyt_tutorial.html) \\|\\| [Building\nModels](modelsyt_tutorial.html) \\|\\| [TensorBoard\nSupport](tensorboardyt_tutorial.html) \\|\\| **Training Models** \\|\\|\n[Model Understanding](captumyt.html)\n\nTraining with PyTorch\n=====================\n\nFollow along with the video below or on\n[youtube](https://www.youtube.com/watch?v=jF43_wj_DCQ).\n\n``` {.python .jupyter-code-cell}\nfrom IPython.display import display, HTML\nhtml_code = \"\"\"\n
\n \n
\n\"\"\"\ndisplay(HTML(html_code))\n```\n\nIntroduction\n------------\n\nIn past videos, we've discussed and demonstrated:\n\n- Building models with the neural network layers and functions of the\n torch.nn module\n- The mechanics of automated gradient computation, which is central to\n gradient-based model training\n- Using TensorBoard to visualize training progress and other\n activities\n\nIn this video, we'll be adding some new tools to your inventory:\n\n- We'll get familiar with the dataset and dataloader abstractions, and\n how they ease the process of feeding data to your model during a\n training loop\n- We'll discuss specific loss functions and when to use them\n- We'll look at PyTorch optimizers, which implement algorithms to\n adjust model weights based on the outcome of a loss function\n\nFinally, we'll pull all of these together and see a full PyTorch\ntraining loop in action.\n\nDataset and DataLoader\n----------------------\n\nThe `Dataset` and `DataLoader` classes encapsulate the process of\npulling your data from storage and exposing it to your training loop in\nbatches.\n\nThe `Dataset` is responsible for accessing and processing single\ninstances of data.\n\nThe `DataLoader` pulls instances of data from the `Dataset` (either\nautomatically or with a sampler that you define), collects them in\nbatches, and returns them for consumption by your training loop. The\n`DataLoader` works with all kinds of datasets, regardless of the type of\ndata they contain.\n\nFor this tutorial, we'll be using the Fashion-MNIST dataset provided by\nTorchVision. We use `torchvision.transforms.Normalize()` to zero-center\nand normalize the distribution of the image tile content, and download\nboth training and validation data splits.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torchvision\nimport torchvision.transforms as transforms\n\n# PyTorch TensorBoard support\nfrom torch.utils.tensorboard import SummaryWriter\nfrom datetime import datetime\n\n\ntransform = transforms.Compose(\n [transforms.ToTensor(),\n transforms.Normalize((0.5,), (0.5,))])\n\n# Create datasets for training & validation, download if necessary\ntraining_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)\nvalidation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)\n\n# Create data loaders for our datasets; shuffle for training, not for validation\ntraining_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)\nvalidation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)\n\n# Class labels\nclasses = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')\n\n# Report split sizes\nprint('Training set has {} instances'.format(len(training_set)))\nprint('Validation set has {} instances'.format(len(validation_set)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As always, let's visualize the data as a sanity check:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\nimport numpy as np\n\n# Helper function for inline image display\ndef matplotlib_imshow(img, one_channel=False):\n if one_channel:\n img = img.mean(dim=0)\n img = img / 2 + 0.5 # unnormalize\n npimg = img.numpy()\n if one_channel:\n plt.imshow(npimg, cmap=\"Greys\")\n else:\n plt.imshow(np.transpose(npimg, (1, 2, 0)))\n\ndataiter = iter(training_loader)\nimages, labels = next(dataiter)\n\n# Create a grid from the images and show them\nimg_grid = torchvision.utils.make_grid(images)\nmatplotlib_imshow(img_grid, one_channel=True)\nprint(' '.join(classes[labels[j]] for j in range(4)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Model\n=========\n\nThe model we'll use in this example is a variant of LeNet-5 - it should\nbe familiar if you've watched the previous videos in this series.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch.nn as nn\nimport torch.nn.functional as F\n\n# PyTorch models inherit from torch.nn.Module\nclass GarmentClassifier(nn.Module):\n def __init__(self):\n super(GarmentClassifier, self).__init__()\n self.conv1 = nn.Conv2d(1, 6, 5)\n self.pool = nn.MaxPool2d(2, 2)\n self.conv2 = nn.Conv2d(6, 16, 5)\n self.fc1 = nn.Linear(16 * 4 * 4, 120)\n self.fc2 = nn.Linear(120, 84)\n self.fc3 = nn.Linear(84, 10)\n\n def forward(self, x):\n x = self.pool(F.relu(self.conv1(x)))\n x = self.pool(F.relu(self.conv2(x)))\n x = x.view(-1, 16 * 4 * 4)\n x = F.relu(self.fc1(x))\n x = F.relu(self.fc2(x))\n x = self.fc3(x)\n return x\n \n\nmodel = GarmentClassifier()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loss Function\n=============\n\nFor this example, we'll be using a cross-entropy loss. For demonstration\npurposes, we'll create batches of dummy output and label values, run\nthem through the loss function, and examine the result.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "loss_fn = torch.nn.CrossEntropyLoss()\n\n# NB: Loss functions expect data in batches, so we're creating batches of 4\n# Represents the model's confidence in each of the 10 classes for a given input\ndummy_outputs = torch.rand(4, 10)\n# Represents the correct class among the 10 being tested\ndummy_labels = torch.tensor([1, 5, 3, 7])\n \nprint(dummy_outputs)\nprint(dummy_labels)\n\nloss = loss_fn(dummy_outputs, dummy_labels)\nprint('Total loss for this batch: {}'.format(loss.item()))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimizer\n=========\n\nFor this example, we'll be using simple [stochastic gradient\ndescent](https://pytorch.org/docs/stable/optim.html) with momentum.\n\nIt can be instructive to try some variations on this optimization\nscheme:\n\n- Learning rate determines the size of the steps the optimizer takes.\n What does a different learning rate do to the your training results,\n in terms of accuracy and convergence time?\n- Momentum nudges the optimizer in the direction of strongest gradient\n over multiple steps. What does changing this value do to your\n results?\n- Try some different optimization algorithms, such as averaged SGD,\n Adagrad, or Adam. How do your results differ?\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Optimizers specified in the torch.optim package\noptimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Training Loop\n=================\n\nBelow, we have a function that performs one training epoch. It\nenumerates data from the DataLoader, and on each pass of the loop does\nthe following:\n\n- Gets a batch of training data from the DataLoader\n- Zeros the optimizer's gradients\n- Performs an inference - that is, gets predictions from the model for\n an input batch\n- Calculates the loss for that set of predictions vs. the labels on\n the dataset\n- Calculates the backward gradients over the learning weights\n- Tells the optimizer to perform one learning step - that is, adjust\n the model's learning weights based on the observed gradients for\n this batch, according to the optimization algorithm we chose\n- It reports on the loss for every 1000 batches.\n- Finally, it reports the average per-batch loss for the last 1000\n batches, for comparison with a validation run\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train_one_epoch(epoch_index, tb_writer):\n running_loss = 0.\n last_loss = 0.\n \n # Here, we use enumerate(training_loader) instead of\n # iter(training_loader) so that we can track the batch\n # index and do some intra-epoch reporting\n for i, data in enumerate(training_loader):\n # Every data instance is an input + label pair\n inputs, labels = data\n \n # Zero your gradients for every batch!\n optimizer.zero_grad()\n \n # Make predictions for this batch\n outputs = model(inputs)\n \n # Compute the loss and its gradients\n loss = loss_fn(outputs, labels)\n loss.backward()\n \n # Adjust learning weights\n optimizer.step()\n \n # Gather data and report\n running_loss += loss.item()\n if i % 1000 == 999:\n last_loss = running_loss / 1000 # loss per batch\n print(' batch {} loss: {}'.format(i + 1, last_loss))\n tb_x = epoch_index * len(training_loader) + i + 1\n tb_writer.add_scalar('Loss/train', last_loss, tb_x)\n running_loss = 0.\n \n return last_loss" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Per-Epoch Activity\n==================\n\nThere are a couple of things we'll want to do once per epoch:\n\n- Perform validation by checking our relative loss on a set of data\n that was not used for training, and report this\n- Save a copy of the model\n\nHere, we'll do our reporting in TensorBoard. This will require going to\nthe command line to start TensorBoard, and opening it in another browser\ntab.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Initializing in a separate cell so we can easily add more epochs to the same run\ntimestamp = datetime.now().strftime('%Y%m%d_%H%M%S')\nwriter = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))\nepoch_number = 0\n\nEPOCHS = 5\n\nbest_vloss = 1_000_000.\n\nfor epoch in range(EPOCHS):\n print('EPOCH {}:'.format(epoch_number + 1))\n \n # Make sure gradient tracking is on, and do a pass over the data\n model.train(True)\n avg_loss = train_one_epoch(epoch_number, writer)\n \n\n running_vloss = 0.0\n # Set the model to evaluation mode, disabling dropout and using population \n # statistics for batch normalization.\n model.eval()\n\n # Disable gradient computation and reduce memory consumption.\n with torch.no_grad():\n for i, vdata in enumerate(validation_loader):\n vinputs, vlabels = vdata\n voutputs = model(vinputs)\n vloss = loss_fn(voutputs, vlabels)\n running_vloss += vloss\n \n avg_vloss = running_vloss / (i + 1)\n print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))\n \n # Log the running loss averaged per batch\n # for both training and validation\n writer.add_scalars('Training vs. Validation Loss',\n { 'Training' : avg_loss, 'Validation' : avg_vloss },\n epoch_number + 1)\n writer.flush()\n \n # Track best performance, and save the model's state\n if avg_vloss < best_vloss:\n best_vloss = avg_vloss\n model_path = 'model_{}_{}'.format(timestamp, epoch_number)\n torch.save(model.state_dict(), model_path)\n \n epoch_number += 1" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To load a saved version of the model:\n\n``` {.python}\nsaved_model = GarmentClassifier()\nsaved_model.load_state_dict(torch.load(PATH))\n```\n\nOnce you've loaded the model, it's ready for whatever you need it for\n-more training, inference, or analysis.\n\nNote that if your model has constructor parameters that affect model\nstructure, you'll need to provide them and configure the model\nidentically to the state in which it was saved.\n\nOther Resources\n===============\n\n- Docs on the [data\n utilities](https://pytorch.org/docs/stable/data.html), including\n Dataset and DataLoader, at pytorch.org\n- A [note on the use of pinned\n memory](https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning)\n for GPU training\n- Documentation on the datasets available in\n [TorchVision](https://pytorch.org/vision/stable/datasets.html),\n [TorchText](https://pytorch.org/text/stable/datasets.html), and\n [TorchAudio](https://pytorch.org/audio/stable/datasets.html)\n- Documentation on the [loss\n functions](https://pytorch.org/docs/stable/nn.html#loss-functions)\n available in PyTorch\n- Documentation on the [torch.optim\n package](https://pytorch.org/docs/stable/optim.html), which includes\n optimizers and related tools, such as learning rate scheduling\n- A detailed [tutorial on saving and loading\n models](https://pytorch.org/tutorials/beginner/saving_loading_models.html)\n- The [Tutorials section of\n pytorch.org](https://pytorch.org/tutorials/) contains tutorials on a\n broad variety of training tasks, including classification in\n different domains, generative adversarial networks, reinforcement\n learning, and more\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }