{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchRL objectives: Coding a DDPG loss\n======================================\n\n**Author**: [Vincent Moens](https://github.com/vmoens)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overview\n========\n\nTorchRL separates the training of RL algorithms in various pieces that\nwill be assembled in your training script: the environment, the data\ncollection and storage, the model and finally the loss function.\n\nTorchRL losses (or \\\"objectives\\\") are stateful objects that contain the\ntrainable parameters (policy and value models). This tutorial will guide\nyou through the steps to code a loss from the ground up using TorchRL.\n\nTo this aim, we will be focusing on DDPG, which is a relatively\nstraightforward algorithm to code. [Deep Deterministic Policy\nGradient](https://arxiv.org/abs/1509.02971) (DDPG) is a simple\ncontinuous control algorithm. It consists in learning a parametric value\nfunction for an action-observation pair, and then learning a policy that\noutputs actions that maximize this value function given a certain\nobservation.\n\nWhat you will learn:\n\n- how to write a loss module and customize its value estimator;\n- how to build an environment in TorchRL, including transforms (for\n example, data normalization) and parallel execution;\n- how to design a policy and value network;\n- how to collect data from your environment efficiently and store them\n in a replay buffer;\n- how to store trajectories (and not transitions) in your replay\n buffer);\n- how to evaluate your model.\n\nPrerequisites\n-------------\n\nThis tutorial assumes that you have completed the [PPO\ntutorial](reinforcement_ppo.html) which gives an overview of the TorchRL\ncomponents and dependencies, such as\n`tensordict.TensorDict`{.interpreted-text role=\"class\"} and\n`tensordict.nn.TensorDictModules`{.interpreted-text role=\"class\"},\nalthough it should be sufficiently transparent to be understood without\na deep understanding of these classes.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

We do not aim at giving a SOTA implementation of the algorithm, but ratherto provide a high-level illustration of TorchRL's loss implementationsand the library features that are to be used in the context ofthis algorithm.

\n```\n```{=html}\n
\n```\nImports and setup\n=================\n\n> ``` {.bash}\n> %%bash\n> pip3 install torchrl mujoco glfw\n> ```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport tqdm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will execute the policy on CUDA if available\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "is_fork = multiprocessing.get_start_method() == \"fork\"\ndevice = (\n torch.device(0)\n if torch.cuda.is_available() and not is_fork\n else torch.device(\"cpu\")\n)\ncollector_device = torch.device(\"cpu\") # Change the device to ``cuda`` to use CUDA" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchRL `~torchrl.objectives.LossModule`{.interpreted-text role=\"class\"}\n========================================================================\n\nTorchRL provides a series of losses to use in your training scripts. The\naim is to have losses that are easily reusable/swappable and that have a\nsimple signature.\n\nThe main characteristics of TorchRL losses are:\n\n- They are stateful objects: they contain a copy of the trainable\n parameters such that `loss_module.parameters()` gives whatever is\n needed to train the algorithm.\n\n- They follow the `TensorDict` convention: the\n `torch.nn.Module.forward`{.interpreted-text role=\"meth\"} method will\n receive a TensorDict as input that contains all the necessary\n information to return a loss value.\n\n > \\>\\>\\> data = replay\\_buffer.sample() \\>\\>\\> loss\\_dict =\n > loss\\_module(data)\n\n- They output a `tensordict.TensorDict`{.interpreted-text\n role=\"class\"} instance with the loss values written under a\n `\"loss_\"` where `smth` is a string describing the loss.\n Additional keys in the `TensorDict` may be useful metrics to log\n during training time.\n\n ```{=html}\n
NOTE:
\n ```\n ```{=html}\n
\n ```\n ```{=html}\n

The reason we return independent losses is to let the user use a differentoptimizer for different sets of parameters for instance. Summing the lossescan be simply done via>>> loss_val = sum(loss for key, loss in loss_dict.items() if key.startswith(\"loss\"))

\n ```\n ```{=html}\n
\n ```\n\nThe `__init__` method\n---------------------\n\nThe parent class of all losses is\n`~torchrl.objectives.LossModule`{.interpreted-text role=\"class\"}. As\nmany other components of the library, its\n`~torchrl.objectives.LossModule.forward`{.interpreted-text role=\"meth\"}\nmethod expects as input a `tensordict.TensorDict`{.interpreted-text\nrole=\"class\"} instance sampled from an experience replay buffer, or any\nsimilar data structure. Using this format makes it possible to re-use\nthe module across modalities, or in complex settings where the model\nneeds to read multiple entries for instance. In other words, it allows\nus to code a loss module that is oblivious to the data type that is\nbeing given to is and that focuses on running the elementary steps of\nthe loss function and only those.\n\nTo keep the tutorial as didactic as we can, we\\'ll be displaying each\nmethod of the class independently and we\\'ll be populating the class at\na later stage.\n\nLet us start with the\n`~torchrl.objectives.LossModule.__init__`{.interpreted-text role=\"meth\"}\nmethod. DDPG aims at solving a control task with a simple strategy:\ntraining a policy to output actions that maximize the value predicted by\na value network. Hence, our loss module needs to receive two networks in\nits constructor: an actor and a value networks. We expect both of these\nto be TensorDict-compatible objects, such as\n`tensordict.nn.TensorDictModule`{.interpreted-text role=\"class\"}. Our\nloss function will need to compute a target value and fit the value\nnetwork to this, and generate an action and fit the policy such that its\nvalue estimate is maximized.\n\nThe crucial step of the `LossModule.__init__`{.interpreted-text\nrole=\"meth\"} method is the call to\n`~torchrl.LossModule.convert_to_functional`{.interpreted-text\nrole=\"meth\"}. This method will extract the parameters from the module\nand convert it to a functional module. Strictly speaking, this is not\nnecessary and one may perfectly code all the losses without it. However,\nwe encourage its usage for the following reason.\n\nThe reason TorchRL does this is that RL algorithms often execute the\nsame model with different sets of parameters, called \\\"trainable\\\" and\n\\\"target\\\" parameters. The \\\"trainable\\\" parameters are those that the\noptimizer needs to fit. The \\\"target\\\" parameters are usually a copy of\nthe former\\'s with some time lag (absolute or diluted through a moving\naverage). These target parameters are used to compute the value\nassociated with the next observation. One the advantages of using a set\nof target parameters for the value model that do not match exactly the\ncurrent configuration is that they provide a pessimistic bound on the\nvalue function being computed. Pay attention to the\n`create_target_params` keyword argument below: this argument tells the\n`~torchrl.objectives.LossModule.convert_to_functional`{.interpreted-text\nrole=\"meth\"} method to create a set of target parameters in the loss\nmodule to be used for target value computation. If this is set to\n`False` (see the actor network for instance) the\n`target_actor_network_params` attribute will still be accessible but\nthis will just return a **detached** version of the actor parameters.\n\nLater, we will see how the target parameters should be updated in\nTorchRL.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from tensordict.nn import TensorDictModule, TensorDictSequential\n\n\ndef _init(\n self,\n actor_network: TensorDictModule,\n value_network: TensorDictModule,\n) -> None:\n super(type(self), self).__init__()\n\n self.convert_to_functional(\n actor_network,\n \"actor_network\",\n create_target_params=True,\n )\n self.convert_to_functional(\n value_network,\n \"value_network\",\n create_target_params=True,\n compare_against=list(actor_network.parameters()),\n )\n\n self.actor_in_keys = actor_network.in_keys\n\n # Since the value we'll be using is based on the actor and value network,\n # we put them together in a single actor-critic container.\n actor_critic = ActorCriticWrapper(actor_network, value_network)\n self.actor_critic = actor_critic\n self.loss_function = \"l2\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The value estimator loss method\n===============================\n\nIn many RL algorithm, the value network (or Q-value network) is trained\nbased on an empirical value estimate. This can be bootstrapped (TD(0),\nlow variance, high bias), meaning that the target value is obtained\nusing the next reward and nothing else, or a Monte-Carlo estimate can be\nobtained (TD(1)) in which case the whole sequence of upcoming rewards\nwill be used (high variance, low bias). An intermediate estimator\n(TD($\\lambda$)) can also be used to compromise bias and variance.\nTorchRL makes it easy to use one or the other estimator via the\n`~torchrl.objectives.utils.ValueEstimators`{.interpreted-text\nrole=\"class\"} Enum class, which contains pointers to all the value\nestimators implemented. Let us define the default value function here.\nWe will take the simplest version (TD(0)), and show later on how this\ncan be changed.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.objectives.utils import ValueEstimators\n\ndefault_value_estimator = ValueEstimators.TD0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need to give some instructions to DDPG on how to build the value\nestimator, depending on the user query. Depending on the estimator\nprovided, we will build the corresponding module to be used at train\ntime:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.objectives.utils import default_value_kwargs\nfrom torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator\n\n\ndef make_value_estimator(self, value_type: ValueEstimators, **hyperparams):\n hp = dict(default_value_kwargs(value_type))\n if hasattr(self, \"gamma\"):\n hp[\"gamma\"] = self.gamma\n hp.update(hyperparams)\n value_key = \"state_action_value\"\n if value_type == ValueEstimators.TD1:\n self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)\n elif value_type == ValueEstimators.TD0:\n self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)\n elif value_type == ValueEstimators.GAE:\n raise NotImplementedError(\n f\"Value type {value_type} it not implemented for loss {type(self)}.\"\n )\n elif value_type == ValueEstimators.TDLambda:\n self._value_estimator = TDLambdaEstimator(value_network=self.actor_critic, **hp)\n else:\n raise NotImplementedError(f\"Unknown value type {value_type}\")\n self._value_estimator.set_keys(value=value_key)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The `make_value_estimator` method can but does not need to be called: if\nnot, the `~torchrl.objectives.LossModule`{.interpreted-text\nrole=\"class\"} will query this method with its default estimator.\n\nThe actor loss method\n=====================\n\nThe central piece of an RL algorithm is the training loss for the actor.\nIn the case of DDPG, this function is quite simple: we just need to\ncompute the value associated with an action computed using the policy\nand optimize the actor weights to maximize this value.\n\nWhen computing this value, we must make sure to take the value\nparameters out of the graph, otherwise the actor and value loss will be\nmixed up. For this, the\n`~torchrl.objectives.utils.hold_out_params`{.interpreted-text\nrole=\"func\"} function can be used.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def _loss_actor(\n self,\n tensordict,\n) -> torch.Tensor:\n td_copy = tensordict.select(*self.actor_in_keys)\n # Get an action from the actor network: since we made it functional, we need to pass the params\n with self.actor_network_params.to_module(self.actor_network):\n td_copy = self.actor_network(td_copy)\n # get the value associated with that action\n with self.value_network_params.detach().to_module(self.value_network):\n td_copy = self.value_network(td_copy)\n return -td_copy.get(\"state_action_value\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The value loss method\n=====================\n\nWe now need to optimize our value network parameters. To do this, we\nwill rely on the value estimator of our class:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.objectives.utils import distance_loss\n\n\ndef _loss_value(\n self,\n tensordict,\n):\n td_copy = tensordict.clone()\n\n # V(s, a)\n with self.value_network_params.to_module(self.value_network):\n self.value_network(td_copy)\n pred_val = td_copy.get(\"state_action_value\").squeeze(-1)\n\n # we manually reconstruct the parameters of the actor-critic, where the first\n # set of parameters belongs to the actor and the second to the value function.\n target_params = TensorDict(\n {\n \"module\": {\n \"0\": self.target_actor_network_params,\n \"1\": self.target_value_network_params,\n }\n },\n batch_size=self.target_actor_network_params.batch_size,\n device=self.target_actor_network_params.device,\n )\n with target_params.to_module(self.actor_critic):\n target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)\n\n # Computes the value loss: L2, L1 or smooth L1 depending on `self.loss_function`\n loss_value = distance_loss(pred_val, target_value, loss_function=self.loss_function)\n td_error = (pred_val - target_value).pow(2)\n\n return loss_value, td_error, pred_val, target_value" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Putting things together in a forward call\n=========================================\n\nThe only missing piece is the forward method, which will glue together\nthe value and actor loss, collect the cost values and write them in a\n`TensorDict` delivered to the user.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from tensordict import TensorDict, TensorDictBase\n\n\ndef _forward(self, input_tensordict: TensorDictBase) -> TensorDict:\n loss_value, td_error, pred_val, target_value = self.loss_value(\n input_tensordict,\n )\n td_error = td_error.detach()\n td_error = td_error.unsqueeze(input_tensordict.ndimension())\n if input_tensordict.device is not None:\n td_error = td_error.to(input_tensordict.device)\n input_tensordict.set(\n \"td_error\",\n td_error,\n inplace=True,\n )\n loss_actor = self.loss_actor(input_tensordict)\n return TensorDict(\n source={\n \"loss_actor\": loss_actor.mean(),\n \"loss_value\": loss_value.mean(),\n \"pred_value\": pred_val.mean().detach(),\n \"target_value\": target_value.mean().detach(),\n \"pred_value_max\": pred_val.max().detach(),\n \"target_value_max\": target_value.max().detach(),\n },\n batch_size=[],\n )\n\n\nfrom torchrl.objectives import LossModule\n\n\nclass DDPGLoss(LossModule):\n default_value_estimator = default_value_estimator\n make_value_estimator = make_value_estimator\n\n __init__ = _init\n forward = _forward\n loss_value = _loss_value\n loss_actor = _loss_actor" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we have our loss, we can use it to train a policy to solve a\ncontrol task.\n\nEnvironment\n===========\n\nIn most algorithms, the first thing that needs to be taken care of is\nthe construction of the environment as it conditions the remainder of\nthe training script.\n\nFor this example, we will be using the `\"cheetah\"` task. The goal is to\nmake a half-cheetah run as fast as possible.\n\nIn TorchRL, one can create such a task by relying on `dm_control` or\n`gym`:\n\n``` {.python}\nenv = GymEnv(\"HalfCheetah-v4\")\n```\n\nor\n\n``` {.python}\nenv = DMControlEnv(\"cheetah\", \"run\")\n```\n\nBy default, these environment disable rendering. Training from states is\nusually easier than training from images. To keep things simple, we\nfocus on learning from states only. To pass the pixels to the\n`tensordicts` that are collected by `env.step()`{.interpreted-text\nrole=\"func\"}, simply pass the `from_pixels=True` argument to the\nconstructor:\n\n``` {.python}\nenv = GymEnv(\"HalfCheetah-v4\", from_pixels=True, pixels_only=True)\n```\n\nWe write a `make_env`{.interpreted-text role=\"func\"} helper function\nthat will create an environment with either one of the two backends\nconsidered above (`dm-control` or `gym`).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.envs.libs.dm_control import DMControlEnv\nfrom torchrl.envs.libs.gym import GymEnv\n\nenv_library = None\nenv_name = None\n\n\ndef make_env(from_pixels=False):\n \"\"\"Create a base ``env``.\"\"\"\n global env_library\n global env_name\n\n if backend == \"dm_control\":\n env_name = \"cheetah\"\n env_task = \"run\"\n env_args = (env_name, env_task)\n env_library = DMControlEnv\n elif backend == \"gym\":\n env_name = \"HalfCheetah-v4\"\n env_args = (env_name,)\n env_library = GymEnv\n else:\n raise NotImplementedError\n\n env_kwargs = {\n \"device\": device,\n \"from_pixels\": from_pixels,\n \"pixels_only\": from_pixels,\n \"frame_skip\": 2,\n }\n env = env_library(*env_args, **env_kwargs)\n return env" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Transforms\n==========\n\nNow that we have a base environment, we may want to modify its\nrepresentation to make it more policy-friendly. In TorchRL, transforms\nare appended to the base environment in a specialized\n`torchr.envs.TransformedEnv`{.interpreted-text role=\"class\"} class.\n\n- It is common in DDPG to rescale the reward using some heuristic\n value. We will multiply the reward by 5 in this example.\n- If we are using `dm_control`{.interpreted-text role=\"mod\"}, it is\n also important to build an interface between the simulator which\n works with double precision numbers, and our script which presumably\n uses single precision ones. This transformation goes both ways: when\n calling `env.step`{.interpreted-text role=\"func\"}, our actions will\n need to be represented in double precision, and the output will need\n to be transformed to single precision. The\n `~torchrl.envs.DoubleToFloat`{.interpreted-text role=\"class\"}\n transform does exactly this: the `in_keys` list refers to the keys\n that will need to be transformed from double to float, while the\n `in_keys_inv` refers to those that need to be transformed to double\n before being passed to the environment.\n- We concatenate the state keys together using the\n `~torchrl.envs.CatTensors`{.interpreted-text role=\"class\"}\n transform.\n- Finally, we also leave the possibility of normalizing the states: we\n will take care of computing the normalizing constants later on.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.envs import (\n CatTensors,\n DoubleToFloat,\n EnvCreator,\n InitTracker,\n ObservationNorm,\n ParallelEnv,\n RewardScaling,\n StepCounter,\n TransformedEnv,\n)\n\n\ndef make_transformed_env(\n env,\n):\n \"\"\"Apply transforms to the ``env`` (such as reward scaling and state normalization).\"\"\"\n\n env = TransformedEnv(env)\n\n # we append transforms one by one, although we might as well create the\n # transformed environment using the `env = TransformedEnv(base_env, transforms)`\n # syntax.\n env.append_transform(RewardScaling(loc=0.0, scale=reward_scaling))\n\n # We concatenate all states into a single \"observation_vector\"\n # even if there is a single tensor, it'll be renamed in \"observation_vector\".\n # This facilitates the downstream operations as we know the name of the\n # output tensor.\n # In some environments (not half-cheetah), there may be more than one\n # observation vector: in this case this code snippet will concatenate them\n # all.\n selected_keys = list(env.observation_spec.keys())\n out_key = \"observation_vector\"\n env.append_transform(CatTensors(in_keys=selected_keys, out_key=out_key))\n\n # we normalize the states, but for now let's just instantiate a stateless\n # version of the transform\n env.append_transform(ObservationNorm(in_keys=[out_key], standard_normal=True))\n\n env.append_transform(DoubleToFloat())\n\n env.append_transform(StepCounter(max_frames_per_traj))\n\n # We need a marker for the start of trajectories for our Ornstein-Uhlenbeck (OU)\n # exploration:\n env.append_transform(InitTracker())\n\n return env" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Parallel execution\n==================\n\nThe following helper function allows us to run environments in parallel.\nRunning environments in parallel can significantly speed up the\ncollection throughput. When using transformed environment, we need to\nchoose whether we want to execute the transform individually for each\nenvironment, or centralize the data and transform it in batch. Both\napproaches are easy to code:\n\n``` {.python}\nenv = ParallelEnv(\n lambda: TransformedEnv(GymEnv(\"HalfCheetah-v4\"), transforms),\n num_workers=4\n)\nenv = TransformedEnv(\n ParallelEnv(lambda: GymEnv(\"HalfCheetah-v4\"), num_workers=4),\n transforms\n)\n```\n\nTo leverage the vectorization capabilities of PyTorch, we adopt the\nfirst method:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def parallel_env_constructor(\n env_per_collector,\n transform_state_dict,\n):\n if env_per_collector == 1:\n\n def make_t_env():\n env = make_transformed_env(make_env())\n env.transform[2].init_stats(3)\n env.transform[2].loc.copy_(transform_state_dict[\"loc\"])\n env.transform[2].scale.copy_(transform_state_dict[\"scale\"])\n return env\n\n env_creator = EnvCreator(make_t_env)\n return env_creator\n\n parallel_env = ParallelEnv(\n num_workers=env_per_collector,\n create_env_fn=EnvCreator(lambda: make_env()),\n create_env_kwargs=None,\n pin_memory=False,\n )\n env = make_transformed_env(parallel_env)\n # we call `init_stats` for a limited number of steps, just to instantiate\n # the lazy buffers.\n env.transform[2].init_stats(3, cat_dim=1, reduce_dim=[0, 1])\n env.transform[2].load_state_dict(transform_state_dict)\n return env\n\n\n# The backend can be ``gym`` or ``dm_control``\nbackend = \"gym\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

frame_skip batches multiple step together with a single actionIf > 1, the other frame counts (for example, frames_per_batch, total_frames)need to be adjusted to have a consistent total number of frames collectedacross experiments. This is important as raising the frame-skip but keeping thetotal number of frames unchanged may seem like cheating: all things compared,a dataset of 10M elements collected with a frame-skip of 2 and another witha frame-skip of 1 actually have a ratio of interactions with the environmentof 2:1! In a nutshell, one should be cautious about the frame-count of atraining script when dealing with frame skipping as this may lead tobiased comparisons between training strategies.

\n```\n```{=html}\n
\n```\nScaling the reward helps us control the signal magnitude for a more\nefficient learning.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "reward_scaling = 5.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also define when a trajectory will be truncated. A thousand steps\n(500 if frame-skip = 2) is a good number to use for the cheetah task:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "max_frames_per_traj = 500" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Normalization of the observations\n=================================\n\nTo compute the normalizing statistics, we run an arbitrary number of\nrandom steps in the environment and compute the mean and standard\ndeviation of the collected observations. The\n`ObservationNorm.init_stats()`{.interpreted-text role=\"func\"} method can\nbe used for this purpose. To get the summary statistics, we create a\ndummy environment and run it for a given number of steps, collect data\nover a given number of steps and compute its summary statistics.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def get_env_stats():\n \"\"\"Gets the stats of an environment.\"\"\"\n proof_env = make_transformed_env(make_env())\n t = proof_env.transform[2]\n t.init_stats(init_env_steps)\n transform_state_dict = t.state_dict()\n proof_env.close()\n return transform_state_dict" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Normalization stats\n===================\n\nNumber of random steps used as for stats computation using\n`ObservationNorm`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "init_env_steps = 5000\n\ntransform_state_dict = get_env_stats()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Number of environments in each data collector\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "env_per_collector = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We pass the stats computed earlier to normalize the output of our\nenvironment:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "parallel_env = parallel_env_constructor(\n env_per_collector=env_per_collector,\n transform_state_dict=transform_state_dict,\n)\n\n\nfrom torchrl.data import CompositeSpec" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Building the model\n==================\n\nWe now turn to the setup of the model. As we have seen, DDPG requires a\nvalue network, trained to estimate the value of a state-action pair, and\na parametric actor that learns how to select actions that maximize this\nvalue.\n\nRecall that building a TorchRL module requires two steps:\n\n- writing the `torch.nn.Module`{.interpreted-text role=\"class\"} that\n will be used as network,\n- wrapping the network in a\n `tensordict.nn.TensorDictModule`{.interpreted-text role=\"class\"}\n where the data flow is handled by specifying the input and output\n keys.\n\nIn more complex scenarios,\n`tensordict.nn.TensorDictSequential`{.interpreted-text role=\"class\"} can\nalso be used.\n\nThe Q-Value network is wrapped in a\n`~torchrl.modules.ValueOperator`{.interpreted-text role=\"class\"} that\nautomatically sets the `out_keys` to `\"state_action_value` for q-value\nnetworks and `state_value` for other value networks.\n\nTorchRL provides a built-in version of the DDPG networks as presented in\nthe original paper. These can be found under\n`~torchrl.modules.DdpgMlpActor`{.interpreted-text role=\"class\"} and\n`~torchrl.modules.DdpgMlpQNet`{.interpreted-text role=\"class\"}.\n\nSince we use lazy modules, it is necessary to materialize the lazy\nmodules before being able to move the policy from device to device and\nachieve other operations. Hence, it is good practice to run the modules\nwith a small sample of data. For this purpose, we generate fake data\nfrom the environment specs.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.modules import (\n ActorCriticWrapper,\n DdpgMlpActor,\n DdpgMlpQNet,\n OrnsteinUhlenbeckProcessModule,\n ProbabilisticActor,\n TanhDelta,\n ValueOperator,\n)\n\n\ndef make_ddpg_actor(\n transform_state_dict,\n device=\"cpu\",\n):\n proof_environment = make_transformed_env(make_env())\n proof_environment.transform[2].init_stats(3)\n proof_environment.transform[2].load_state_dict(transform_state_dict)\n\n out_features = proof_environment.action_spec.shape[-1]\n\n actor_net = DdpgMlpActor(\n action_dim=out_features,\n )\n\n in_keys = [\"observation_vector\"]\n out_keys = [\"param\"]\n\n actor = TensorDictModule(\n actor_net,\n in_keys=in_keys,\n out_keys=out_keys,\n )\n\n actor = ProbabilisticActor(\n actor,\n distribution_class=TanhDelta,\n in_keys=[\"param\"],\n spec=CompositeSpec(action=proof_environment.action_spec),\n ).to(device)\n\n q_net = DdpgMlpQNet()\n\n in_keys = in_keys + [\"action\"]\n qnet = ValueOperator(\n in_keys=in_keys,\n module=q_net,\n ).to(device)\n\n # initialize lazy modules\n qnet(actor(proof_environment.reset().to(device)))\n return actor, qnet\n\n\nactor, qnet = make_ddpg_actor(\n transform_state_dict=transform_state_dict,\n device=device,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Exploration\n===========\n\nThe policy is passed into a\n`~torchrl.modules.OrnsteinUhlenbeckProcessModule`{.interpreted-text\nrole=\"class\"} exploration module, as suggested in the original paper.\nLet\\'s define the number of frames before OU noise reaches its minimum\nvalue\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "annealing_frames = 1_000_000\n\nactor_model_explore = TensorDictSequential(\n actor,\n OrnsteinUhlenbeckProcessModule(\n spec=actor.spec.clone(),\n annealing_num_steps=annealing_frames,\n ).to(device),\n)\nif device == torch.device(\"cpu\"):\n actor_model_explore.share_memory()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data collector\n==============\n\nTorchRL provides specialized classes to help you collect data by\nexecuting the policy in the environment. These \\\"data collectors\\\"\niteratively compute the action to be executed at a given time, then\nexecute a step in the environment and reset it when required. Data\ncollectors are designed to help developers have a tight control on the\nnumber of frames per batch of data, on the (a)sync nature of this\ncollection and on the resources allocated to the data collection (for\nexample GPU, number of workers, and so on).\n\nHere we will use\n`~torchrl.collectors.SyncDataCollector`{.interpreted-text role=\"class\"},\na simple, single-process data collector. TorchRL offers other\ncollectors, such as\n`~torchrl.collectors.MultiaSyncDataCollector`{.interpreted-text\nrole=\"class\"}, which executed the rollouts in an asynchronous manner\n(for example, data will be collected while the policy is being\noptimized, thereby decoupling the training and data collection).\n\nThe parameters to specify are:\n\n- an environment factory or an environment,\n\n- the policy,\n\n- the total number of frames before the collector is considered empty,\n\n- the maximum number of frames per trajectory (useful for\n non-terminating environments, like `dm_control` ones).\n\n ```{=html}\n
NOTE:
\n ```\n ```{=html}\n
\n ```\n ```{=html}\n

The max_frames_per_traj passed to the collector will have the effectof registering a new ~torchrl.envs.StepCounter transformwith the environment used for inference. We can achieve the same resultmanually, as we do in this script.

\n ```\n ```{=html}\n
\n ```\n\nOne should also pass:\n\n- the number of frames in each batch collected,\n- the number of random steps executed independently from the policy,\n- the devices used for policy execution\n- the devices used to store data before the data is passed to the main\n process.\n\nThe total frames we will use during training should be around 1M.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "total_frames = 10_000 # 1_000_000" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The number of frames returned by the collector at each iteration of the\nouter loop is equal to the length of each sub-trajectories times the\nnumber of environments run in parallel in each collector.\n\nIn other words, we expect batches from the collector to have a shape\n`[env_per_collector, traj_len]` where\n`traj_len=frames_per_batch/env_per_collector`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "traj_len = 200\nframes_per_batch = env_per_collector * traj_len\ninit_random_frames = 5000\nnum_collectors = 2\n\nfrom torchrl.collectors import SyncDataCollector\nfrom torchrl.envs import ExplorationType\n\ncollector = SyncDataCollector(\n parallel_env,\n policy=actor_model_explore,\n total_frames=total_frames,\n frames_per_batch=frames_per_batch,\n init_random_frames=init_random_frames,\n reset_at_each_iter=False,\n split_trajs=False,\n device=collector_device,\n exploration_type=ExplorationType.RANDOM,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Evaluator: building your recorder object\n========================================\n\nAs the training data is obtained using some exploration strategy, the\ntrue performance of our algorithm needs to be assessed in deterministic\nmode. We do this using a dedicated class, `Recorder`, which executes the\npolicy in the environment at a given frequency and returns some\nstatistics obtained from these simulations.\n\nThe following helper function builds this object:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.trainers import Recorder\n\n\ndef make_recorder(actor_model_explore, transform_state_dict, record_interval):\n base_env = make_env()\n environment = make_transformed_env(base_env)\n environment.transform[2].init_stats(\n 3\n ) # must be instantiated to load the state dict\n environment.transform[2].load_state_dict(transform_state_dict)\n\n recorder_obj = Recorder(\n record_frames=1000,\n policy_exploration=actor_model_explore,\n environment=environment,\n exploration_type=ExplorationType.DETERMINISTIC,\n record_interval=record_interval,\n )\n return recorder_obj" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will be recording the performance every 10 batch collected\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "record_interval = 10\n\nrecorder = make_recorder(\n actor_model_explore, transform_state_dict, record_interval=record_interval\n)\n\nfrom torchrl.data.replay_buffers import (\n LazyMemmapStorage,\n PrioritizedSampler,\n RandomSampler,\n TensorDictReplayBuffer,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Replay buffer\n=============\n\nReplay buffers come in two flavors: prioritized (where some error signal\nis used to give a higher likelihood of sampling to some items than\nothers) and regular, circular experience replay.\n\nTorchRL replay buffers are composable: one can pick up the storage,\nsampling and writing strategies. It is also possible to store tensors on\nphysical memory using a memory-mapped array. The following function\ntakes care of creating the replay buffer with the desired\nhyperparameters:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.envs import RandomCropTensorDict\n\n\ndef make_replay_buffer(buffer_size, batch_size, random_crop_len, prefetch=3, prb=False):\n if prb:\n sampler = PrioritizedSampler(\n max_capacity=buffer_size,\n alpha=0.7,\n beta=0.5,\n )\n else:\n sampler = RandomSampler()\n replay_buffer = TensorDictReplayBuffer(\n storage=LazyMemmapStorage(\n buffer_size,\n scratch_dir=buffer_scratch_dir,\n ),\n batch_size=batch_size,\n sampler=sampler,\n pin_memory=False,\n prefetch=prefetch,\n transform=RandomCropTensorDict(random_crop_len, sample_dim=1),\n )\n return replay_buffer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We\\'ll store the replay buffer in a temporary directory on disk\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tempfile\n\ntmpdir = tempfile.TemporaryDirectory()\nbuffer_scratch_dir = tmpdir.name" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Replay buffer storage and batch size\n====================================\n\nTorchRL replay buffer counts the number of elements along the first\ndimension. Since we\\'ll be feeding trajectories to our buffer, we need\nto adapt the buffer size by dividing it by the length of the\nsub-trajectories yielded by our data collector. Regarding the\nbatch-size, our sampling strategy will consist in sampling trajectories\nof length `traj_len=200` before selecting sub-trajectories or length\n`random_crop_len=25` on which the loss will be computed. This strategy\nbalances the choice of storing whole trajectories of a certain length\nwith the need for providing samples with a sufficient heterogeneity to\nour loss. The following figure shows the dataflow from a collector that\ngets 8 frames in each batch with 2 environments run in parallel, feeds\nthem to a replay buffer that contains 1000 trajectories and samples\nsub-trajectories of 2 time steps each.\n\n![](https://pytorch.org/tutorials/_static/img/replaybuffer_traj.png)\n\nLet\\'s start with the number of frames stored in the buffer\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def ceil_div(x, y):\n return -x // (-y)\n\n\nbuffer_size = 1_000_000\nbuffer_size = ceil_div(buffer_size, traj_len)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prioritized replay buffer is disabled by default\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "prb = False" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We also need to define how many updates we\\'ll be doing per batch of\ndata collected. This is known as the update-to-data or `UTD` ratio:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "update_to_data = 64" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We\\'ll be feeding the loss with trajectories of length 25:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "random_crop_len = 25" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the original paper, the authors perform one update with a batch of 64\nelements for each frame collected. Here, we reproduce the same ratio but\nwhile realizing several updates at each batch collection. We adapt our\nbatch-size to achieve the same number of update-per-frame ratio:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "batch_size = ceil_div(64 * frames_per_batch, update_to_data * random_crop_len)\n\nreplay_buffer = make_replay_buffer(\n buffer_size=buffer_size,\n batch_size=batch_size,\n random_crop_len=random_crop_len,\n prefetch=3,\n prb=prb,\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Loss module construction\n========================\n\nWe build our loss module with the actor and `qnet` we\\'ve just created.\nBecause we have target parameters to update, we \\_[must]() create a\ntarget network updater.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "gamma = 0.99\nlmbda = 0.9\ntau = 0.001 # Decay factor for the target network\n\nloss_module = DDPGLoss(actor, qnet)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "let\\'s use the TD(lambda) estimator!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "loss_module.make_value_estimator(ValueEstimators.TDLambda, gamma=gamma, lmbda=lmbda, device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

Off-policy usually dictates a TD(0) estimator. Here, we use a TD()estimator, which will introduce some bias as the trajectory that followsa certain state has been collected with an outdated policy.This trick, as the multi-step trick that can be used during data collection,are alternative versions of \"hacks\" that we usually find to work well inpractice despite the fact that they introduce some bias in the returnestimates.

\n```\n```{=html}\n
\n```\nTarget network updater\n======================\n\nTarget networks are a crucial part of off-policy RL algorithms. Updating\nthe target network parameters is made easy thanks to the\n`~torchrl.objectives.HardUpdate`{.interpreted-text role=\"class\"} and\n`~torchrl.objectives.SoftUpdate`{.interpreted-text role=\"class\"}\nclasses. They\\'re built with the loss module as argument, and the update\nis achieved via a call to [updater.step()]{.title-ref} at the\nappropriate location in the training loop.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrl.objectives.utils import SoftUpdate\n\ntarget_net_updater = SoftUpdate(loss_module, eps=1 - tau)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimizer\n=========\n\nFinally, we will use the Adam optimizer for the policy and value\nnetwork:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch import optim\n\noptimizer_actor = optim.Adam(\n loss_module.actor_network_params.values(True, True), lr=1e-4, weight_decay=0.0\n)\noptimizer_value = optim.Adam(\n loss_module.value_network_params.values(True, True), lr=1e-3, weight_decay=1e-2\n)\ntotal_collection_steps = total_frames // frames_per_batch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Time to train the policy\n========================\n\nThe training loop is pretty straightforward now that we have built all\nthe modules we need.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rewards = []\nrewards_eval = []\n\n# Main loop\n\ncollected_frames = 0\npbar = tqdm.tqdm(total=total_frames)\nr0 = None\nfor i, tensordict in enumerate(collector):\n\n # update weights of the inference policy\n collector.update_policy_weights_()\n\n if r0 is None:\n r0 = tensordict[\"next\", \"reward\"].mean().item()\n pbar.update(tensordict.numel())\n\n # extend the replay buffer with the new data\n current_frames = tensordict.numel()\n collected_frames += current_frames\n replay_buffer.extend(tensordict.cpu())\n\n # optimization steps\n if collected_frames >= init_random_frames:\n for _ in range(update_to_data):\n # sample from replay buffer\n sampled_tensordict = replay_buffer.sample().to(device)\n\n # Compute loss\n loss_dict = loss_module(sampled_tensordict)\n\n # optimize\n loss_dict[\"loss_actor\"].backward()\n gn1 = torch.nn.utils.clip_grad_norm_(\n loss_module.actor_network_params.values(True, True), 10.0\n )\n optimizer_actor.step()\n optimizer_actor.zero_grad()\n\n loss_dict[\"loss_value\"].backward()\n gn2 = torch.nn.utils.clip_grad_norm_(\n loss_module.value_network_params.values(True, True), 10.0\n )\n optimizer_value.step()\n optimizer_value.zero_grad()\n\n gn = (gn1**2 + gn2**2) ** 0.5\n\n # update priority\n if prb:\n replay_buffer.update_tensordict_priority(sampled_tensordict)\n # update target network\n target_net_updater.step()\n\n rewards.append(\n (\n i,\n tensordict[\"next\", \"reward\"].mean().item(),\n )\n )\n td_record = recorder(None)\n if td_record is not None:\n rewards_eval.append((i, td_record[\"r_evaluation\"].item()))\n if len(rewards_eval) and collected_frames >= init_random_frames:\n target_value = loss_dict[\"target_value\"].item()\n loss_value = loss_dict[\"loss_value\"].item()\n loss_actor = loss_dict[\"loss_actor\"].item()\n rn = sampled_tensordict[\"next\", \"reward\"].mean().item()\n rs = sampled_tensordict[\"next\", \"reward\"].std().item()\n pbar.set_description(\n f\"reward: {rewards[-1][1]: 4.2f} (r0 = {r0: 4.2f}), \"\n f\"reward eval: reward: {rewards_eval[-1][1]: 4.2f}, \"\n f\"reward normalized={rn :4.2f}/{rs :4.2f}, \"\n f\"grad norm={gn: 4.2f}, \"\n f\"loss_value={loss_value: 4.2f}, \"\n f\"loss_actor={loss_actor: 4.2f}, \"\n f\"target value: {target_value: 4.2f}\"\n )\n\n # update the exploration strategy\n actor_model_explore[1].step(current_frames)\n\ncollector.shutdown()\ndel collector" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Experiment results\n==================\n\nWe make a simple plot of the average rewards during training. We can\nobserve that our policy learned quite well to solve the task.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

As already mentioned above, to get a more reasonable performance,use a greater value for total_frames for example, 1M.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from matplotlib import pyplot as plt\n\nplt.figure()\nplt.plot(*zip(*rewards), label=\"training\")\nplt.plot(*zip(*rewards_eval), label=\"eval\")\nplt.legend()\nplt.xlabel(\"iter\")\nplt.ylabel(\"reward\")\nplt.tight_layout()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, we have learned how to code a loss module in TorchRL\ngiven the concrete example of DDPG.\n\nThe key takeaways are:\n\n- How to use the `~torchrl.objectives.LossModule`{.interpreted-text\n role=\"class\"} class to code up a new loss component;\n- How to use (or not) a target network, and how to update its\n parameters;\n- How to create an optimizer associated with a loss module.\n\nNext Steps\n==========\n\nTo iterate further on this loss module we might consider:\n\n- Using [\\@dispatch]{.title-ref} (see [\\[Feature\\] Distpatch IQL loss\n module](https://github.com/pytorch/rl/pull/1230).)\n- Allowing flexible TensorDict keys.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }