{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://pytorch.org/tutorials/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Recurrent DQN: Training recurrent policies\n==========================================\n\n**Author**: [Vincent Moens](https://github.com/vmoens)\n\n```{=html}\n

What you will learn

Prerequisites

\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overview\n========\n\nMemory-based policies are crucial not only when the observations are\npartially observable but also when the time dimension must be taken into\naccount to make informed decisions.\n\nRecurrent neural network have long been a popular tool for memory-based\npolicies. The idea is to keep a recurrent state in memory between two\nconsecutive steps, and use this as an input to the policy along with the\ncurrent observation.\n\nThis tutorial shows how to incorporate an RNN in a policy using TorchRL.\n\nKey learnings:\n\n- Incorporating an RNN in an actor in TorchRL;\n- Using that memory-based policy with a replay buffer and a loss\n module.\n\nThe core idea of using RNNs in TorchRL is to use TensorDict as a data\ncarrier for the hidden states from one step to another. We\\'ll build a\npolicy that reads the previous recurrent state from the current\nTensorDict, and writes the current recurrent states in the TensorDict of\nthe next state:\n\n![](https://pytorch.org/tutorials/_static/img/rollout_recurrent.png)\n\nAs this figure shows, our environment populates the TensorDict with\nzeroed recurrent states which are read by the policy together with the\nobservation to produce an action, and recurrent states that will be used\nfor the next step. When the\n`~torchrl.envs.utils.step_mdp`{.interpreted-text role=\"func\"} function\nis called, the recurrent states from the next state are brought to the\ncurrent TensorDict. Let\\'s see how this is implemented in practice.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If you are running this in Google Colab, make sure you install the\nfollowing dependencies:\n\n``` {.bash}\n!pip3 install torchrl\n!pip3 install gym[mujoco]\n!pip3 install tqdm\n```\n\nSetup\n=====\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport tqdm\nfrom tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq\nfrom torch import nn\nfrom torchrl.collectors import SyncDataCollector\nfrom torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer\nfrom torchrl.envs import (\n Compose,\n ExplorationType,\n GrayScale,\n InitTracker,\n ObservationNorm,\n Resize,\n RewardScaling,\n set_exploration_type,\n StepCounter,\n ToTensorImage,\n TransformedEnv,\n)\nfrom torchrl.envs.libs.gym import GymEnv\nfrom torchrl.modules import ConvNet, EGreedyModule, LSTMModule, MLP, QValueModule\nfrom torchrl.objectives import DQNLoss, SoftUpdate\n\nis_fork = multiprocessing.get_start_method() == \"fork\"\ndevice = (\n torch.device(0)\n if torch.cuda.is_available() and not is_fork\n else torch.device(\"cpu\")\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Environment\n===========\n\nAs usual, the first step is to build our environment: it helps us define\nthe problem and build the policy network accordingly. For this tutorial,\nwe\\'ll be running a single pixel-based instance of the CartPole gym\nenvironment with some custom transforms: turning to grayscale, resizing\nto 84x84, scaling down the rewards and normalizing the observations.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

The ~torchrl.envs.transforms.StepCounter transform is accessory. Since the CartPoletask goal is to make trajectories as long as possible, counting the stepscan help us track the performance of our policy.

\n```\n```{=html}\n
\n```\nTwo transforms are important for the purpose of this tutorial:\n\n- `~torchrl.envs.transforms.InitTracker`{.interpreted-text\n role=\"class\"} will stamp the calls to\n `~torchrl.envs.EnvBase.reset`{.interpreted-text role=\"meth\"} by\n adding a `\"is_init\"` boolean mask in the TensorDict that will track\n which steps require a reset of the RNN hidden states.\n- The `~torchrl.envs.transforms.TensorDictPrimer`{.interpreted-text\n role=\"class\"} transform is a bit more technical. It is not required\n to use RNN policies. However, it instructs the environment (and\n subsequently the collector) that some extra keys are to be expected.\n Once added, a call to [env.reset()]{.title-ref} will populate the\n entries indicated in the primer with zeroed tensors. Knowing that\n these tensors are expected by the policy, the collector will pass\n them on during collection. Eventually, we\\'ll be storing our hidden\n states in the replay buffer, which will help us bootstrap the\n computation of the RNN operations in the loss module (which would\n otherwise be initiated with 0s). In summary: not including this\n transform will not impact hugely the training of our policy, but it\n will make the recurrent keys disappear from the collected data and\n the replay buffer, which will in turn lead to a slightly less\n optimal training. Fortunately, the\n `~torchrl.modules.LSTMModule`{.interpreted-text role=\"class\"} we\n propose is equipped with a helper method to build just that\n transform for us, so we can wait until we build it!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "env = TransformedEnv(\n GymEnv(\"CartPole-v1\", from_pixels=True, device=device),\n Compose(\n ToTensorImage(),\n GrayScale(),\n Resize(84, 84),\n StepCounter(),\n InitTracker(),\n RewardScaling(loc=0.0, scale=0.1),\n ObservationNorm(standard_normal=True, in_keys=[\"pixels\"]),\n ),\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As always, we need to initialize manually our normalization constants:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "env.transform[-1].init_stats(1000, reduce_dim=[0, 1, 2], cat_dim=0, keep_dims=[0])\ntd = env.reset()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Policy\n======\n\nOur policy will have 3 components: a\n`~torchrl.modules.ConvNet`{.interpreted-text role=\"class\"} backbone, an\n`~torchrl.modules.LSTMModule`{.interpreted-text role=\"class\"} memory\nlayer and a shallow `~torchrl.modules.MLP`{.interpreted-text\nrole=\"class\"} block that will map the LSTM output onto the action\nvalues.\n\nConvolutional network\n---------------------\n\nWe build a convolutional network flanked with a\n`torch.nn.AdaptiveAvgPool2d`{.interpreted-text role=\"class\"} that will\nsquash the output in a vector of size 64. The\n`~torchrl.modules.ConvNet`{.interpreted-text role=\"class\"} can assist us\nwith this:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "feature = Mod(\n ConvNet(\n num_cells=[32, 32, 64],\n squeeze_output=True,\n aggregator_class=nn.AdaptiveAvgPool2d,\n aggregator_kwargs={\"output_size\": (1, 1)},\n device=device,\n ),\n in_keys=[\"pixels\"],\n out_keys=[\"embed\"],\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "we execute the first module on a batch of data to gather the size of the\noutput vector:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "n_cells = feature(env.reset())[\"embed\"].shape[-1]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "LSTM Module\n===========\n\nTorchRL provides a specialized\n`~torchrl.modules.LSTMModule`{.interpreted-text role=\"class\"} class to\nincorporate LSTMs in your code-base. It is a\n`~tensordict.nn.TensorDictModuleBase`{.interpreted-text role=\"class\"}\nsubclass: as such, it has a set of `in_keys` and `out_keys` that\nindicate what values should be expected to be read and written/updated\nduring the execution of the module. The class comes with customizable\npredefined values for these attributes to facilitate its construction.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

: The class supports almost all LSTM features such asdropout or multi-layered LSTMs.However, to respect TorchRL's conventions, this LSTM must have the batch_firstattribute set to True which is the default in PyTorch. However,our ~torchrl.modules.LSTMModule changes this defaultbehavior, so we're good with a native call.Also, the LSTM cannot have a bidirectional attribute set to True asthis wouldn't be usable in online settings. In this case, the default valueis the correct one.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "lstm = LSTMModule(\n input_size=n_cells,\n hidden_size=128,\n device=device,\n in_key=\"embed\",\n out_key=\"embed\",\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us look at the LSTM Module class, specifically its in and out\\_keys:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"in_keys\", lstm.in_keys)\nprint(\"out_keys\", lstm.out_keys)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that these values contain the key we indicated as the in\\_key\n(and out\\_key) as well as recurrent key names. The out\\_keys are\npreceded by a \\\"next\\\" prefix that indicates that they will need to be\nwritten in the \\\"next\\\" TensorDict. We use this convention (which can be\noverridden by passing the in\\_keys/out\\_keys arguments) to make sure\nthat a call to `~torchrl.envs.utils.step_mdp`{.interpreted-text\nrole=\"func\"} will move the recurrent state to the root TensorDict,\nmaking it available to the RNN during the following call (see figure in\nthe intro).\n\nAs mentioned earlier, we have one more optional transform to add to our\nenvironment to make sure that the recurrent states are passed to the\nbuffer. The\n`~torchrl.modules.LSTMModule.make_tensordict_primer`{.interpreted-text\nrole=\"meth\"} method does exactly that:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "env.append_transform(lstm.make_tensordict_primer())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and that\\'s it! We can print the environment to check that everything\nlooks good now that we have added the primer:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(env)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MLP\n===\n\nWe use a single-layer MLP to represent the action values we\\'ll be using\nfor our policy.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mlp = MLP(\n out_features=2,\n num_cells=[\n 64,\n ],\n device=device,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "and fill the bias with zeros:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mlp[-1].bias.data.fill_(0.0)\nmlp = Mod(mlp, in_keys=[\"embed\"], out_keys=[\"action_value\"])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using the Q-Values to select an action\n======================================\n\nThe last part of our policy is the Q-Value Module. The Q-Value module\n`~torchrl.modules.tensordict_module.QValueModule`{.interpreted-text\nrole=\"class\"} will read the `\"action_values\"` key that is produced by\nour MLP and from it, gather the action that has the maximum value. The\nonly thing we need to do is to specify the action space, which can be\ndone either by passing a string or an action-spec. This allows us to use\nCategorical (sometimes called \\\"sparse\\\") encoding or the one-hot\nversion of it.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "qval = QValueModule(spec=env.action_spec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

TorchRL also provides a wrapper class torchrl.modules.QValueActor thatwraps a module in a Sequential together with a ~torchrl.modules.tensordict_module.QValueModulelike we are doing explicitly here. There is little advantage to do thisand the process is less transparent, but the end results will be similar towhat we do here.

\n```\n```{=html}\n
\n```\nWe can now put things together in a\n`~tensordict.nn.TensorDictSequential`{.interpreted-text role=\"class\"}\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "stoch_policy = Seq(feature, lstm, mlp, qval)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DQN being a deterministic algorithm, exploration is a crucial part of\nit. We\\'ll be using an $\\epsilon$-greedy policy with an epsilon of 0.2\ndecaying progressively to 0. This decay is achieved via a call to\n`~torchrl.modules.EGreedyModule.step`{.interpreted-text role=\"meth\"}\n(see training loop below).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "exploration_module = EGreedyModule(\n annealing_num_steps=1_000_000, spec=env.action_spec, eps_init=0.2\n)\nstoch_policy = Seq(\n stoch_policy,\n exploration_module,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using the model for the loss\n============================\n\nThe model as we\\'ve built it is well equipped to be used in sequential\nsettings. However, the class `torch.nn.LSTM`{.interpreted-text\nrole=\"class\"} can use a cuDNN-optimized backend to run the RNN sequence\nfaster on GPU device. We would not want to miss such an opportunity to\nspeed up our training loop! To use it, we just need to tell the LSTM\nmodule to run on \\\"recurrent-mode\\\" when used by the loss. As we\\'ll\nusually want to have two copies of the LSTM module, we do this by\ncalling a\n`~torchrl.modules.LSTMModule.set_recurrent_mode`{.interpreted-text\nrole=\"meth\"} method that will return a new instance of the LSTM (with\nshared weights) that will assume that the input data is sequential in\nnature.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "policy = Seq(feature, lstm.set_recurrent_mode(True), mlp, qval)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Because we still have a couple of uninitialized parameters we should\ninitialize them before creating an optimizer and such.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "policy(env.reset())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "DQN Loss\n========\n\nOut DQN loss requires us to pass the policy and, again, the\naction-space. While this may seem redundant, it is important as we want\nto make sure that the `~torchrl.objectives.DQNLoss`{.interpreted-text\nrole=\"class\"} and the\n`~torchrl.modules.tensordict_module.QValueModule`{.interpreted-text\nrole=\"class\"} classes are compatible, but aren\\'t strongly dependent on\neach other.\n\nTo use the Double-DQN, we ask for a `delay_value` argument that will\ncreate a non-differentiable copy of the network parameters to be used as\na target network.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "loss_fn = DQNLoss(policy, action_space=env.action_spec, delay_value=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since we are using a double DQN, we need to update the target\nparameters. We\\'ll use a\n`~torchrl.objectives.SoftUpdate`{.interpreted-text role=\"class\"}\ninstance to carry out this work.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "updater = SoftUpdate(loss_fn, eps=0.95)\n\noptim = torch.optim.Adam(policy.parameters(), lr=3e-4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Collector and replay buffer\n===========================\n\nWe build the simplest data collector there is. We\\'ll try to train our\nalgorithm with a million frames, extending the buffer with 50 frames at\na time. The buffer will be designed to store 20 thousands trajectories\nof 50 steps each. At each optimization step (16 per data collection),\nwe\\'ll collect 4 items from our buffer, for a total of 200 transitions.\nWe\\'ll use a\n`~torchrl.data.replay_buffers.LazyMemmapStorage`{.interpreted-text\nrole=\"class\"} storage to keep the data on disk.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

For the sake of efficiency, we're only running a few thousands iterationshere. In a real setting, the total number of frames should be set to 1M.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "collector = SyncDataCollector(env, stoch_policy, frames_per_batch=50, total_frames=200, device=device)\nrb = TensorDictReplayBuffer(\n storage=LazyMemmapStorage(20_000), batch_size=4, prefetch=10\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Training loop\n=============\n\nTo keep track of the progress, we will run the policy in the environment\nonce every 50 data collection, and plot the results after training.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "utd = 16\npbar = tqdm.tqdm(total=1_000_000)\nlongest = 0\n\ntraj_lens = []\nfor i, data in enumerate(collector):\n if i == 0:\n print(\n \"Let us print the first batch of data.\\nPay attention to the key names \"\n \"which will reflect what can be found in this data structure, in particular: \"\n \"the output of the QValueModule (action_values, action and chosen_action_value),\"\n \"the 'is_init' key that will tell us if a step is initial or not, and the \"\n \"recurrent_state keys.\\n\",\n data,\n )\n pbar.update(data.numel())\n # it is important to pass data that is not flattened\n rb.extend(data.unsqueeze(0).to_tensordict().cpu())\n for _ in range(utd):\n s = rb.sample().to(device, non_blocking=True)\n loss_vals = loss_fn(s)\n loss_vals[\"loss\"].backward()\n optim.step()\n optim.zero_grad()\n longest = max(longest, data[\"step_count\"].max().item())\n pbar.set_description(\n f\"steps: {longest}, loss_val: {loss_vals['loss'].item(): 4.4f}, action_spread: {data['action'].sum(0)}\"\n )\n exploration_module.step(data.numel())\n updater.step()\n\n with set_exploration_type(ExplorationType.DETERMINISTIC), torch.no_grad():\n rollout = env.rollout(10000, stoch_policy)\n traj_lens.append(rollout.get((\"next\", \"step_count\")).max().item())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let\\'s plot our results:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "if traj_lens:\n from matplotlib import pyplot as plt\n\n plt.plot(traj_lens)\n plt.xlabel(\"Test collection\")\n plt.title(\"Test trajectory lengths\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nWe have seen how an RNN can be incorporated in a policy in TorchRL. You\nshould now be able:\n\n- Create an LSTM module that acts as a\n `~tensordict.nn.TensorDictModule`{.interpreted-text role=\"class\"}\n- Indicate to the LSTM module that a reset is needed via an\n `~torchrl.envs.transforms.InitTracker`{.interpreted-text\n role=\"class\"} transform\n- Incorporate this module in a policy and in a loss module\n- Make sure that the collector is made aware of the recurrent state\n entries such that they can be stored in the replay buffer along with\n the rest of the data\n\nFurther Reading\n===============\n\n- The TorchRL documentation can be found\n [here](https://pytorch.org/rl/).\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }