{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hyperparameter tuning with Ray Tune\n===================================\n\nHyperparameter tuning can make the difference between an average model\nand a highly accurate one. Often simple things like choosing a different\nlearning rate or changing a network layer size can have a dramatic\nimpact on your model performance.\n\nFortunately, there are tools that help with finding the best combination\nof parameters. [Ray Tune](https://docs.ray.io/en/latest/tune.html) is an\nindustry standard tool for distributed hyperparameter tuning. Ray Tune\nincludes the latest hyperparameter search algorithms, integrates with\nvarious analysis libraries, and natively supports distributed training\nthrough [Ray\\'s distributed machine learning engine](https://ray.io/).\n\nIn this tutorial, we will show you how to integrate Ray Tune into your\nPyTorch training workflow. We will extend [this tutorial from the\nPyTorch\ndocumentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html)\nfor training a CIFAR10 image classifier.\n\nAs you will see, we only need to add some slight modifications. In\nparticular, we need to\n\n1. wrap data loading and training in functions,\n2. make some network parameters configurable,\n3. add checkpointing (optional),\n4. and define the search space for the model tuning\n\n| \n\nTo run this tutorial, please make sure the following packages are\ninstalled:\n\n- `ray[tune]`: Distributed hyperparameter tuning library\n- `torchvision`: For the data transformers\n\nSetup / Imports\n---------------\n\nLet\\'s start with the imports:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from functools import partial\nimport os\nimport tempfile\nfrom pathlib import Path\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nfrom torch.utils.data import random_split\nimport torchvision\nimport torchvision.transforms as transforms\nfrom ray import tune\nfrom ray import train\nfrom ray.train import Checkpoint, get_checkpoint\nfrom ray.tune.schedulers import ASHAScheduler\nimport ray.cloudpickle as pickle" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Most of the imports are needed for building the PyTorch model. Only the\nlast imports are for Ray Tune.\n\nData loaders\n============\n\nWe wrap the data loaders in their own function and pass a global data\ndirectory. This way we can share a data directory between different\ntrials.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def load_data(data_dir=\"./data\"):\n transform = transforms.Compose(\n [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]\n )\n\n trainset = torchvision.datasets.CIFAR10(\n root=data_dir, train=True, download=True, transform=transform\n )\n\n testset = torchvision.datasets.CIFAR10(\n root=data_dir, train=False, download=True, transform=transform\n )\n\n return trainset, testset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Configurable neural network\n===========================\n\nWe can only tune those parameters that are configurable. In this\nexample, we can specify the layer sizes of the fully connected layers:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class Net(nn.Module):\n def __init__(self, l1=120, l2=84):\n super(Net, self).__init__()\n self.conv1 = nn.Conv2d(3, 6, 5)\n self.pool = nn.MaxPool2d(2, 2)\n self.conv2 = nn.Conv2d(6, 16, 5)\n self.fc1 = nn.Linear(16 * 5 * 5, l1)\n self.fc2 = nn.Linear(l1, l2)\n self.fc3 = nn.Linear(l2, 10)\n\n def forward(self, x):\n x = self.pool(F.relu(self.conv1(x)))\n x = self.pool(F.relu(self.conv2(x)))\n x = torch.flatten(x, 1) # flatten all dimensions except batch\n x = F.relu(self.fc1(x))\n x = F.relu(self.fc2(x))\n x = self.fc3(x)\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The train function\n==================\n\nNow it gets interesting, because we introduce some changes to the\nexample [from the PyTorch\ndocumentation](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html).\n\nWe wrap the training script in a function\n`train_cifar(config, data_dir=None)`. The `config` parameter will\nreceive the hyperparameters we would like to train with. The `data_dir`\nspecifies the directory where we load and store the data, so that\nmultiple runs can share the same data source. We also load the model and\noptimizer state at the start of the run, if a checkpoint is provided.\nFurther down in this tutorial you will find information on how to save\nthe checkpoint and what it is used for.\n\n``` {.python}\nnet = Net(config[\"l1\"], config[\"l2\"])\n\ncheckpoint = get_checkpoint()\nif checkpoint:\n with checkpoint.as_directory() as checkpoint_dir:\n data_path = Path(checkpoint_dir) / \"data.pkl\"\n with open(data_path, \"rb\") as fp:\n checkpoint_state = pickle.load(fp)\n start_epoch = checkpoint_state[\"epoch\"]\n net.load_state_dict(checkpoint_state[\"net_state_dict\"])\n optimizer.load_state_dict(checkpoint_state[\"optimizer_state_dict\"])\nelse:\n start_epoch = 0\n```\n\nThe learning rate of the optimizer is made configurable, too:\n\n``` {.python}\noptimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9)\n```\n\nWe also split the training data into a training and validation subset.\nWe thus train on 80% of the data and calculate the validation loss on\nthe remaining 20%. The batch sizes with which we iterate through the\ntraining and test sets are configurable as well.\n\nAdding (multi) GPU support with DataParallel\n--------------------------------------------\n\nImage classification benefits largely from GPUs. Luckily, we can\ncontinue to use PyTorch\\'s abstractions in Ray Tune. Thus, we can wrap\nour model in `nn.DataParallel` to support data parallel training on\nmultiple GPUs:\n\n``` {.python}\ndevice = \"cpu\"\nif torch.cuda.is_available():\n device = \"cuda:0\"\n if torch.cuda.device_count() > 1:\n net = nn.DataParallel(net)\nnet.to(device)\n```\n\nBy using a `device` variable we make sure that training also works when\nwe have no GPUs available. PyTorch requires us to send our data to the\nGPU memory explicitly, like this:\n\n``` {.python}\nfor i, data in enumerate(trainloader, 0):\n inputs, labels = data\n inputs, labels = inputs.to(device), labels.to(device)\n```\n\nThe code now supports training on CPUs, on a single GPU, and on multiple\nGPUs. Notably, Ray also supports [fractional\nGPUs](https://docs.ray.io/en/master/using-ray-with-gpus.html#fractional-gpus)\nso we can share GPUs among trials, as long as the model still fits on\nthe GPU memory. We\\'ll come back to that later.\n\nCommunicating with Ray Tune\n---------------------------\n\nThe most interesting part is the communication with Ray Tune:\n\n``` {.python}\ncheckpoint_data = {\n \"epoch\": epoch,\n \"net_state_dict\": net.state_dict(),\n \"optimizer_state_dict\": optimizer.state_dict(),\n}\nwith tempfile.TemporaryDirectory() as checkpoint_dir:\n data_path = Path(checkpoint_dir) / \"data.pkl\"\n with open(data_path, \"wb\") as fp:\n pickle.dump(checkpoint_data, fp)\n\n checkpoint = Checkpoint.from_directory(checkpoint_dir)\n train.report(\n {\"loss\": val_loss / val_steps, \"accuracy\": correct / total},\n checkpoint=checkpoint,\n )\n```\n\nHere we first save a checkpoint and then report some metrics back to Ray\nTune. Specifically, we send the validation loss and accuracy back to Ray\nTune. Ray Tune can then use these metrics to decide which hyperparameter\nconfiguration lead to the best results. These metrics can also be used\nto stop bad performing trials early in order to avoid wasting resources\non those trials.\n\nThe checkpoint saving is optional, however, it is necessary if we wanted\nto use advanced schedulers like [Population Based\nTraining](https://docs.ray.io/en/latest/tune/examples/pbt_guide.html).\nAlso, by saving the checkpoint we can later load the trained models and\nvalidate them on a test set. Lastly, saving checkpoints is useful for\nfault tolerance, and it allows us to interrupt training and continue\ntraining later.\n\nFull training function\n----------------------\n\nThe full code example looks like this:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def train_cifar(config, data_dir=None):\n net = Net(config[\"l1\"], config[\"l2\"])\n\n device = \"cpu\"\n if torch.cuda.is_available():\n device = \"cuda:0\"\n if torch.cuda.device_count() > 1:\n net = nn.DataParallel(net)\n net.to(device)\n\n criterion = nn.CrossEntropyLoss()\n optimizer = optim.SGD(net.parameters(), lr=config[\"lr\"], momentum=0.9)\n\n checkpoint = get_checkpoint()\n if checkpoint:\n with checkpoint.as_directory() as checkpoint_dir:\n data_path = Path(checkpoint_dir) / \"data.pkl\"\n with open(data_path, \"rb\") as fp:\n checkpoint_state = pickle.load(fp)\n start_epoch = checkpoint_state[\"epoch\"]\n net.load_state_dict(checkpoint_state[\"net_state_dict\"])\n optimizer.load_state_dict(checkpoint_state[\"optimizer_state_dict\"])\n else:\n start_epoch = 0\n\n trainset, testset = load_data(data_dir)\n\n test_abs = int(len(trainset) * 0.8)\n train_subset, val_subset = random_split(\n trainset, [test_abs, len(trainset) - test_abs]\n )\n\n trainloader = torch.utils.data.DataLoader(\n train_subset, batch_size=int(config[\"batch_size\"]), shuffle=True, num_workers=8\n )\n valloader = torch.utils.data.DataLoader(\n val_subset, batch_size=int(config[\"batch_size\"]), shuffle=True, num_workers=8\n )\n\n for epoch in range(start_epoch, 10): # loop over the dataset multiple times\n running_loss = 0.0\n epoch_steps = 0\n for i, data in enumerate(trainloader, 0):\n # get the inputs; data is a list of [inputs, labels]\n inputs, labels = data\n inputs, labels = inputs.to(device), labels.to(device)\n\n # zero the parameter gradients\n optimizer.zero_grad()\n\n # forward + backward + optimize\n outputs = net(inputs)\n loss = criterion(outputs, labels)\n loss.backward()\n optimizer.step()\n\n # print statistics\n running_loss += loss.item()\n epoch_steps += 1\n if i % 2000 == 1999: # print every 2000 mini-batches\n print(\n \"[%d, %5d] loss: %.3f\"\n % (epoch + 1, i + 1, running_loss / epoch_steps)\n )\n running_loss = 0.0\n\n # Validation loss\n val_loss = 0.0\n val_steps = 0\n total = 0\n correct = 0\n for i, data in enumerate(valloader, 0):\n with torch.no_grad():\n inputs, labels = data\n inputs, labels = inputs.to(device), labels.to(device)\n\n outputs = net(inputs)\n _, predicted = torch.max(outputs.data, 1)\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n\n loss = criterion(outputs, labels)\n val_loss += loss.cpu().numpy()\n val_steps += 1\n\n checkpoint_data = {\n \"epoch\": epoch,\n \"net_state_dict\": net.state_dict(),\n \"optimizer_state_dict\": optimizer.state_dict(),\n }\n with tempfile.TemporaryDirectory() as checkpoint_dir:\n data_path = Path(checkpoint_dir) / \"data.pkl\"\n with open(data_path, \"wb\") as fp:\n pickle.dump(checkpoint_data, fp)\n\n checkpoint = Checkpoint.from_directory(checkpoint_dir)\n train.report(\n {\"loss\": val_loss / val_steps, \"accuracy\": correct / total},\n checkpoint=checkpoint,\n )\n \n print(\"Finished Training\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, most of the code is adapted directly from the original\nexample.\n\nTest set accuracy\n=================\n\nCommonly the performance of a machine learning model is tested on a\nhold-out test set with data that has not been used for training the\nmodel. We also wrap this in a function:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def test_accuracy(net, device=\"cpu\"):\n trainset, testset = load_data()\n\n testloader = torch.utils.data.DataLoader(\n testset, batch_size=4, shuffle=False, num_workers=2\n )\n\n correct = 0\n total = 0\n with torch.no_grad():\n for data in testloader:\n images, labels = data\n images, labels = images.to(device), labels.to(device)\n outputs = net(images)\n _, predicted = torch.max(outputs.data, 1)\n total += labels.size(0)\n correct += (predicted == labels).sum().item()\n\n return correct / total" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The function also expects a `device` parameter, so we can do the test\nset validation on a GPU.\n\nConfiguring the search space\n============================\n\nLastly, we need to define Ray Tune\\'s search space. Here is an example:\n\n``` {.python}\nconfig = {\n \"l1\": tune.choice([2 ** i for i in range(9)]),\n \"l2\": tune.choice([2 ** i for i in range(9)]),\n \"lr\": tune.loguniform(1e-4, 1e-1),\n \"batch_size\": tune.choice([2, 4, 8, 16])\n}\n```\n\nThe `tune.choice()` accepts a list of values that are uniformly sampled\nfrom. In this example, the `l1` and `l2` parameters should be powers of\n2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256. The `lr`\n(learning rate) should be uniformly sampled between 0.0001 and 0.1.\nLastly, the batch size is a choice between 2, 4, 8, and 16.\n\nAt each trial, Ray Tune will now randomly sample a combination of\nparameters from these search spaces. It will then train a number of\nmodels in parallel and find the best performing one among these. We also\nuse the `ASHAScheduler` which will terminate bad performing trials\nearly.\n\nWe wrap the `train_cifar` function with `functools.partial` to set the\nconstant `data_dir` parameter. We can also tell Ray Tune what resources\nshould be available for each trial:\n\n``` {.python}\ngpus_per_trial = 2\n# ...\nresult = tune.run(\n partial(train_cifar, data_dir=data_dir),\n resources_per_trial={\"cpu\": 8, \"gpu\": gpus_per_trial},\n config=config,\n num_samples=num_samples,\n scheduler=scheduler,\n checkpoint_at_end=True)\n```\n\nYou can specify the number of CPUs, which are then available e.g. to\nincrease the `num_workers` of the PyTorch `DataLoader` instances. The\nselected number of GPUs are made visible to PyTorch in each trial.\nTrials do not have access to GPUs that haven\\'t been requested for them\n- so you don\\'t have to care about two trials using the same set of\nresources.\n\nHere we can also specify fractional GPUs, so something like\n`gpus_per_trial=0.5` is completely valid. The trials will then share\nGPUs among each other. You just have to make sure that the models still\nfit in the GPU memory.\n\nAfter training the models, we will find the best performing one and load\nthe trained network from the checkpoint file. We then obtain the test\nset accuracy and report everything by printing.\n\nThe full main function looks like this:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):\n data_dir = os.path.abspath(\"./data\")\n load_data(data_dir)\n config = {\n \"l1\": tune.choice([2**i for i in range(9)]),\n \"l2\": tune.choice([2**i for i in range(9)]),\n \"lr\": tune.loguniform(1e-4, 1e-1),\n \"batch_size\": tune.choice([2, 4, 8, 16]),\n }\n scheduler = ASHAScheduler(\n metric=\"loss\",\n mode=\"min\",\n max_t=max_num_epochs,\n grace_period=1,\n reduction_factor=2,\n )\n result = tune.run(\n partial(train_cifar, data_dir=data_dir),\n resources_per_trial={\"cpu\": 2, \"gpu\": gpus_per_trial},\n config=config,\n num_samples=num_samples,\n scheduler=scheduler,\n )\n\n best_trial = result.get_best_trial(\"loss\", \"min\", \"last\")\n print(f\"Best trial config: {best_trial.config}\")\n print(f\"Best trial final validation loss: {best_trial.last_result['loss']}\")\n print(f\"Best trial final validation accuracy: {best_trial.last_result['accuracy']}\")\n\n best_trained_model = Net(best_trial.config[\"l1\"], best_trial.config[\"l2\"])\n device = \"cpu\"\n if torch.cuda.is_available():\n device = \"cuda:0\"\n if gpus_per_trial > 1:\n best_trained_model = nn.DataParallel(best_trained_model)\n best_trained_model.to(device)\n\n best_checkpoint = result.get_best_checkpoint(trial=best_trial, metric=\"accuracy\", mode=\"max\")\n with best_checkpoint.as_directory() as checkpoint_dir:\n data_path = Path(checkpoint_dir) / \"data.pkl\"\n with open(data_path, \"rb\") as fp:\n best_checkpoint_data = pickle.load(fp)\n\n best_trained_model.load_state_dict(best_checkpoint_data[\"net_state_dict\"])\n test_acc = test_accuracy(best_trained_model, device)\n print(\"Best trial test set accuracy: {}\".format(test_acc))\n\n\nif __name__ == \"__main__\":\n # You can change the number of GPUs per trial here:\n main(num_samples=10, max_num_epochs=10, gpus_per_trial=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you run the code, an example output could look like this:\n\n``` {.sh}\nNumber of trials: 10/10 (10 TERMINATED)\n+-----+--------------+------+------+-------------+--------+---------+------------+\n| ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |\n|-----+--------------+------+------+-------------+--------+---------+------------|\n| ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |\n| ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |\n| ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |\n| ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |\n| ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |\n| ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |\n| ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |\n| ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |\n| ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |\n| ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |\n+-----+--------------+------+------+-------------+--------+---------+------------+\n\nBest trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}\nBest trial final validation loss: 1.5310075663924216\nBest trial final validation accuracy: 0.4761\nBest trial test set accuracy: 0.4737\n```\n\nMost trials have been stopped early in order to avoid wasting resources.\nThe best performing trial achieved a validation accuracy of about 47%,\nwhich could be confirmed on the test set.\n\nSo that\\'s it! You can now tune the parameters of your PyTorch models.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }