{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Getting Started with Nested Tensors\n===================================\n\nNested tensors generalize the shape of regular dense tensors, allowing\nfor representation of ragged-sized data.\n\n- for a regular tensor, each dimension is regular and has a size\n- for a nested tensor, not all dimensions have regular sizes; some of\n them are ragged\n\nNested tensors are a natural solution for representing sequential data\nwithin various domains:\n\n- in NLP, sentences can have variable lengths, so a batch of sentences\n forms a nested tensor\n- in CV, images can have variable shapes, so a batch of images forms a\n nested tensor\n\nIn this tutorial, we will demonstrate basic usage of nested tensors and\nmotivate their usefulness for operating on sequential data of varying\nlengths with a real-world example. In particular, they are invaluable\nfor building transformers that can efficiently operate on ragged\nsequential inputs. Below, we present an implementation of multi-head\nattention using nested tensors that, combined usage of `torch.compile`,\nout-performs operating naively on tensors with padding.\n\nNested tensors are currently a prototype feature and are subject to\nchange.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport timeit\nimport torch\nimport torch.nn.functional as F\n\nfrom torch import nn\n\ntorch.manual_seed(1)\nnp.random.seed(1)\n\ndevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nested tensor initialization\n============================\n\nFrom the Python frontend, a nested tensor can be created from a list of\ntensors. We denote nt\\[i\\] as the ith tensor component of a\nnestedtensor.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nt = torch.nested.nested_tensor([torch.arange(12).reshape(\n 2, 6), torch.arange(18).reshape(3, 6)], dtype=torch.float, device=device)\nprint(f\"{nt=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By padding every underlying tensor to the same shape, a nestedtensor can\nbe converted to a regular tensor.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "padded_out_tensor = torch.nested.to_padded_tensor(nt, padding=0.0)\nprint(f\"{padded_out_tensor=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "All tensors posses an attribute for determining if they are nested;\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(f\"nt is nested: {nt.is_nested}\")\nprint(f\"padded_out_tensor is nested: {padded_out_tensor.is_nested}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It is common to construct nestedtensors from batches of irregularly\nshaped tensors. i.e. dimension 0 is assumed to be the batch dimension.\nIndexing dimension 0 gives back the first underlying tensor component.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(\"First underlying tensor component:\", nt[0], sep='\\n')\nprint(\"last column of 2nd underlying tensor component:\", nt[1, :, -1], sep='\\n')\n\n# When indexing a nestedtensor's 0th dimension, the result is a regular tensor.\nprint(f\"First underlying tensor component is nested: {nt[0].is_nested}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An important note is that slicing in dimension 0 has not been supported\nyet. Which means it not currently possible to construct a view that\ncombines the underlying tensor components.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Nested Tensor Operations\n========================\n\nAs each operation must be explicitly implemented for nestedtensors,\noperation coverage for nestedtensors is currently narrower than that of\nregular tensors. For now, only basic operations such as index, dropout,\nsoftmax, transpose, reshape, linear, bmm are covered. However, coverage\nis being expanded. If you need certain operations, please file an\n[issue](https://github.com/pytorch/pytorch) to help us prioritize\ncoverage.\n\n**reshape**\n\nThe reshape op is for changing the shape of a tensor. Its full semantics\nfor regular tensors can be found\n[here](https://pytorch.org/docs/stable/generated/torch.reshape.html).\nFor regular tensors, when specifying the new shape, a single dimension\nmay be -1, in which case it is inferred from the remaining dimensions\nand the number of elements.\n\nThe semantics for nestedtensors are similar, except that -1 no longer\ninfers. Instead, it inherits the old size (here 2 for `nt[0]` and 3 for\n`nt[1]`). -1 is the only legal size to specify for a jagged dimension.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nt_reshaped = nt.reshape(2, -1, 2, 3)\nprint(f\"{nt_reshaped=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**transpose**\n\nThe transpose op is for swapping two dimensions of a tensor. Its full\nsemantics can be found\n[here](https://pytorch.org/docs/stable/generated/torch.transpose.html).\nNote that for nestedtensors dimension 0 is special; it is assumed to be\nthe batch dimension, so transposes involving nestedtensor dimension 0\nare not supported.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nt_transposed = nt_reshaped.transpose(1, 2)\nprint(f\"{nt_transposed=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**others**\n\nOther operations have the same semantics as for regular tensors.\nApplying the operation on a nestedtensor is equivalent to applying the\noperation to the underlying tensor components, with the result being a\nnestedtensor as well.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "nt_mm = torch.nested.nested_tensor([torch.randn((2, 3, 4)), torch.randn((2, 3, 5))], device=device)\nnt3 = torch.matmul(nt_transposed, nt_mm)\nprint(f\"Result of Matmul:\\n {nt3}\")\n\nnt4 = F.dropout(nt3, 0.1)\nprint(f\"Result of Dropout:\\n {nt4}\")\n\nnt5 = F.softmax(nt4, -1)\nprint(f\"Result of Softmax:\\n {nt5}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Why Nested Tensor\n=================\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When data is sequential, it is often the case that each sample has a\ndifferent length. For example, in a batch of sentences, each sentence\nhas a different number of words. A common technique for handling varying\nsequences is to manually pad each data tensor to the same shape in order\nto form a batch. For example, we have 2 sentences with different lengths\nand a vocabulary In order to represent his as single tensor we pad with\n0 to the max length in the batch.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sentences = [[\"goodbye\", \"padding\"],\n [\"embrace\", \"nested\", \"tensor\"]]\nvocabulary = {\"goodbye\": 1.0, \"padding\": 2.0,\n \"embrace\": 3.0, \"nested\": 4.0, \"tensor\": 5.0}\npadded_sentences = torch.tensor([[1.0, 2.0, 0.0],\n [3.0, 4.0, 5.0]])\nnested_sentences = torch.nested.nested_tensor([torch.tensor([1.0, 2.0]),\n torch.tensor([3.0, 4.0, 5.0])])\nprint(f\"{padded_sentences=}\")\nprint(f\"{nested_sentences=}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This technique of padding a batch of data to its max length is not\noptimal. The padded data is not needed for computation and wastes memory\nby allocating larger tensors than necessary. Further, not all operations\nhave the same semnatics when applied to padded data. For matrix\nmultiplications in order to ignore the padded entries, one needs to pad\nwith 0 while for softmax one has to pad with -inf to ignore specific\nentries. The primary objective of nested tensor is to facilitate\noperations on ragged data using the standard PyTorch tensor UX, thereby\neliminating the need for inefficient and complex padding and masking.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "padded_sentences_for_softmax = torch.tensor([[1.0, 2.0, float(\"-inf\")],\n [3.0, 4.0, 5.0]])\nprint(F.softmax(padded_sentences_for_softmax, -1))\nprint(F.softmax(nested_sentences, -1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us take a look at a practical example: the multi-head attention\ncomponent utilized in\n[Transformers](https://arxiv.org/pdf/1706.03762.pdf). We can implement\nthis in such a way that it can operate on either padded or nested\ntensors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MultiHeadAttention(nn.Module):\n \"\"\"\n Computes multi-head attention. Supports nested or padded tensors.\n\n Args:\n E_q (int): Size of embedding dim for query\n E_k (int): Size of embedding dim for key\n E_v (int): Size of embedding dim for value\n E_total (int): Total embedding dim of combined heads post input projection. Each head\n has dim E_total // nheads\n nheads (int): Number of heads\n dropout_p (float, optional): Dropout probability. Default: 0.0\n \"\"\"\n def __init__(self, E_q: int, E_k: int, E_v: int, E_total: int,\n nheads: int, dropout_p: float = 0.0):\n super().__init__()\n self.nheads = nheads\n self.dropout_p = dropout_p\n self.query_proj = nn.Linear(E_q, E_total)\n self.key_proj = nn.Linear(E_k, E_total)\n self.value_proj = nn.Linear(E_v, E_total)\n E_out = E_q\n self.out_proj = nn.Linear(E_total, E_out)\n assert E_total % nheads == 0, \"Embedding dim is not divisible by nheads\"\n self.E_head = E_total // nheads\n\n def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:\n \"\"\"\n Forward pass; runs the following process:\n 1. Apply input projection\n 2. Split heads and prepare for SDPA\n 3. Run SDPA\n 4. Apply output projection\n\n Args:\n query (torch.Tensor): query of shape (N, L_t, E_q)\n key (torch.Tensor): key of shape (N, L_s, E_k)\n value (torch.Tensor): value of shape (N, L_s, E_v)\n\n Returns:\n attn_output (torch.Tensor): output of shape (N, L_t, E_q)\n \"\"\"\n # Step 1. Apply input projection\n # TODO: demonstrate packed projection\n query = self.query_proj(query)\n key = self.key_proj(key)\n value = self.value_proj(value)\n\n # Step 2. Split heads and prepare for SDPA\n # reshape query, key, value to separate by head\n # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)\n query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)\n key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)\n value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)\n\n # Step 3. Run SDPA\n # (N, nheads, L_t, E_head)\n attn_output = F.scaled_dot_product_attention(\n query, key, value, dropout_p=dropout_p, is_causal=True)\n # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)\n attn_output = attn_output.transpose(1, 2).flatten(-2)\n\n # Step 4. Apply output projection\n # (N, L_t, E_total) -> (N, L_t, E_out)\n attn_output = self.out_proj(attn_output)\n\n return attn_output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "set hyperparameters following [the Transformer\npaper](https://arxiv.org/pdf/1706.03762.pdf)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N = 512\nE_q, E_k, E_v, E_total = 512, 512, 512, 512\nE_out = E_q\nnheads = 8" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "except for dropout probability: set to 0 for correctness check\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "dropout_p = 0.0" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let us generate some realistic fake data from Zipf\\'s law.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:\n # generate fake corpus by unigram Zipf distribution\n # from wikitext-2 corpus, we get rank \".\" = 3, \"!\" = 386, \"?\" = 858\n sentence_lengths = np.empty(batch_size, dtype=int)\n for ibatch in range(batch_size):\n sentence_lengths[ibatch] = 1\n word = np.random.zipf(alpha)\n while word != 3 and word != 386 and word != 858:\n sentence_lengths[ibatch] += 1\n word = np.random.zipf(alpha)\n return torch.tensor(sentence_lengths)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create nested tensor batch inputs\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def gen_batch(N, E_q, E_k, E_v, device):\n # generate semi-realistic data using Zipf distribution for sentence lengths\n sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)\n\n # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged\n # dimension and works with torch.compile. The batch items each have shape (B, S*, D)\n # where B = batch size, S* = ragged sequence length, and D = embedding dimension.\n query = torch.nested.nested_tensor([\n torch.randn(l.item(), E_q, device=device)\n for l in sentence_lengths\n ], layout=torch.jagged)\n\n key = torch.nested.nested_tensor([\n torch.randn(s.item(), E_k, device=device)\n for s in sentence_lengths\n ], layout=torch.jagged)\n\n value = torch.nested.nested_tensor([\n torch.randn(s.item(), E_v, device=device)\n for s in sentence_lengths\n ], layout=torch.jagged)\n\n return query, key, value, sentence_lengths\n\nquery, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate padded forms of query, key, value for comparison\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def jagged_to_padded(jt, padding_val):\n # TODO: do jagged -> padded directly when this is supported\n return torch.nested.to_padded_tensor(\n torch.nested.nested_tensor(list(jt.unbind())),\n padding_val)\n\npadded_query, padded_key, padded_value = (\n jagged_to_padded(t, 0.0) for t in (query, key, value)\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Construct the model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads, dropout_p).to(device=device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check correctness and performance\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "def benchmark(func, *args, **kwargs):\n torch.cuda.synchronize()\n begin = timeit.default_timer()\n output = func(*args, **kwargs)\n torch.cuda.synchronize()\n end = timeit.default_timer()\n return output, (end - begin)\n\noutput_nested, time_nested = benchmark(mha, query, key, value)\noutput_padded, time_padded = benchmark(mha, padded_query, padded_key, padded_value)\n\n# padding-specific step: remove output projection bias from padded entries for fair comparison\nfor i, entry_length in enumerate(sentence_lengths):\n output_padded[i, entry_length:] = 0.0\n\nprint(\"=== without torch.compile ===\")\nprint(\"nested and padded calculations differ by\", (jagged_to_padded(output_nested, 0.0) - output_padded).abs().max().item())\nprint(\"nested tensor multi-head attention takes\", time_nested, \"seconds\")\nprint(\"padded tensor multi-head attention takes\", time_padded, \"seconds\")\n\n# warm up compile first...\ncompiled_mha = torch.compile(mha)\ncompiled_mha(query, key, value)\n# ...now benchmark\ncompiled_output_nested, compiled_time_nested = benchmark(\n compiled_mha, query, key, value)\n\n# warm up compile first...\ncompiled_mha(padded_query, padded_key, padded_value)\n# ...now benchmark\ncompiled_output_padded, compiled_time_padded = benchmark(\n compiled_mha, padded_query, padded_key, padded_value)\n\n# padding-specific step: remove output projection bias from padded entries for fair comparison\nfor i, entry_length in enumerate(sentence_lengths):\n compiled_output_padded[i, entry_length:] = 0.0\n\nprint(\"=== with torch.compile ===\")\nprint(\"nested and padded calculations differ by\", (jagged_to_padded(compiled_output_nested, 0.0) - compiled_output_padded).abs().max().item())\nprint(\"nested tensor multi-head attention takes\", compiled_time_nested, \"seconds\")\nprint(\"padded tensor multi-head attention takes\", compiled_time_padded, \"seconds\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that without `torch.compile`, the overhead of the python subclass\nnested tensor can make it slower than the equivalent computation on\npadded tensors. However, once `torch.compile` is enabled, operating on\nnested tensors gives a multiple x speedup. Avoiding wasted computation\non padding becomes only more valuable as the percentage of padding in\nthe batch increases.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "print(f\"Nested speedup: {compiled_time_padded / compiled_time_nested:.3f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, we have learned how to perform basic operations with\nnested tensors and how implement multi-head attention for transformers\nin a way that avoids computation on padding. For more information, check\nout the docs for the\n[torch.nested](https://pytorch.org/docs/stable/nested.html) namespace.\n\nSee Also\n========\n\n- [Accelerating PyTorch Transformers by replacing nn.Transformer with\n Nested Tensors and\n torch.compile](https://docs.pytorch.org/tutorials/intermediate/transformer_building_blocks.html)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }