{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Knowledge Distillation Tutorial\n===============================\n\n**Author**: [Alexandros Chariton](https://github.com/AlexandrosChrtn)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Knowledge distillation is a technique that enables knowledge transfer\nfrom large, computationally expensive models to smaller ones without\nlosing validity. This allows for deployment on less powerful hardware,\nmaking evaluation faster and more efficient.\n\nIn this tutorial, we will run a number of experiments focused at\nimproving the accuracy of a lightweight neural network, using a more\npowerful network as a teacher. The computational cost and the speed of\nthe lightweight network will remain unaffected, our intervention only\nfocuses on its weights, not on its forward pass. Applications of this\ntechnology can be found in devices such as drones or mobile phones. In\nthis tutorial, we do not use any external packages as everything we need\nis available in `torch` and `torchvision`.\n\nIn this tutorial, you will learn:\n\n- How to modify model classes to extract hidden representations and\n use them for further calculations\n- How to modify regular train loops in PyTorch to include additional\n losses on top of, for example, cross-entropy for classification\n- How to improve the performance of lightweight models by using more\n complex models as teachers\n\nPrerequisites\n=============\n\n- 1 GPU, 4GB of memory\n- PyTorch v2.0 or later\n- CIFAR-10 dataset (downloaded by the script and saved in a directory\n called `/data`)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch.optim as optim\nimport torchvision.transforms as transforms\nimport torchvision.datasets as datasets\n\n# Check if the current `accelerator `__\n# is available, and if not, use the CPU\ndevice = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else \"cpu\"\nprint(f\"Using {device} device\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loading CIFAR-10\n================\n\nCIFAR-10 is a popular image dataset with ten classes. Our objective is\nto predict one of the following classes for each input image.\n\n![Example of CIFAR-10\nimages](https://pytorch.org/tutorials//../_static/img/cifar10.png){.align-center}\n\nThe input images are RGB, so they have 3 channels and are 32x32 pixels.\nBasically, each image is described by 3 x 32 x 32 = 3072 numbers ranging\nfrom 0 to 255. A common practice in neural networks is to normalize the\ninput, which is done for multiple reasons, including avoiding saturation\nin commonly used activation functions and increasing numerical\nstability. Our normalization process consists of subtracting the mean\nand dividing by the standard deviation along each channel. The tensors\n\\\"mean=\\[0.485, 0.456, 0.406\\]\\\" and \\\"std=\\[0.229, 0.224, 0.225\\]\\\"\nwere already computed, and they represent the mean and standard\ndeviation of each channel in the predefined subset of CIFAR-10 intended\nto be the training set. Notice how we use these values for the test set\nas well, without recomputing the mean and standard deviation from\nscratch. This is because the network was trained on features produced by\nsubtracting and dividing the numbers above, and we want to maintain\nconsistency. Furthermore, in real life, we would not be able to compute\nthe mean and standard deviation of the test set since, under our\nassumptions, this data would not be accessible at that point.\n\nAs a closing point, we often refer to this held-out set as the\nvalidation set, and we use a separate set, called the test set, after\noptimizing a model\\'s performance on the validation set. This is done to\navoid selecting a model based on the greedy and biased optimization of a\nsingle metric.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Below we are preprocessing data for CIFAR-10. We use an arbitrary batch size of 128.\ntransforms_cifar = transforms.Compose([\n transforms.ToTensor(),\n transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n])\n\n# Loading the CIFAR-10 dataset:\ntrain_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms_cifar)\ntest_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms_cifar)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This section is for CPU users only who are interested in quick results. Use this option only if you're interested in a small scale experiment. Keep in mind the code should run fairly quickly using any GPU. Select only the first num_images_to_keep images from the train/test dataset

#from torch.utils.data import Subset

\n

num_images_to_keep = 2000

\n

train_dataset = Subset(train_dataset, range(min(num_images_to_keep, 50_000)))

\n

test_dataset = Subset(test_dataset, range(min(num_images_to_keep, 10_000)))

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "#Dataloaders\ntrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)\ntest_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Defining model classes and utility functions\n============================================\n\nNext, we need to define our model classes. Several user-defined\nparameters need to be set here. We use two different architectures,\nkeeping the number of filters fixed across our experiments to ensure\nfair comparisons. Both architectures are Convolutional Neural Networks\n(CNNs) with a different number of convolutional layers that serve as\nfeature extractors, followed by a classifier with 10 classes. The number\nof filters and neurons is smaller for the students.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Deeper neural network class to be used as teacher:\nclass DeepNN(nn.Module):\n def __init__(self, num_classes=10):\n super(DeepNN, self).__init__()\n self.features = nn.Sequential(\n nn.Conv2d(3, 128, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.Conv2d(128, 64, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n nn.Conv2d(64, 64, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.Conv2d(64, 32, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n )\n self.classifier = nn.Sequential(\n nn.Linear(2048, 512),\n nn.ReLU(),\n nn.Dropout(0.1),\n nn.Linear(512, num_classes)\n )\n\n def forward(self, x):\n x = self.features(x)\n x = torch.flatten(x, 1)\n x = self.classifier(x)\n return x\n\n# Lightweight neural network class to be used as student:\nclass LightNN(nn.Module):\n def __init__(self, num_classes=10):\n super(LightNN, self).__init__()\n self.features = nn.Sequential(\n nn.Conv2d(3, 16, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n nn.Conv2d(16, 16, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n )\n self.classifier = nn.Sequential(\n nn.Linear(1024, 256),\n nn.ReLU(),\n nn.Dropout(0.1),\n nn.Linear(256, num_classes)\n )\n\n def forward(self, x):\n x = self.features(x)\n x = torch.flatten(x, 1)\n x = self.classifier(x)\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We employ 2 functions to help us produce and evaluate the results on our\noriginal classification task. One function is called `train` and takes\nthe following arguments:\n\n- `model`: A model instance to train (update its weights) via this\n function.\n- `train_loader`: We defined our `train_loader` above, and its job is\n to feed the data into the model.\n- `epochs`: How many times we loop over the dataset.\n- `learning_rate`: The learning rate determines how large our steps\n towards convergence should be. Too large or too small steps can be\n detrimental.\n- `device`: Determines the device to run the workload on. Can be\n either CPU or GPU depending on availability.\n\nOur test function is similar, but it will be invoked with `test_loader`\nto load images from the test set.\n\n![Train both networks with Cross-Entropy. The student will be used as a\nbaseline:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/ce_only.png){.align-center}\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train(model, train_loader, epochs, learning_rate, device):\n criterion = nn.CrossEntropyLoss()\n optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n\n model.train()\n\n for epoch in range(epochs):\n running_loss = 0.0\n for inputs, labels in train_loader:\n # inputs: A collection of batch_size images\n # labels: A vector of dimensionality batch_size with integers denoting class of each image\n inputs, labels = inputs.to(device), labels.to(device)\n\n optimizer.zero_grad()\n outputs = model(inputs)\n\n # outputs: Output of the network for the collection of images. A tensor of dimensionality batch_size x num_classes\n # labels: The actual labels of the images. Vector of dimensionality batch_size\n loss = criterion(outputs, labels)\n loss.backward()\n optimizer.step()\n\n running_loss += loss.item()\n\n print(f\"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}\")\n\ndef test(model, test_loader, device):\n model.to(device)\n model.eval()\n\n correct = 0\n total = 0\n\n with torch.no_grad():\n for inputs, labels in test_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n\n outputs = model(inputs)\n _, predicted = torch.max(outputs.data, 1)\n\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n\n accuracy = 100 * correct / total\n print(f\"Test Accuracy: {accuracy:.2f}%\")\n return accuracy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cross-entropy runs\n==================\n\nFor reproducibility, we need to set the torch manual seed. We train\nnetworks using different methods, so to compare them fairly, it makes\nsense to initialize the networks with the same weights. Start by\ntraining the teacher network using cross-entropy:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(42)\nnn_deep = DeepNN(num_classes=10).to(device)\ntrain(nn_deep, train_loader, epochs=10, learning_rate=0.001, device=device)\ntest_accuracy_deep = test(nn_deep, test_loader, device)\n\n# Instantiate the lightweight network:\ntorch.manual_seed(42)\nnn_light = LightNN(num_classes=10).to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We instantiate one more lightweight network model to compare their\nperformances. Back propagation is sensitive to weight initialization, so\nwe need to make sure these two networks have the exact same\ninitialization.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(42)\nnew_nn_light = LightNN(num_classes=10).to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To ensure we have created a copy of the first network, we inspect the\nnorm of its first layer. If it matches, then we are safe to conclude\nthat the networks are indeed the same.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Print the norm of the first layer of the initial lightweight model\nprint(\"Norm of 1st layer of nn_light:\", torch.norm(nn_light.features[0].weight).item())\n# Print the norm of the first layer of the new lightweight model\nprint(\"Norm of 1st layer of new_nn_light:\", torch.norm(new_nn_light.features[0].weight).item())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Print the total number of parameters in each model:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "total_params_deep = \"{:,}\".format(sum(p.numel() for p in nn_deep.parameters()))\nprint(f\"DeepNN parameters: {total_params_deep}\")\ntotal_params_light = \"{:,}\".format(sum(p.numel() for p in nn_light.parameters()))\nprint(f\"LightNN parameters: {total_params_light}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Train and test the lightweight network with cross entropy loss:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "train(nn_light, train_loader, epochs=10, learning_rate=0.001, device=device)\ntest_accuracy_light_ce = test(nn_light, test_loader, device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, based on test accuracy, we can now compare the deeper\nnetwork that is to be used as a teacher with the lightweight network\nthat is our supposed student. So far, our student has not intervened\nwith the teacher, therefore this performance is achieved by the student\nitself. The metrics so far can be seen with the following lines:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(f\"Teacher accuracy: {test_accuracy_deep:.2f}%\")\nprint(f\"Student accuracy: {test_accuracy_light_ce:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Knowledge distillation run\n==========================\n\nNow let\\'s try to improve the test accuracy of the student network by\nincorporating the teacher. Knowledge distillation is a straightforward\ntechnique to achieve this, based on the fact that both networks output a\nprobability distribution over our classes. Therefore, the two networks\nshare the same number of output neurons. The method works by\nincorporating an additional loss into the traditional cross entropy\nloss, which is based on the softmax output of the teacher network. The\nassumption is that the output activations of a properly trained teacher\nnetwork carry additional information that can be leveraged by a student\nnetwork during training. The original work suggests that utilizing\nratios of smaller probabilities in the soft targets can help achieve the\nunderlying objective of deep neural networks, which is to create a\nsimilarity structure over the data where similar objects are mapped\ncloser together. For example, in CIFAR-10, a truck could be mistaken for\nan automobile or airplane, if its wheels are present, but it is less\nlikely to be mistaken for a dog. Therefore, it makes sense to assume\nthat valuable information resides not only in the top prediction of a\nproperly trained model but in the entire output distribution. However,\ncross entropy alone does not sufficiently exploit this information as\nthe activations for non-predicted classes tend to be so small that\npropagated gradients do not meaningfully change the weights to construct\nthis desirable vector space.\n\nAs we continue defining our first helper function that introduces a\nteacher-student dynamic, we need to include a few extra parameters:\n\n- `T`: Temperature controls the smoothness of the output\n distributions. Larger `T` leads to smoother distributions, thus\n smaller probabilities get a larger boost.\n- `soft_target_loss_weight`: A weight assigned to the extra objective\n we\\'re about to include.\n- `ce_loss_weight`: A weight assigned to cross-entropy. Tuning these\n weights pushes the network towards optimizing for either objective.\n\n![Distillation loss is calculated from the logits of the networks. It\nonly returns gradients to the\nstudent:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/distillation_output_loss.png){.align-center}\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train_knowledge_distillation(teacher, student, train_loader, epochs, learning_rate, T, soft_target_loss_weight, ce_loss_weight, device):\n ce_loss = nn.CrossEntropyLoss()\n optimizer = optim.Adam(student.parameters(), lr=learning_rate)\n\n teacher.eval() # Teacher set to evaluation mode\n student.train() # Student to train mode\n\n for epoch in range(epochs):\n running_loss = 0.0\n for inputs, labels in train_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n\n optimizer.zero_grad()\n\n # Forward pass with the teacher model - do not save gradients here as we do not change the teacher's weights\n with torch.no_grad():\n teacher_logits = teacher(inputs)\n\n # Forward pass with the student model\n student_logits = student(inputs)\n\n #Soften the student logits by applying softmax first and log() second\n soft_targets = nn.functional.softmax(teacher_logits / T, dim=-1)\n soft_prob = nn.functional.log_softmax(student_logits / T, dim=-1)\n\n # Calculate the soft targets loss. Scaled by T**2 as suggested by the authors of the paper \"Distilling the knowledge in a neural network\"\n soft_targets_loss = torch.sum(soft_targets * (soft_targets.log() - soft_prob)) / soft_prob.size()[0] * (T**2)\n\n # Calculate the true label loss\n label_loss = ce_loss(student_logits, labels)\n\n # Weighted sum of the two losses\n loss = soft_target_loss_weight * soft_targets_loss + ce_loss_weight * label_loss\n\n loss.backward()\n optimizer.step()\n\n running_loss += loss.item()\n\n print(f\"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}\")\n\n# Apply ``train_knowledge_distillation`` with a temperature of 2. Arbitrarily set the weights to 0.75 for CE and 0.25 for distillation loss.\ntrain_knowledge_distillation(teacher=nn_deep, student=new_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, T=2, soft_target_loss_weight=0.25, ce_loss_weight=0.75, device=device)\ntest_accuracy_light_ce_and_kd = test(new_nn_light, test_loader, device)\n\n# Compare the student test accuracy with and without the teacher, after distillation\nprint(f\"Teacher accuracy: {test_accuracy_deep:.2f}%\")\nprint(f\"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%\")\nprint(f\"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cosine loss minimization run\n============================\n\nFeel free to play around with the temperature parameter that controls\nthe softness of the softmax function and the loss coefficients. In\nneural networks, it is easy to include additional loss functions to the\nmain objectives to achieve goals like better generalization. Let\\'s try\nincluding an objective for the student, but now let\\'s focus on their\nhidden states rather than their output layers. Our goal is to convey\ninformation from the teacher\\'s representation to the student by\nincluding a naive loss function, whose minimization implies that the\nflattened vectors that are subsequently passed to the classifiers have\nbecome more *similar* as the loss decreases. Of course, the teacher does\nnot update its weights, so the minimization depends only on the\nstudent\\'s weights. The rationale behind this method is that we are\noperating under the assumption that the teacher model has a better\ninternal representation that is unlikely to be achieved by the student\nwithout external intervention, therefore we artificially push the\nstudent to mimic the internal representation of the teacher. Whether or\nnot this will end up helping the student is not straightforward, though,\nbecause pushing the lightweight network to reach this point could be a\ngood thing, assuming that we have found an internal representation that\nleads to better test accuracy, but it could also be harmful because the\nnetworks have different architectures and the student does not have the\nsame learning capacity as the teacher. In other words, there is no\nreason for these two vectors, the student\\'s and the teacher\\'s to match\nper component. The student could reach an internal representation that\nis a permutation of the teacher\\'s and it would be just as efficient.\nNonetheless, we can still run a quick experiment to figure out the\nimpact of this method. We will be using the `CosineEmbeddingLoss` which\nis given by the following formula:\n\n![Formula for\nCosineEmbeddingLoss](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/cosine_embedding_loss.png){.align-center\nwidth=\"450px\"}\n\nObviously, there is one thing that we need to resolve first. When we\napplied distillation to the output layer we mentioned that both networks\nhave the same number of neurons, equal to the number of classes.\nHowever, this is not the case for the layer following our convolutional\nlayers. Here, the teacher has more neurons than the student after the\nflattening of the final convolutional layer. Our loss function accepts\ntwo vectors of equal dimensionality as inputs, therefore we need to\nsomehow match them. We will solve this by including an average pooling\nlayer after the teacher\\'s convolutional layer to reduce its\ndimensionality to match that of the student.\n\nTo proceed, we will modify our model classes, or create new ones. Now,\nthe forward function returns not only the logits of the network but also\nthe flattened hidden representation after the convolutional layer. We\ninclude the aforementioned pooling for the modified teacher.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class ModifiedDeepNNCosine(nn.Module):\n def __init__(self, num_classes=10):\n super(ModifiedDeepNNCosine, self).__init__()\n self.features = nn.Sequential(\n nn.Conv2d(3, 128, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.Conv2d(128, 64, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n nn.Conv2d(64, 64, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.Conv2d(64, 32, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n )\n self.classifier = nn.Sequential(\n nn.Linear(2048, 512),\n nn.ReLU(),\n nn.Dropout(0.1),\n nn.Linear(512, num_classes)\n )\n\n def forward(self, x):\n x = self.features(x)\n flattened_conv_output = torch.flatten(x, 1)\n x = self.classifier(flattened_conv_output)\n flattened_conv_output_after_pooling = torch.nn.functional.avg_pool1d(flattened_conv_output, 2)\n return x, flattened_conv_output_after_pooling\n\n# Create a similar student class where we return a tuple. We do not apply pooling after flattening.\nclass ModifiedLightNNCosine(nn.Module):\n def __init__(self, num_classes=10):\n super(ModifiedLightNNCosine, self).__init__()\n self.features = nn.Sequential(\n nn.Conv2d(3, 16, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n nn.Conv2d(16, 16, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n )\n self.classifier = nn.Sequential(\n nn.Linear(1024, 256),\n nn.ReLU(),\n nn.Dropout(0.1),\n nn.Linear(256, num_classes)\n )\n\n def forward(self, x):\n x = self.features(x)\n flattened_conv_output = torch.flatten(x, 1)\n x = self.classifier(flattened_conv_output)\n return x, flattened_conv_output\n\n# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance\nmodified_nn_deep = ModifiedDeepNNCosine(num_classes=10).to(device)\nmodified_nn_deep.load_state_dict(nn_deep.state_dict())\n\n# Once again ensure the norm of the first layer is the same for both networks\nprint(\"Norm of 1st layer for deep_nn:\", torch.norm(nn_deep.features[0].weight).item())\nprint(\"Norm of 1st layer for modified_deep_nn:\", torch.norm(modified_nn_deep.features[0].weight).item())\n\n# Initialize a modified lightweight network with the same seed as our other lightweight instances. This will be trained from scratch to examine the effectiveness of cosine loss minimization.\ntorch.manual_seed(42)\nmodified_nn_light = ModifiedLightNNCosine(num_classes=10).to(device)\nprint(\"Norm of 1st layer:\", torch.norm(modified_nn_light.features[0].weight).item())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Naturally, we need to change the train loop because now the model\nreturns a tuple `(logits, hidden_representation)`. Using a sample input\ntensor we can print their shapes.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Create a sample input tensor\nsample_input = torch.randn(128, 3, 32, 32).to(device) # Batch size: 128, Filters: 3, Image size: 32x32\n\n# Pass the input through the student\nlogits, hidden_representation = modified_nn_light(sample_input)\n\n# Print the shapes of the tensors\nprint(\"Student logits shape:\", logits.shape) # batch_size x total_classes\nprint(\"Student hidden representation shape:\", hidden_representation.shape) # batch_size x hidden_representation_size\n\n# Pass the input through the teacher\nlogits, hidden_representation = modified_nn_deep(sample_input)\n\n# Print the shapes of the tensors\nprint(\"Teacher logits shape:\", logits.shape) # batch_size x total_classes\nprint(\"Teacher hidden representation shape:\", hidden_representation.shape) # batch_size x hidden_representation_size" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In our case, `hidden_representation_size` is `1024`. This is the\nflattened feature map of the final convolutional layer of the student\nand as you can see, it is the input for its classifier. It is `1024` for\nthe teacher too, because we made it so with `avg_pool1d` from `2048`.\nThe loss applied here only affects the weights of the student prior to\nthe loss calculation. In other words, it does not affect the classifier\nof the student. The modified training loop is the following:\n\n![In Cosine Loss minimization, we want to maximize the cosine similarity\nof the two representations by returning gradients to the\nstudent:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/cosine_loss_distillation.png){.align-center}\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train_cosine_loss(teacher, student, train_loader, epochs, learning_rate, hidden_rep_loss_weight, ce_loss_weight, device):\n ce_loss = nn.CrossEntropyLoss()\n cosine_loss = nn.CosineEmbeddingLoss()\n optimizer = optim.Adam(student.parameters(), lr=learning_rate)\n\n teacher.to(device)\n student.to(device)\n teacher.eval() # Teacher set to evaluation mode\n student.train() # Student to train mode\n\n for epoch in range(epochs):\n running_loss = 0.0\n for inputs, labels in train_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n\n optimizer.zero_grad()\n\n # Forward pass with the teacher model and keep only the hidden representation\n with torch.no_grad():\n _, teacher_hidden_representation = teacher(inputs)\n\n # Forward pass with the student model\n student_logits, student_hidden_representation = student(inputs)\n\n # Calculate the cosine loss. Target is a vector of ones. From the loss formula above we can see that is the case where loss minimization leads to cosine similarity increase.\n hidden_rep_loss = cosine_loss(student_hidden_representation, teacher_hidden_representation, target=torch.ones(inputs.size(0)).to(device))\n\n # Calculate the true label loss\n label_loss = ce_loss(student_logits, labels)\n\n # Weighted sum of the two losses\n loss = hidden_rep_loss_weight * hidden_rep_loss + ce_loss_weight * label_loss\n\n loss.backward()\n optimizer.step()\n\n running_loss += loss.item()\n\n print(f\"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We need to modify our test function for the same reason. Here we ignore\nthe hidden representation returned by the model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def test_multiple_outputs(model, test_loader, device):\n model.to(device)\n model.eval()\n\n correct = 0\n total = 0\n\n with torch.no_grad():\n for inputs, labels in test_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n\n outputs, _ = model(inputs) # Disregard the second tensor of the tuple\n _, predicted = torch.max(outputs.data, 1)\n\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n\n accuracy = 100 * correct / total\n print(f\"Test Accuracy: {accuracy:.2f}%\")\n return accuracy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this case, we could easily include both knowledge distillation and\ncosine loss minimization in the same function. It is common to combine\nmethods to achieve better performance in teacher-student paradigms. For\nnow, we can run a simple train-test session.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Train and test the lightweight network with cross entropy loss\ntrain_cosine_loss(teacher=modified_nn_deep, student=modified_nn_light, train_loader=train_loader, epochs=10, learning_rate=0.001, hidden_rep_loss_weight=0.25, ce_loss_weight=0.75, device=device)\ntest_accuracy_light_ce_and_cosine_loss = test_multiple_outputs(modified_nn_light, test_loader, device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Intermediate regressor run\n==========================\n\nOur naive minimization does not guarantee better results for several\nreasons, one being the dimensionality of the vectors. Cosine similarity\ngenerally works better than Euclidean distance for vectors of higher\ndimensionality, but we were dealing with vectors with 1024 components\neach, so it is much harder to extract meaningful similarities.\nFurthermore, as we mentioned, pushing towards a match of the hidden\nrepresentation of the teacher and the student is not supported by\ntheory. There are no good reasons why we should be aiming for a 1:1\nmatch of these vectors. We will provide a final example of training\nintervention by including an extra network called regressor. The\nobjective is to first extract the feature map of the teacher after a\nconvolutional layer, then extract a feature map of the student after a\nconvolutional layer, and finally try to match these maps. However, this\ntime, we will introduce a regressor between the networks to facilitate\nthe matching process. The regressor will be trainable and ideally will\ndo a better job than our naive cosine loss minimization scheme. Its main\njob is to match the dimensionality of these feature maps so that we can\nproperly define a loss function between the teacher and the student.\nDefining such a loss function provides a teaching \\\"path,\\\" which is\nbasically a flow to back-propagate gradients that will change the\nstudent\\'s weights. Focusing on the output of the convolutional layers\nright before each classifier for our original networks, we have the\nfollowing shapes:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Pass the sample input only from the convolutional feature extractor\nconvolutional_fe_output_student = nn_light.features(sample_input)\nconvolutional_fe_output_teacher = nn_deep.features(sample_input)\n\n# Print their shapes\nprint(\"Student's feature extractor output shape: \", convolutional_fe_output_student.shape)\nprint(\"Teacher's feature extractor output shape: \", convolutional_fe_output_teacher.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We have 32 filters for the teacher and 16 filters for the student. We\nwill include a trainable layer that converts the feature map of the\nstudent to the shape of the feature map of the teacher. In practice, we\nmodify the lightweight class to return the hidden state after an\nintermediate regressor that matches the sizes of the convolutional\nfeature maps and the teacher class to return the output of the final\nconvolutional layer without pooling or flattening.\n\n![The trainable layer matches the shapes of the intermediate tensors and\nMean Squared Error (MSE) is properly\ndefined:](https://pytorch.org/tutorials//../_static/img/knowledge_distillation/fitnets_knowledge_distill.png){.align-center}\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class ModifiedDeepNNRegressor(nn.Module):\n def __init__(self, num_classes=10):\n super(ModifiedDeepNNRegressor, self).__init__()\n self.features = nn.Sequential(\n nn.Conv2d(3, 128, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.Conv2d(128, 64, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n nn.Conv2d(64, 64, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.Conv2d(64, 32, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n )\n self.classifier = nn.Sequential(\n nn.Linear(2048, 512),\n nn.ReLU(),\n nn.Dropout(0.1),\n nn.Linear(512, num_classes)\n )\n\n def forward(self, x):\n x = self.features(x)\n conv_feature_map = x\n x = torch.flatten(x, 1)\n x = self.classifier(x)\n return x, conv_feature_map\n\nclass ModifiedLightNNRegressor(nn.Module):\n def __init__(self, num_classes=10):\n super(ModifiedLightNNRegressor, self).__init__()\n self.features = nn.Sequential(\n nn.Conv2d(3, 16, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n nn.Conv2d(16, 16, kernel_size=3, padding=1),\n nn.ReLU(),\n nn.MaxPool2d(kernel_size=2, stride=2),\n )\n # Include an extra regressor (in our case linear)\n self.regressor = nn.Sequential(\n nn.Conv2d(16, 32, kernel_size=3, padding=1)\n )\n self.classifier = nn.Sequential(\n nn.Linear(1024, 256),\n nn.ReLU(),\n nn.Dropout(0.1),\n nn.Linear(256, num_classes)\n )\n\n def forward(self, x):\n x = self.features(x)\n regressor_output = self.regressor(x)\n x = torch.flatten(x, 1)\n x = self.classifier(x)\n return x, regressor_output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After that, we have to update our train loop again. This time, we\nextract the regressor output of the student, the feature map of the\nteacher, we calculate the `MSE` on these tensors (they have the exact\nsame shape so it\\'s properly defined) and we back propagate gradients\nbased on that loss, in addition to the regular cross entropy loss of the\nclassification task.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train_mse_loss(teacher, student, train_loader, epochs, learning_rate, feature_map_weight, ce_loss_weight, device):\n ce_loss = nn.CrossEntropyLoss()\n mse_loss = nn.MSELoss()\n optimizer = optim.Adam(student.parameters(), lr=learning_rate)\n\n teacher.to(device)\n student.to(device)\n teacher.eval() # Teacher set to evaluation mode\n student.train() # Student to train mode\n\n for epoch in range(epochs):\n running_loss = 0.0\n for inputs, labels in train_loader:\n inputs, labels = inputs.to(device), labels.to(device)\n\n optimizer.zero_grad()\n\n # Again ignore teacher logits\n with torch.no_grad():\n _, teacher_feature_map = teacher(inputs)\n\n # Forward pass with the student model\n student_logits, regressor_feature_map = student(inputs)\n\n # Calculate the loss\n hidden_rep_loss = mse_loss(regressor_feature_map, teacher_feature_map)\n\n # Calculate the true label loss\n label_loss = ce_loss(student_logits, labels)\n\n # Weighted sum of the two losses\n loss = feature_map_weight * hidden_rep_loss + ce_loss_weight * label_loss\n\n loss.backward()\n optimizer.step()\n\n running_loss += loss.item()\n\n print(f\"Epoch {epoch+1}/{epochs}, Loss: {running_loss / len(train_loader)}\")\n\n# Notice how our test function remains the same here with the one we used in our previous case. We only care about the actual outputs because we measure accuracy.\n\n# Initialize a ModifiedLightNNRegressor\ntorch.manual_seed(42)\nmodified_nn_light_reg = ModifiedLightNNRegressor(num_classes=10).to(device)\n\n# We do not have to train the modified deep network from scratch of course, we just load its weights from the trained instance\nmodified_nn_deep_reg = ModifiedDeepNNRegressor(num_classes=10).to(device)\nmodified_nn_deep_reg.load_state_dict(nn_deep.state_dict())\n\n# Train and test once again\ntrain_mse_loss(teacher=modified_nn_deep_reg, student=modified_nn_light_reg, train_loader=train_loader, epochs=10, learning_rate=0.001, feature_map_weight=0.25, ce_loss_weight=0.75, device=device)\ntest_accuracy_light_ce_and_mse_loss = test_multiple_outputs(modified_nn_light_reg, test_loader, device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is expected that the final method will work better than `CosineLoss`\nbecause now we have allowed a trainable layer between the teacher and\nthe student, which gives the student some wiggle room when it comes to\nlearning, rather than pushing the student to copy the teacher\\'s\nrepresentation. Including the extra network is the idea behind\nhint-based distillation.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(f\"Teacher accuracy: {test_accuracy_deep:.2f}%\")\nprint(f\"Student accuracy without teacher: {test_accuracy_light_ce:.2f}%\")\nprint(f\"Student accuracy with CE + KD: {test_accuracy_light_ce_and_kd:.2f}%\")\nprint(f\"Student accuracy with CE + CosineLoss: {test_accuracy_light_ce_and_cosine_loss:.2f}%\")\nprint(f\"Student accuracy with CE + RegressorMSE: {test_accuracy_light_ce_and_mse_loss:.2f}%\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nNone of the methods above increases the number of parameters for the\nnetwork or inference time, so the performance increase comes at the\nlittle cost of calculating gradients during training. In ML\napplications, we mostly care about inference time because training\nhappens before the model deployment. If our lightweight model is still\ntoo heavy for deployment, we can apply different ideas, such as\npost-training quantization. Additional losses can be applied in many\ntasks, not just classification, and you can experiment with quantities\nlike coefficients, temperature, or number of neurons. Feel free to tune\nany numbers in the tutorial above, but keep in mind, if you change the\nnumber of neurons / filters chances are a shape mismatch might occur.\n\nFor more information, see:\n\n- [Hinton, G., Vinyals, O., Dean, J.: Distilling the knowledge in a\n neural network. In: Neural Information Processing System Deep\n Learning Workshop (2015)](https://arxiv.org/abs/1503.02531)\n- [Romero, A., Ballas, N., Kahou, S.E., Chassang, A., Gatta, C.,\n Bengio, Y.: Fitnets: Hints for thin deep nets. In: Proceedings of\n the International Conference on Learning\n Representations (2015)](https://arxiv.org/abs/1412.6550)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }