{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchRL objectives: Coding a DDPG loss\n======================================\n\n**Author**: [Vincent Moens](https://github.com/vmoens)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overview\n========\n\nTorchRL separates the training of RL algorithms in various pieces that\nwill be assembled in your training script: the environment, the data\ncollection and storage, the model and finally the loss function.\n\nTorchRL losses (or \\\"objectives\\\") are stateful objects that contain the\ntrainable parameters (policy and value models). This tutorial will guide\nyou through the steps to code a loss from the ground up using TorchRL.\n\nTo this aim, we will be focusing on DDPG, which is a relatively\nstraightforward algorithm to code. [Deep Deterministic Policy\nGradient](https://arxiv.org/abs/1509.02971) (DDPG) is a simple\ncontinuous control algorithm. It consists in learning a parametric value\nfunction for an action-observation pair, and then learning a policy that\noutputs actions that maximize this value function given a certain\nobservation.\n\nWhat you will learn:\n\n- how to write a loss module and customize its value estimator;\n- how to build an environment in TorchRL, including transforms (for\n example, data normalization) and parallel execution;\n- how to design a policy and value network;\n- how to collect data from your environment efficiently and store them\n in a replay buffer;\n- how to store trajectories (and not transitions) in your replay\n buffer);\n- how to evaluate your model.\n\nPrerequisites\n-------------\n\nThis tutorial assumes that you have completed the [PPO\ntutorial](reinforcement_ppo.html) which gives an overview of the TorchRL\ncomponents and dependencies, such as\n`tensordict.TensorDict`{.interpreted-text role=\"class\"} and\n`tensordict.nn.TensorDictModules`{.interpreted-text role=\"class\"},\nalthough it should be sufficiently transparent to be understood without\na deep understanding of these classes.\n\n```{=html}\n
We do not aim at giving a SOTA implementation of the algorithm, but ratherto provide a high-level illustration of TorchRL's loss implementationsand the library features that are to be used in the context ofthis algorithm.
\n```\n```{=html}\nThe reason we return independent losses is to let the user use a differentoptimizer for different sets of parameters for instance. Summing the lossescan be simply done via>>> loss_val = sum(loss for key, loss in loss_dict.items() if key.startswith(\"loss\"))
\n ```\n ```{=html}\nframe_skip
batches multiple step together with a single actionIf > 1, the other frame counts (for example, frames_per_batch, total_frames)need to be adjusted to have a consistent total number of frames collectedacross experiments. This is important as raising the frame-skip but keeping thetotal number of frames unchanged may seem like cheating: all things compared,a dataset of 10M elements collected with a frame-skip of 2 and another witha frame-skip of 1 actually have a ratio of interactions with the environmentof 2:1! In a nutshell, one should be cautious about the frame-count of atraining script when dealing with frame skipping as this may lead tobiased comparisons between training strategies.
The max_frames_per_traj
passed to the collector will have the effectof registering a new ~torchrl.envs.StepCounter
transformwith the environment used for inference. We can achieve the same resultmanually, as we do in this script.
Off-policy usually dictates a TD(0) estimator. Here, we use a TD()estimator, which will introduce some bias as the trajectory that followsa certain state has been collected with an outdated policy.This trick, as the multi-step trick that can be used during data collection,are alternative versions of \"hacks\" that we usually find to work well inpractice despite the fact that they introduce some bias in the returnestimates.
\n```\n```{=html}\nAs already mentioned above, to get a more reasonable performance,use a greater value for total_frames
for example, 1M.