{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Introduction to TorchRec\n========================\n\n**TorchRec** is a PyTorch library tailored for building scalable and\nefficient recommendation systems using embeddings. This tutorial guides\nyou through the installation process, introduces the concept of\nembeddings, and highlights their importance in recommendation systems.\nIt offers practical demonstrations on implementing embeddings with\nPyTorch and TorchRec, focusing on handling large embedding tables\nthrough distributed training and advanced optimizations.\n\n```{=html}\n

What you will learn

Prerequisites

\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Install Dependencies\n====================\n\nBefore running this tutorial in Google Colab or other environment,\ninstall the following dependencies:\n\n``` {.sh}\n!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U\n!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121\n!pip3 install torchmetrics==1.0.3\n!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121\n```\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

If you are running this in Google Colab, make sure to switch to a GPU runtime type.For more information,see Enabling CUDA

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Embeddings\n==========\n\nWhen building recommendation systems, categorical features typically\nhave massive cardinality, posts, users, ads, and so on.\n\nIn order to represent these entities and model these relationships,\n**embeddings** are used. In machine learning, **embeddings are a vectors\nof real numbers in a high-dimensional space used to represent meaning in\ncomplex data like words, images, or users**.\n\nEmbeddings in RecSys\n====================\n\nNow you might wonder, how are these embeddings generated in the first\nplace? Well, embeddings are represented as individual rows in an\n**Embedding Table**, also referred to as embedding weights. The reason\nfor this is that embeddings or embedding table weights are trained just\nlike all of the other weights of the model via gradient descent!\n\nEmbedding tables are simply a large matrix for storing embeddings, with\ntwo dimensions (B, N), where:\n\n- B is the number of embeddings stored by the table\n- N is the number of dimensions per embedding (N-dimensional\n embedding).\n\nThe inputs to embedding tables represent embedding lookups to retrieve\nthe embedding for a specific index or row. In recommendation systems,\nsuch as those used in many large systems, unique IDs are not only used\nfor specific users, but also across entities like posts and ads to serve\nas lookup indices to respective embedding tables!\n\nEmbeddings are trained in RecSys through the following process:\n\n- **Input/lookup indices are fed into the model, as unique IDs**. IDs\n are hashed to the total size of the embedding table to prevent\n issues when the ID \\> number of rows\n- Embeddings are then retrieved and **pooled, such as taking the sum\n or mean of the embeddings**. This is required as there can be a\n variable number of embeddings per example while the model expects\n consistent shapes.\n- The **embeddings are used in conjunction with the rest of the model\n to produce a prediction**, such as [Click-Through Rate\n (CTR)](https://support.google.com/google-ads/answer/2615875?hl=en)\n for an ad.\n- The loss is calculated with the prediction and the label for an\n example, and **all weights of the model are updated through gradient\n descent and backpropagation, including the embedding weights** that\n were associated with the example.\n\nThese embeddings are crucial for representing categorical features, such\nas users, posts, and ads, in order to capture relationships and make\ngood recommendations. The [Deep learning recommendation\nmodel](https://arxiv.org/abs/1906.00091) (DLRM) paper talks more about\nthe technical details of using embedding tables in RecSys.\n\nThis tutorial introduces the concept of embeddings, showcase TorchRec\nspecific modules and data types, and depict how distributed training\nworks with TorchRec.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Embeddings in PyTorch\n=====================\n\nIn PyTorch, we have the following types of embeddings:\n\n- `torch.nn.Embedding`{.interpreted-text role=\"class\"}: An embedding\n table where forward pass returns the embeddings themselves as is.\n- `torch.nn.EmbeddingBag`{.interpreted-text role=\"class\"}: Embedding\n table where forward pass returns embeddings that are then pooled,\n for example, sum or mean, otherwise known as **Pooled Embeddings**.\n\nIn this section, we will go over a very brief introduction to performing\nembedding lookups by passing in indices into the table.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "num_embeddings, embedding_dim = 10, 4\n\n# Initialize our embedding table\nweights = torch.rand(num_embeddings, embedding_dim)\nprint(\"Weights:\", weights)\n\n# Pass in pre-generated weights just for example, typically weights are randomly initialized\nembedding_collection = torch.nn.Embedding(\n num_embeddings, embedding_dim, _weight=weights\n)\nembedding_bag_collection = torch.nn.EmbeddingBag(\n num_embeddings, embedding_dim, _weight=weights\n)\n\n# Print out the tables, we should see the same weights as above\nprint(\"Embedding Collection Table: \", embedding_collection.weight)\nprint(\"Embedding Bag Collection Table: \", embedding_bag_collection.weight)\n\n# Lookup rows (ids for embedding ids) from the embedding tables\n# 2D tensor with shape (batch_size, ids for each batch)\nids = torch.tensor([[1, 3]])\nprint(\"Input row IDS: \", ids)\n\nembeddings = embedding_collection(ids)\n\n# Print out the embedding lookups\n# You should see the specific embeddings be the same as the rows (ids) of the embedding tables above\nprint(\"Embedding Collection Results: \")\nprint(embeddings)\nprint(\"Shape: \", embeddings.shape)\n\n# ``nn.EmbeddingBag`` default pooling is mean, so should be mean of batch dimension of values above\npooled_embeddings = embedding_bag_collection(ids)\n\nprint(\"Embedding Bag Collection Results: \")\nprint(pooled_embeddings)\nprint(\"Shape: \", pooled_embeddings.shape)\n\n# ``nn.EmbeddingBag`` is the same as ``nn.Embedding`` but just with pooling (mean, sum, and so on)\n# We can see that the mean of the embeddings of embedding_collection is the same as the output of the embedding_bag_collection\nprint(\"Mean: \", torch.mean(embedding_collection(ids), dim=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Congratulations! Now you have a basic understanding of how to use\nembedding tables \\-\\-- one of the foundations of modern recommendation\nsystems! These tables represent entities and their relationships. For\nexample, the relationship between a given user and the pages and posts\nthey have liked.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchRec Features Overview\n==========================\n\nIn the section above we\\'ve learned how to use embedding tables, one of\nthe foundations of modern recommendation systems! These tables represent\nentities and relationships, such as users, pages, posts, etc. Given that\nthese entities are always increasing, a **hash** function is typically\napplied to make sure the IDs are within the bounds of a certain\nembedding table. However, in order to represent a vast amount of\nentities and reduce hash collisions, these tables can become quite\nmassive (think about the number of ads for example). In fact, these\ntables can become so massive that they won\\'t be able to fit on 1 GPU,\neven with 80G of memory.\n\nIn order to train models with massive embedding tables, sharding these\ntables across GPUs is required, which then introduces a whole new set of\nproblems and opportunities in parallelism and optimization. Luckily, we\nhave the TorchRec library that has encountered, consolidated, and\naddressed many of these concerns. TorchRec serves as a **library that\nprovides primitives for large scale distributed embeddings**.\n\nNext, we will explore the major features of the TorchRec library. We\nwill start with `torch.nn.Embedding` and will extend that to custom\nTorchRec modules, explore distributed training environment with\ngenerating a sharding plan for embeddings, look at inherent TorchRec\noptimizations, and extend the model to be ready for inference in C++.\nBelow is a quick outline of what this section consists of:\n\n- TorchRec Modules and Data Types\n- Distributed Training, Sharding, and Optimizations\n- Inference\n\nLet\\'s begin with importing TorchRec:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torchrec" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchRec Modules and Data Types\n===============================\n\nThis section goes over TorchRec Modules and data types including such\nentities as `EmbeddingCollection` and `EmbeddingBagCollection`,\n`JaggedTensor`, `KeyedJaggedTensor`, `KeyedTensor` and more.\n\nFrom `EmbeddingBag` to `EmbeddingBagCollection`\n-----------------------------------------------\n\nWe have already explored `torch.nn.Embedding`{.interpreted-text\nrole=\"class\"} and `torch.nn.EmbeddingBag`{.interpreted-text\nrole=\"class\"}. TorchRec extends these modules by creating collections of\nembeddings, in other words modules that can have multiple embedding\ntables, with `EmbeddingCollection` and `EmbeddingBagCollection` We will\nuse `EmbeddingBagCollection` to represent a group of embedding bags.\n\nIn the example code below, we create an `EmbeddingBagCollection` (EBC)\nwith two embedding bags, 1 representing **products** and 1 representing\n**users**. Each table, `product_table` and `user_table`, is represented\nby a 64 dimension embedding of size 4096.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ebc = torchrec.EmbeddingBagCollection(\n device=\"cpu\",\n tables=[\n torchrec.EmbeddingBagConfig(\n name=\"product_table\",\n embedding_dim=64,\n num_embeddings=4096,\n feature_names=[\"product\"],\n pooling=torchrec.PoolingType.SUM,\n ),\n torchrec.EmbeddingBagConfig(\n name=\"user_table\",\n embedding_dim=64,\n num_embeddings=4096,\n feature_names=[\"user\"],\n pooling=torchrec.PoolingType.SUM,\n )\n ]\n)\nprint(ebc.embedding_bags)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's inspect the forward method for `EmbeddingBagCollection` and the\nmodule's inputs and outputs:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import inspect\n\n# Let's look at the ``EmbeddingBagCollection`` forward method\n# What is a ``KeyedJaggedTensor`` and ``KeyedTensor``?\nprint(inspect.getsource(ebc.forward))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "TorchRec Input/Output Data Types\n================================\n\nTorchRec has distinct data types for input and output of its modules:\n`JaggedTensor`, `KeyedJaggedTensor`, and `KeyedTensor`. Now you might\nask, why create new data types to represent sparse features? To answer\nthat question, we must understand how sparse features are represented in\ncode.\n\nSparse features are otherwise known as `id_list_feature` and\n`id_score_list_feature`, and are the **IDs** that will be used as\nindices to an embedding table to retrieve the embedding for that ID. To\ngive a very simple example, imagine a single sparse feature being Ads\nthat a user interacted with. The input itself would be a set of Ad IDs\nthat a user interacted with, and the embeddings retrieved would be a\nsemantic representation of those Ads. The tricky part of representing\nthese features in code is that in each input example, **the number of\nIDs is variable**. One day a user might have interacted with only one ad\nwhile the next day they interact with three.\n\nA simple representation is shown below, where we have a `lengths` tensor\ndenoting how many indices are in an example for a batch and a `values`\ntensor containing the indices themselves.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Batch Size 2\n# 1 ID in example 1, 2 IDs in example 2\nid_list_feature_lengths = torch.tensor([1, 2])\n\n# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2\nid_list_feature_values = torch.tensor([5, 7, 1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, let\\'s look at the offsets as well as what is contained in each\nbatch\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Lengths can be converted to offsets for easy indexing of values\nid_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)\n\nprint(\"Offsets: \", id_list_feature_offsets)\nprint(\"First Batch: \", id_list_feature_values[: id_list_feature_offsets[0]])\nprint(\n \"Second Batch: \",\n id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]],\n)\n\nfrom torchrec import JaggedTensor\n\n# ``JaggedTensor`` is just a wrapper around lengths/offsets and values tensors!\njt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)\n\n# Automatically compute offsets from lengths\nprint(\"Offsets: \", jt.offsets())\n\n# Convert to list of values\nprint(\"List of Values: \", jt.to_dense())\n\n# ``__str__`` representation\nprint(jt)\n\nfrom torchrec import KeyedJaggedTensor\n\n# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``\n# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets\n# From before, we have our two features \"product\" and \"user\". Let's create ``JaggedTensors`` for both!\n\nproduct_jt = JaggedTensor(\n values=torch.tensor([1, 2, 1, 5]), lengths=torch.tensor([3, 1])\n)\nuser_jt = JaggedTensor(values=torch.tensor([2, 3, 4, 1]), lengths=torch.tensor([2, 2]))\n\n# Q1: How many batches are there, and which values are in the first batch for ``product_jt`` and ``user_jt``?\nkjt = KeyedJaggedTensor.from_jt_dict({\"product\": product_jt, \"user\": user_jt})\n\n# Look at our feature keys for the ``KeyedJaggedTensor``\nprint(\"Keys: \", kjt.keys())\n\n# Look at the overall lengths for the ``KeyedJaggedTensor``\nprint(\"Lengths: \", kjt.lengths())\n\n# Look at all values for ``KeyedJaggedTensor``\nprint(\"Values: \", kjt.values())\n\n# Can convert ``KeyedJaggedTensor`` to dictionary representation\nprint(\"to_dict: \", kjt.to_dict())\n\n# ``KeyedJaggedTensor`` string representation\nprint(kjt)\n\n# Q2: What are the offsets for the ``KeyedJaggedTensor``?\n\n# Now we can run a forward pass on our ``EmbeddingBagCollection`` from before\nresult = ebc(kjt)\nresult\n\n# Result is a ``KeyedTensor``, which contains a list of the feature names and the embedding results\nprint(result.keys())\n\n# The results shape is [2, 128], as batch size of 2. Reread previous section if you need a refresher on how the batch size is determined\n# 128 for dimension of embedding. If you look at where we initialized the ``EmbeddingBagCollection``, we have two tables \"product\" and \"user\" of dimension 64 each\n# meaning embeddings for both features are of size 64. 64 + 64 = 128\nprint(result.values().shape)\n\n# Nice to_dict method to determine the embeddings that belong to each feature\nresult_dict = result.to_dict()\nfor key, embedding in result_dict.items():\n print(key, embedding.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Congrats! You now understand TorchRec modules and data types. Give\nyourself a pat on the back for making it this far. Next, we will learn\nabout distributed training and sharding.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Distributed Training and Sharding\n=================================\n\nNow that we have a grasp on TorchRec modules and data types, it\\'s time\nto take it to the next level.\n\nRemember, the main purpose of TorchRec is to provide primitives for\ndistributed embeddings. So far, we\\'ve only worked with embedding tables\non a single device. This has been possible given how small the embedding\ntables have been, but in a production setting this isn\\'t generally the\ncase. Embedding tables often get massive, where one table can\\'t fit on\na single GPU, creating the requirement for multiple devices and a\ndistributed environment.\n\nIn this section, we will explore setting up a distributed environment,\nexactly how actual production training is done, and explore sharding\nembedding tables, all with TorchRec.\n\n**This section will also only use 1 GPU, though it will be treated in a\ndistributed fashion. This is only a limitation for training, as training\nhas a process per GPU. Inference does not run into this requirement**\n\nIn the example code below, we set up our PyTorch distributed\nenvironment.\n\n```{=html}\n
WARNING:
\n```\n```{=html}\n
\n```\n```{=html}\n

If you are running this in Google Colab, you can only call this cell once,calling it again will cause an error as you can only initialize the processgroup once.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\n\nimport torch.distributed as dist\n\n# Set up environment variables for distributed training\n# RANK is which GPU we are on, default 0\nos.environ[\"RANK\"] = \"0\"\n# How many devices in our \"world\", colab notebook can only handle 1 process\nos.environ[\"WORLD_SIZE\"] = \"1\"\n# Localhost as we are training locally\nos.environ[\"MASTER_ADDR\"] = \"localhost\"\n# Port for distributed training\nos.environ[\"MASTER_PORT\"] = \"29500\"\n\n# nccl backend is for GPUs, gloo is for CPUs\ndist.init_process_group(backend=\"gloo\")\n\nprint(f\"Distributed environment initialized: {dist}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Distributed Embeddings\n======================\n\nWe have already worked with the main TorchRec module:\n`EmbeddingBagCollection`. We have examined how it works along with how\ndata is represented in TorchRec. However, we have not yet explored one\nof the main parts of TorchRec, which is **distributed embeddings**.\n\nGPUs are the most popular choice for ML workloads by far today, as they\nare able to do magnitudes more floating point operations/s\n([FLOPs](https://en.wikipedia.org/wiki/FLOPS)) than CPU. However, GPUs\ncome with the limitation of scarce fast memory (HBM which is analogous\nto RAM for CPU), typically, \\~10s of GBs.\n\nA RecSys model can contain embedding tables that far exceed the memory\nlimit for 1 GPU, hence the need for distribution of the embedding tables\nacross multiple GPUs, otherwise known as **model parallel**. On the\nother hand, **data parallel** is where the entire model is replicated on\neach GPU, which each GPU taking in a distinct batch of data for\ntraining, syncing gradients on the backwards pass.\n\nParts of the model that **require less compute but more memory\n(embeddings) are distributed with model parallel** while parts that\n**require more compute and less memory (dense layers, MLP, etc.) are\ndistributed with data parallel**.\n\nSharding\n========\n\nIn order to distribute an embedding table, we split up the embedding\ntable into parts and place those parts onto different devices, also\nknown as \"sharding\".\n\nThere are many ways to shard embedding tables. The most common ways are:\n\n- Table-Wise: the table is placed entirely onto one device\n- Column-Wise: columns of embedding tables are sharded\n- Row-Wise: rows of embedding tables are sharded\n\nSharded Modules\n===============\n\nWhile all of this seems like a lot to deal with and implement, you\\'re\nin luck. **TorchRec provides all the primitives for easy distributed\ntraining and inference**! In fact, TorchRec modules have two\ncorresponding classes for working with any TorchRec module in a\ndistributed environment:\n\n- **The module sharder**: This class exposes a `shard` API that\n handles sharding a TorchRec Module, producing a sharded module.\n - For `EmbeddingBagCollection`, the sharder is\n [EmbeddingBagCollectionSharder](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.EmbeddingBagCollectionSharder)\n- **Sharded module**: This class is a sharded variant of a TorchRec\n module. It has the same input/output as a the regular TorchRec\n module, but much more optimized and works in a distributed\n environment.\n - For `EmbeddingBagCollection`, the sharded variant is\n [ShardedEmbeddingBagCollection](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.embeddingbag.ShardedEmbeddingBagCollection)\n\nEvery TorchRec module has an unsharded and sharded variant.\n\n- The unsharded version is meant to be prototyped and experimented\n with.\n- The sharded version is meant to be used in a distributed environment\n for distributed training and inference.\n\nThe sharded versions of TorchRec modules, for example\n`EmbeddingBagCollection`, will handle everything that is needed for\nModel Parallelism, such as communication between GPUs for distributing\nembeddings to the correct GPUs.\n\nRefresher of our `EmbeddingBagCollection` module\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ebc\n\nfrom torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\nfrom torchrec.distributed.planner import EmbeddingShardingPlanner, Topology\nfrom torchrec.distributed.types import ShardingEnv\n\n# Corresponding sharder for ``EmbeddingBagCollection`` module\nsharder = EmbeddingBagCollectionSharder()\n\n# ``ProcessGroup`` from torch.distributed initialized 2 cells above\npg = dist.GroupMember.WORLD\nassert pg is not None, \"Process group is not initialized\"\n\nprint(f\"Process Group: {pg}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Planner\n=======\n\nBefore we can show how sharding works, we must know about the\n**planner**, which helps us determine the best sharding configuration.\n\nGiven a number of embedding tables and a number of ranks, there are many\ndifferent sharding configurations that are possible. For example, given\n2 embedding tables and 2 GPUs, you can:\n\n- Place 1 table on each GPU\n- Place both tables on a single GPU and no tables on the other\n- Place certain rows and columns on each GPU\n\nGiven all of these possibilities, we typically want a sharding\nconfiguration that is optimal for performance.\n\nThat is where the planner comes in. The planner is able to determine\ngiven the number of embedding tables and the number of GPUs, what is the\noptimal configuration. Turns out, this is incredibly difficult to do\nmanually, with tons of factors that engineers have to consider to ensure\nan optimal sharding plan. Luckily, TorchRec provides an auto planner\nwhen the planner is used.\n\nThe TorchRec planner:\n\n- Assesses memory constraints of hardware\n- Estimates compute based on memory fetches as embedding lookups\n- Addresses data specific factors\n- Considers other hardware specifics like bandwidth to generate an\n optimal sharding plan\n\nIn order to take into consideration all these variables, The TorchRec\nplanner can take in [various amounts of data for embedding tables,\nconstraints, hardware information, and\ntopology](https://github.com/pytorch/torchrec/blob/main/torchrec/distributed/planner/planners.py#L147-L155)\nto aid in generating the optimal sharding plan for a model, which is\nroutinely provided across stacks.\n\nTo learn more about sharding, see our [sharding\ntutorial](https://pytorch.org/tutorials/advanced/sharding.html).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# In our case, 1 GPU and compute on CUDA device\nplanner = EmbeddingShardingPlanner(\n topology=Topology(\n world_size=1,\n compute_device=\"cuda\",\n )\n)\n\n# Run planner to get plan for sharding\nplan = planner.collective_plan(ebc, [sharder], pg)\n\nprint(f\"Sharding Plan generated: {plan}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Planner Result\n==============\n\nAs you can see above, when running the planner there is quite a bit of\noutput. We can see a lot of stats being calculated along with where our\ntables end up being placed.\n\nThe result of running the planner is a static plan, which can be reused\nfor sharding! This allows sharding to be static for production models\ninstead of determining a new sharding plan everytime. Below, we use the\nsharding plan to finally generate our `ShardedEmbeddingBagCollection`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# The static plan that was generated\nplan\n\nenv = ShardingEnv.from_process_group(pg)\n\n# Shard the ``EmbeddingBagCollection`` module using the ``EmbeddingBagCollectionSharder``\nsharded_ebc = sharder.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n\nprint(f\"Sharded EBC Module: {sharded_ebc}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GPU Training with `LazyAwaitable`\n=================================\n\nRemember that TorchRec is a highly optimized library for distributed\nembeddings. A concept that TorchRec introduces to enable higher\nperformance for training on GPU is a\n[LazyAwaitable](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.types.LazyAwaitable).\nYou will see `LazyAwaitable` types as outputs of various sharded\nTorchRec modules. All a `LazyAwaitable` type does is delay calculating\nsome result as long as possible, and it does it by acting like an async\ntype.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from typing import List\n\nfrom torchrec.distributed.types import LazyAwaitable\n\n\n# Demonstrate a ``LazyAwaitable`` type:\nclass ExampleAwaitable(LazyAwaitable[torch.Tensor]):\n def __init__(self, size: List[int]) -> None:\n super().__init__()\n self._size = size\n\n def _wait_impl(self) -> torch.Tensor:\n return torch.ones(self._size)\n\n\nawaitable = ExampleAwaitable([3, 2])\nawaitable.wait()\n\nkjt = kjt.to(\"cuda\")\noutput = sharded_ebc(kjt)\n# The output of our sharded ``EmbeddingBagCollection`` module is an `Awaitable`?\nprint(output)\n\nkt = output.wait()\n# Now we have our ``KeyedTensor`` after calling ``.wait()``\n# If you are confused as to why we have a ``KeyedTensor ``output,\n# give yourself a refresher on the unsharded ``EmbeddingBagCollection`` module\nprint(type(kt))\n\nprint(kt.keys())\n\nprint(kt.values().shape)\n\n# Same output format as unsharded ``EmbeddingBagCollection``\nresult_dict = kt.to_dict()\nfor key, embedding in result_dict.items():\n print(key, embedding.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Anatomy of Sharded TorchRec modules\n===================================\n\nWe have now successfully sharded an `EmbeddingBagCollection` given a\nsharding plan that we generated! The sharded module has common APIs from\nTorchRec which abstract away distributed communication/compute amongst\nmultiple GPUs. In fact, these APIs are highly optimized for performance\nin training and inference. **Below are the three common APIs for\ndistributed training/inference** that are provided by TorchRec:\n\n- `input_dist`: Handles distributing inputs from GPU to GPU.\n- `lookups`: Does the actual embedding lookup in an optimized, batched\n manner using FBGEMM TBE (more on this later).\n- `output_dist`: Handles distributing outputs from GPU to GPU.\n\nThe distribution of inputs and outputs is done through [NCCL\nCollectives](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/overview.html),\nnamely\n[All-to-Alls](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html#all-to-all),\nwhich is where all GPUs send and receive data to and from one another.\nTorchRec interfaces with PyTorch distributed for collectives and\nprovides clean abstractions to the end users, removing the concern for\nthe lower level details.\n\nThe backwards pass does all of these collectives but in the reverse\norder for distribution of gradients. `input_dist`, `lookup`, and\n`output_dist` all depend on the sharding scheme. Since we sharded in a\ntable-wise fashion, these APIs are modules that are constructed by\n[TwPooledEmbeddingSharding](https://pytorch.org/torchrec/torchrec.distributed.sharding.html#torchrec.distributed.sharding.tw_sharding.TwPooledEmbeddingSharding).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sharded_ebc\n\n# Distribute input KJTs to all other GPUs and receive KJTs\nsharded_ebc._input_dists\n\n# Distribute output embeddings to all other GPUs and receive embeddings\nsharded_ebc._output_dists" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Optimizing Embedding Lookups\n============================\n\nIn performing lookups for a collection of embedding tables, a trivial\nsolution would be to iterate through all the `nn.EmbeddingBags` and do a\nlookup per table. This is exactly what the standard, unsharded\n`EmbeddingBagCollection` does. However, while this solution is simple,\nit is extremely slow.\n\n[FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu) is a\nlibrary that provides GPU operators (otherwise known as kernels) that\nare very optimized. One of these operators is known as **Table Batched\nEmbedding** (TBE), provides two major optimizations:\n\n- Table batching, which allows you to look up multiple embeddings with\n one kernel call.\n- Optimizer Fusion, which allows the module to update itself given the\n canonical pytorch optimizers and arguments.\n\nThe `ShardedEmbeddingBagCollection` uses the FBGEMM TBE as the lookup\ninstead of traditional `nn.EmbeddingBags` for optimized embedding\nlookups.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sharded_ebc._lookups" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`DistributedModelParallel`\n==========================\n\nWe have now explored sharding a single `EmbeddingBagCollection`! We were\nable to take the `EmbeddingBagCollectionSharder` and use the unsharded\n`EmbeddingBagCollection` to generate a `ShardedEmbeddingBagCollection`\nmodule. This workflow is fine, but typically when implementing model\nparallel,\n[DistributedModelParallel](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel)\n(DMP) is used as the standard interface. When wrapping your model (in\nour case `ebc`), with DMP, the following will occur:\n\n1. Decide how to shard the model. DMP will collect the available\n sharders and come up with a plan of the optimal way to shard the\n embedding table(s) (for example, `EmbeddingBagCollection`)\n2. Actually shard the model. This includes allocating memory for each\n embedding table on the appropriate device(s).\n\nDMP takes in everything that we\\'ve just experimented with, like a\nstatic sharding plan, a list of sharders, etc. However, it also has some\nnice defaults to seamlessly shard a TorchRec model. In this toy example,\nsince we have two embedding tables and one GPU, TorchRec will place both\non the single GPU.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ebc\n\nmodel = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))\n\nout = model(kjt)\nout.wait()\n\nmodel" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sharding Best Practices\n=======================\n\nCurrently, our configuration is only sharding on 1 GPU (or rank), which\nis trivial: just place all the tables on 1 GPUs memory. However, in real\nproduction use cases, embedding tables are **typically sharded on\nhundreds of GPUs**, with different sharding methods such as table-wise,\nrow-wise, and column-wise. It is incredibly important to determine a\nproper sharding configuration (to prevent out of memory issues) while\nkeeping it balanced not only in terms of memory but also compute for\noptimal performance.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Adding in the Optimizer\n=======================\n\nRemember that TorchRec modules are hyperoptimized for large scale\ndistributed training. An important optimization is in regards to the\noptimizer.\n\nTorchRec modules provide a seamless API to fuse the backwards pass and\noptimize step in training, providing a significant optimization in\nperformance and decreasing the memory used, alongside granularity in\nassigning distinct optimizers to distinct model parameters.\n\nOptimizer Classes\n-----------------\n\nTorchRec uses `CombinedOptimizer`, which contains a collection of\n`KeyedOptimizers`. A `CombinedOptimizer` effectively makes it easy to\nhandle multiple optimizers for various sub groups in the model. A\n`KeyedOptimizer` extends the `torch.optim.Optimizer` and is initialized\nthrough a dictionary of parameters exposes the parameters. Each `TBE`\nmodule in a `EmbeddingBagCollection` will have it\\'s own\n`KeyedOptimizer` which combines into one `CombinedOptimizer`.\n\nFused optimizer in TorchRec\n---------------------------\n\nUsing `DistributedModelParallel`, the **optimizer is fused, which means\nthat the optimizer update is done in the backward**. This is an\noptimization in TorchRec and FBGEMM, where the optimizer embedding\ngradients are not materialized and applied directly to the parameters.\nThis brings significant memory savings as embedding gradients are\ntypically size of the parameters themselves.\n\nYou can, however, choose to make the optimizer `dense` which does not\napply this optimization and let\\'s you inspect the embedding gradients\nor apply computations to it as you wish. A dense optimizer in this case\nwould be your [canonical PyTorch model training loop with\noptimizer.](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html)\n\nOnce the optimizer is created through `DistributedModelParallel`, you\nstill need to manage an optimizer for the other parameters not\nassociated with TorchRec embedding modules. To find the other\nparameters, use\n`in_backward_optimizer_filter(model.named_parameters())`. Apply an\noptimizer to those parameters as you would a normal Torch optimizer and\ncombine this and the `model.fused_optimizer` into one\n`CombinedOptimizer` that you can use in your training loop to\n`zero_grad` and `step` through.\n\nAdding an Optimizer to `EmbeddingBagCollection`\n-----------------------------------------------\n\nWe will do this in two ways, which are equivalent, but give you options\ndepending on your preferences:\n\n1. Passing optimizer kwargs through `fused_params` in sharder.\n2. Through `apply_optimizer_in_backward`, which converts the optimizer\n parameters to `fused_params` to pass to the `TBE` in the\n `EmbeddingBagCollection` or `EmbeddingCollection`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Option 1: Passing optimizer kwargs through fused parameters\nfrom torchrec.optim.optimizers import in_backward_optimizer_filter\nfrom fbgemm_gpu.split_embedding_configs import EmbOptimType\n\n\n# We initialize the sharder with\nfused_params = {\n \"optimizer\": EmbOptimType.EXACT_ROWWISE_ADAGRAD,\n \"learning_rate\": 0.02,\n \"eps\": 0.002,\n}\n\n# Initialize sharder with ``fused_params``\nsharder_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)\n\n# We'll use same plan and unsharded EBC as before but this time with our new sharder\nsharded_ebc_fused_params = sharder_with_fused_params.shard(ebc, plan.plan[\"\"], env, torch.device(\"cuda\"))\n\n# Looking at the optimizer of each, we can see that the learning rate changed, which indicates our optimizer has been applied correctly.\n# If seen, we can also look at the TBE logs of the cell to see that our new optimizer is indeed being applied\nprint(f\"Original Sharded EBC fused optimizer: {sharded_ebc.fused_optimizer}\")\nprint(f\"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}\")\n\nprint(f\"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}\")\n\nfrom torch.distributed.optim import _apply_optimizer_in_backward as apply_optimizer_in_backward\nimport copy\n# Option 2: Applying optimizer through apply_optimizer_in_backward\n# Note: we need to call apply_optimizer_in_backward on unsharded model first and then shard it\n\n# We can achieve the same result as we did in the previous\nebc_apply_opt = copy.deepcopy(ebc)\noptimizer_kwargs = {\"lr\": 0.5}\n\nfor name, param in ebc_apply_opt.named_parameters():\n print(f\"{name=}\")\n apply_optimizer_in_backward(torch.optim.SGD, [param], optimizer_kwargs)\n\nsharded_ebc_apply_opt = sharder.shard(ebc_apply_opt, plan.plan[\"\"], env, torch.device(\"cuda\"))\n\n# Now when we print the optimizer, we will see our new learning rate, you can verify momentum through the TBE logs as well if outputted\nprint(sharded_ebc_apply_opt.fused_optimizer)\nprint(type(sharded_ebc_apply_opt.fused_optimizer))\n\n# We can also check through the filter other parameters that aren't associated with the \"fused\" optimizer(s)\n# Practically, just non TorchRec module parameters. Since our module is just a TorchRec EBC\n# there are no other parameters that aren't associated with TorchRec\nprint(\"Non Fused Model Parameters:\")\nprint(dict(in_backward_optimizer_filter(sharded_ebc_fused_params.named_parameters())).keys())\n\n# Here we do a dummy backwards call and see that parameter updates for fused\n# optimizers happen as a result of the backward pass\n\nebc_output = sharded_ebc_fused_params(kjt).wait().values()\nloss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\nprint(f\"First Iteration Loss: {loss}\")\n\nloss.backward()\n\nebc_output = sharded_ebc_fused_params(kjt).wait().values()\nloss = torch.sum(torch.ones_like(ebc_output) - ebc_output)\n# We don't call an optimizer.step(), so for the loss to have changed here,\n# that means that the gradients were somehow updated, which is what the\n# fused optimizer automatically handles for us\nprint(f\"Second Iteration Loss: {loss}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Inference\n=========\n\nNow that we are able to train distributed embeddings, how can we take\nthe trained model and optimize it for inference? Inference is typically\nvery sensitive to **performance and size of the model**. Running just\nthe trained model in a Python environment is incredibly inefficient.\nThere are two key differences between inference and training\nenvironments:\n\n- **Quantization**: Inference models are typically quantized, where\n model parameters lose precision for lower latency in predictions and\n reduced model size. For example FP32 (4 bytes) in trained model to\n INT8 (1 byte) for each embedding weight. This is also necessary\n given the vast scale of embedding tables, as we want to use as few\n devices as possible for inference to minimize latency.\n- **C++ environment**: Inference latency is very important, so in\n order to ensure ample performance, the model is typically ran in a\n C++ environment, along with the situations where we don\\'t have a\n Python runtime, like on device.\n\nTorchRec provides primitives for converting a TorchRec model into being\ninference ready with:\n\n- APIs for quantizing the model, introducing optimizations\n automatically with FBGEMM TBE\n- Sharding embeddings for distributed inference\n- Compiling the model to\n [TorchScript](https://pytorch.org/docs/stable/jit.html) (compatible\n in C++)\n\nIn this section, we will go over this entire workflow of:\n\n- Quantizing the model\n- Sharding the quantized model\n- Compiling the sharded quantized model into TorchScript\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "ebc\n\nclass InferenceModule(torch.nn.Module):\n def __init__(self, ebc: torchrec.EmbeddingBagCollection):\n super().__init__()\n self.ebc_ = ebc\n\n def forward(self, kjt: KeyedJaggedTensor):\n return self.ebc_(kjt)\n\nmodule = InferenceModule(ebc)\nfor name, param in module.named_parameters():\n # Here, the parameters should still be FP32, as we are using a standard EBC\n # FP32 is default, regularly used for training\n print(name, param.shape, param.dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Quantization\n============\n\nAs you can see above, the normal EBC contains embedding table weights as\nFP32 precision (32 bits for each weight). Here, we will use the TorchRec\ninference library to quantize the embedding weights of the model to INT8\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch import quantization as quant\nfrom torchrec.modules.embedding_configs import QuantConfig\nfrom torchrec.quant.embedding_modules import (\n EmbeddingBagCollection as QuantEmbeddingBagCollection,\n)\n\n\nquant_dtype = torch.int8\n\n\nqconfig = QuantConfig(\n # dtype of the result of the embedding lookup, post activation\n # torch.float generally for compatibility with rest of the model\n # as rest of the model here usually isn't quantized\n activation=quant.PlaceholderObserver.with_args(dtype=torch.float),\n # quantized type for embedding weights, aka parameters to actually quantize\n weight=quant.PlaceholderObserver.with_args(dtype=quant_dtype),\n)\nqconfig_spec = {\n # Map of module type to qconfig\n torchrec.EmbeddingBagCollection: qconfig,\n}\nmapping = {\n # Map of module type to quantized module type\n torchrec.EmbeddingBagCollection: QuantEmbeddingBagCollection,\n}\n\n\nmodule = InferenceModule(ebc)\n\n# Quantize the module\nqebc = quant.quantize_dynamic(\n module,\n qconfig_spec=qconfig_spec,\n mapping=mapping,\n inplace=False,\n)\n\n\nprint(f\"Quantized EBC: {qebc}\")\n\nkjt = kjt.to(\"cpu\")\n\nqebc(kjt)\n\n# Once quantized, goes from parameters -> buffers, as no longer trainable\nfor name, buffer in qebc.named_buffers():\n # The shapes of the tables should be the same but the dtype should be int8 now\n # post quantization\n print(name, buffer.shape, buffer.dtype)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Shard\n=====\n\nHere we perform sharding of the TorchRec quantized model. This is to\nensure we are using the performant module through FBGEMM TBE. Here we\nare using one device to be consistent with training (1 TBE).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrec import distributed as trec_dist\nfrom torchrec.distributed.shard import _shard_modules\n\n\nsharded_qebc = _shard_modules(\n module=qebc,\n device=torch.device(\"cpu\"),\n env=trec_dist.ShardingEnv.from_local(\n 1,\n 0,\n ),\n)\n\n\nprint(f\"Sharded Quantized EBC: {sharded_qebc}\")\n\nsharded_qebc(kjt)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Compilation\n===========\n\nNow we have the optimized eager TorchRec inference model. The next step\nis to ensure that this model is loadable in C++, as currently it is only\nrunnable in a Python runtime.\n\nThe recommended method of compilation at Meta is two fold: [torch.fx\ntracing](https://pytorch.org/docs/stable/fx.html) (generate intermediate\nrepresentation of model) and converting the result to TorchScript, where\nTorchScript is C++ compatible.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchrec.fx import Tracer\n\n\ntracer = Tracer(leaf_modules=[\"IntNBitTableBatchedEmbeddingBagsCodegen\"])\n\ngraph = tracer.trace(sharded_qebc)\ngm = torch.fx.GraphModule(sharded_qebc, graph)\n\nprint(\"Graph Module Created!\")\n\nprint(gm.code)\n\nscripted_gm = torch.jit.script(gm)\nprint(\"Scripted Graph Module Created!\")\n\nprint(scripted_gm.code)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, you have gone from training a distributed RecSys model\nall the way to making it inference ready. The [TorchRec\nrepo](https://github.com/pytorch/torchrec/tree/main/torchrec/inference)\nhas a full example of how to load a TorchRec TorchScript model into C++\nfor inference.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "See Also\n========\n\nFor more information, please see our\n[dlrm](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/)\nexample, which includes multinode training on the Criteo 1TB dataset\nusing the methods described in [Deep Learning Recommendation Model for\nPersonalization and Recommendation\nSystems](https://arxiv.org/abs/1906.00091).\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }