{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "::: {.meta description=\"Learn how to optimize transformer models by replacing nn.Transformer with Nested Tensors and torch.compile() for significant performance gains in PyTorch.\"}\n:::\n\nAccelerating PyTorch Transformers by replacing `nn.Transformer` with Nested Tensors and `torch.compile()`\n=========================================================================================================\n\n**Author:** [Mikayla Gawarecki](https://github.com/mikaylagawarecki)\n\n```{=html}\n

What you will learn

Prerequisites

\n```\nOver the past few years, the PyTorch team has developed various lower\nlevel features that, when composed, can create a variety of transformer\nvariants. These include:\n\n- Nested Tensors with the `torch.jagged` layout (AKA NJTs)\n- `scaled_dot_product_attention`\n- `torch.compile()`\n- `FlexAttention`\n\nThis tutorial will give a brief overview of the above technologies and\ndemonstrate how they can be composed to yield flexible and performant\ntransformer layers with improved user experience.\n\nOne may observe that the `torch.nn` module currently provides various\n`Transformer`-related layers. In particular, it includes\n`TransformerEncoderLayer`, `TransformerEncoder`,\n`TransformerDecoderLayer`, `TransformerDecoder`, `Transformer` and\n`MultiheadAttention`. This family of layers was initially implemented\nfollowing the [Attention is All You\nNeed](https://arxiv.org/abs/1706.03762) paper. The components discussed\nin this tutorial provide improved user experience, flexibility and\nperformance over the existing `nn` layers.\n\nIs this tutorial for me?\n========================\n\nIf you are wondering about what building blocks the `torch` library\nprovides for writing your own transformer layers and best practices, you\nare in the right place. Please keep reading!\n\nIf you are looking for an out-of-the-box implementation of a popular\ntransformer architecture, note that there are many open-source libraries\nthat provide them, including:\n\n- [HuggingFace\n transformers](https://github.com/huggingface/transformers)\n- [xformers](https://github.com/facebookresearch/xformers)\n- [torchtune](https://github.com/pytorch/torchtune)\n\nIf you are only interested in performant attention score modifications,\nplease check out the [FlexAttention\nblog](https://pytorch.org/blog/flexattention/) that contains a [gym of\nmasks](https://github.com/pytorch-labs/attention-gym).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Introducing the Building Blocks\n===============================\n\nFirst, we will briefly introduce the four technologies mentioned in the\nintroduction\n\n- [torch.nested](https://pytorch.org/tutorials/prototype/nestedtensor.html)\n\nNested tensors generalize the shape of regular dense tensors, allowing\nfor representation of ragged-sized data with the same tensor UX. In the\ncontext of transformers, we can think of nested tensors as a tool for\nrepresenting variable sequence lengths. They eliminate the need for the\nbug-prone practices of explicit padding and masking (think\n`key_padding_mask` in `nn.MultiHeadAttention`).\n\n- [scaled\\_dot\\_product\\_attention](https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html)\n\n`scaled_dot_product_attention` is a primitive for\n$\\text{softmax}(\\frac{QK^T}{\\sqrt{E}} + B)V$ that dispatches into either\nfused implementations of the operator or a fallback implementation. It\nworks out of the box in eager mode (i.e. the default mode of using\nPyTorch where operations are executed on the fly as they are\nencountered) and also integrates seamlessly with `torch.compile()`. As\nof 2.6, it will also offer grouped query attention natively.\n\n- [torch.compile()](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)\n\n`torch.compile()` is a compiler introduced in version 2.0 that is able\nto capture a graph of PyTorch code and perform various optimizations on\nit, such as fusing together sequences of ops. Nested tensors with the\n`torch.jagged` layout and `scaled_dot_product_attention` work seamlessly\nwith compile. In the context of transformers, the value add of using\ncompile with nested tensor and SDPA is that compile can remove framework\noverhead ones sees in eager mode and fuse sequences of ops in\ntransformers together, such as projection and activation.\n\n- [FlexAttention](https://pytorch.org/blog/flexattention/)\n\n`FlexAttention` is a primitive that allows users to modify attention\nscores prior to the softmax operation. It generalizes the additive `B`\nterm above for `scaled_dot_product_attention`, allowing for arbitrary\ncalculation. It requires compile to achieve good performance.\n\nThe above building blocks are \\\"All You Need\\\" (as of October 2024)\n===================================================================\n\nThe main premise in this section is that most transformer variations are\nGPT-style, consisting of layers like Embedding, Positional Encoding,\nAttention Blocks and Feed Forward networks. If we were to try to\nclassify the differences in this space, we might land on something like:\n\n1. Layer type (activation functions such as `SwiGLU` and others,\n normalization functions such as `RMSNorm` and others, positional\n encodings, such as Sinusoidal, Rotary.)\n2. Layer ordering, such as where to apply norms and positional\n encoding.\n3. Modifications to attention score, such as `ALiBi`, Relative\n Positional Bias and so on.\n\nIn a pre-compiler environment, you might write a custom transformer and\nnotice that it functions correctly but is slow. To address this, you\nmight develop a custom fused kernel for the specific series of\noperations. In a compiler environment, you can simply perform the\ninitial step and then compile and benefit from improved performance.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "MultiheadAttention\n==================\n\nRemember that MultiheadAttention takes in a query, key, and value, and\nconsists of an input projection, a `scaled_dot_product_attention`\noperator and an output projection. The main takeaway we want to\ndemonstrate here is the improvement yielded when we replaced\npadded/masked inputs with nested tensors. The improvements are\nthreefold:\n\n- **User Experience** Remember that `nn.MultiheadAttention` requires\n `query`, `key`, and `value` to be dense `torch.Tensors`. It also\n provides a `key_padding_mask` that is used to mask out padding\n tokens in the `key` that arise due to different sequence lengths\n within a batch. Since there is no `query_padding_mask` in `nn.MHA`,\n users have to take care to mask/slice the outputs appropriately to\n account for query sequence lengths. `NestedTensor` cleanly removes\n the need for this sort of error-prone padding masks.\n- **Memory** Instead of materializing a dense `[B, S, D]` tensor with\n a `[B, S]` padding mask (where `B` is batch size, `S` is max\n sequence length in the batch and `D` is embedding size), nested\n tensors allow you to cleanly represent the batch of varying sequence\n lengths. As a result, the input and intermediate activations will\n use less memory.\n- **Performance** Since padding is not materialized and unnecessary\n computation on padding is skipped, performance and memory usage\n improve.\n\nWe\\'ll demonstrate the above by building upon the `MultiheadAttention`\nlayer in the [Nested Tensor\ntutorial](https://pytorch.org/tutorials/prototype/nestedtensor.html) and\ncomparing it to the `nn.MultiheadAttention` layer.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass MultiHeadAttention(nn.Module):\n \"\"\"\n Computes multi-head attention. Supports nested or padded tensors.\n\n Args:\n E_q (int): Size of embedding dim for query\n E_k (int): Size of embedding dim for key\n E_v (int): Size of embedding dim for value\n E_total (int): Total embedding dim of combined heads post input projection. Each head\n has dim E_total // nheads\n nheads (int): Number of heads\n dropout (float, optional): Dropout probability. Default: 0.0\n bias (bool, optional): Whether to add bias to input projection. Default: True\n \"\"\"\n\n def __init__(\n self,\n E_q: int,\n E_k: int,\n E_v: int,\n E_total: int,\n nheads: int,\n dropout: float = 0.0,\n bias=True,\n device=None,\n dtype=None,\n ):\n factory_kwargs = {\"device\": device, \"dtype\": dtype}\n super().__init__()\n self.nheads = nheads\n self.dropout = dropout\n self._qkv_same_embed_dim = E_q == E_k and E_q == E_v\n if self._qkv_same_embed_dim:\n self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)\n else:\n self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)\n self.k_proj = nn.Linear(E_k, E_total, bias=bias, **factory_kwargs)\n self.v_proj = nn.Linear(E_v, E_total, bias=bias, **factory_kwargs)\n E_out = E_q\n self.out_proj = nn.Linear(E_total, E_out, bias=bias, **factory_kwargs)\n assert E_total % nheads == 0, \"Embedding dim is not divisible by nheads\"\n self.E_head = E_total // nheads\n self.bias = bias\n\n def forward(\n self,\n query: torch.Tensor,\n key: torch.Tensor,\n value: torch.Tensor,\n attn_mask=None,\n is_causal=False,\n ) -> torch.Tensor:\n \"\"\"\n Forward pass; runs the following process:\n 1. Apply input projection\n 2. Split heads and prepare for SDPA\n 3. Run SDPA\n 4. Apply output projection\n\n Args:\n query (torch.Tensor): query of shape (``N``, ``L_q``, ``E_qk``)\n key (torch.Tensor): key of shape (``N``, ``L_kv``, ``E_qk``)\n value (torch.Tensor): value of shape (``N``, ``L_kv``, ``E_v``)\n attn_mask (torch.Tensor, optional): attention mask of shape (``N``, ``L_q``, ``L_kv``) to pass to SDPA. Default: None\n is_causal (bool, optional): Whether to apply causal mask. Default: False\n\n Returns:\n attn_output (torch.Tensor): output of shape (N, L_t, E_q)\n \"\"\"\n # Step 1. Apply input projection\n if self._qkv_same_embed_dim:\n if query is key and key is value:\n result = self.packed_proj(query)\n query, key, value = torch.chunk(result, 3, dim=-1)\n else:\n q_weight, k_weight, v_weight = torch.chunk(\n self.packed_proj.weight, 3, dim=0\n )\n if self.bias:\n q_bias, k_bias, v_bias = torch.chunk(\n self.packed_proj.bias, 3, dim=0\n )\n else:\n q_bias, k_bias, v_bias = None, None, None\n query, key, value = (\n F.linear(query, q_weight, q_bias),\n F.linear(key, k_weight, k_bias),\n F.linear(value, v_weight, v_bias),\n )\n\n else:\n query = self.q_proj(query)\n key = self.k_proj(key)\n value = self.v_proj(value)\n\n # Step 2. Split heads and prepare for SDPA\n # reshape query, key, value to separate by head\n # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)\n query = query.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)\n key = key.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)\n # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)\n value = value.unflatten(-1, [self.nheads, self.E_head]).transpose(1, 2)\n\n # Step 3. Run SDPA\n # (N, nheads, L_t, E_head)\n attn_output = F.scaled_dot_product_attention(\n query, key, value, dropout_p=self.dropout, is_causal=is_causal\n )\n # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)\n attn_output = attn_output.transpose(1, 2).flatten(-2)\n\n # Step 4. Apply output projection\n # (N, L_t, E_total) -> (N, L_t, E_out)\n attn_output = self.out_proj(attn_output)\n\n return attn_output" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Utilities\n=========\n\nIn this section, we include a utility to generate semi-realistic data\nusing `Zipf` distribution for sentence lengths. This is used to generate\nthe nested query, key, and value tensors. We also include a benchmark\nutility.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\n\n\ndef zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:\n # generate fake corpus by unigram Zipf distribution\n # from wikitext-2 corpus, we get rank \".\" = 3, \"!\" = 386, \"?\" = 858\n sentence_lengths = np.empty(batch_size, dtype=int)\n for ibatch in range(batch_size):\n sentence_lengths[ibatch] = 1\n word = np.random.zipf(alpha)\n while word != 3 and word != 386 and word != 858:\n sentence_lengths[ibatch] += 1\n word = np.random.zipf(alpha)\n return torch.tensor(sentence_lengths)\n\n\n# Generate a batch of semi-realistic data using Zipf distribution for sentence lengths\n# in the form of nested tensors with the jagged layout.\ndef gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):\n # generate semi-realistic data using Zipf distribution for sentence lengths\n sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)\n\n # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged\n # dimension and works with torch.compile. The batch items each have shape (B, S*, D)\n # where B = batch size, S* = ragged sequence length, and D = embedding dimension.\n if query_seq_len_1:\n query = torch.nested.nested_tensor(\n [torch.randn(1, E_q, dtype=dtype, device=device) for l in sentence_lengths],\n layout=torch.jagged,\n )\n else:\n query = torch.nested.nested_tensor(\n [\n torch.randn(l.item(), E_q, dtype=dtype, device=device)\n for l in sentence_lengths\n ],\n layout=torch.jagged,\n )\n\n key = torch.nested.nested_tensor(\n [\n torch.randn(s.item(), E_k, dtype=dtype, device=device)\n for s in sentence_lengths\n ],\n layout=torch.jagged,\n )\n\n value = torch.nested.nested_tensor(\n [\n torch.randn(s.item(), E_v, dtype=dtype, device=device)\n for s in sentence_lengths\n ],\n layout=torch.jagged,\n )\n\n return query, key, value, sentence_lengths\n\n\nimport math\nimport timeit\n\n\ndef benchmark(func, *args, **kwargs):\n torch.cuda.synchronize()\n torch.cuda.reset_peak_memory_stats()\n begin = timeit.default_timer()\n output = func(*args, **kwargs)\n torch.cuda.synchronize()\n end = timeit.default_timer()\n return output, (end - begin), torch.cuda.max_memory_allocated()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will now demonstrate the performance improvements of using nested\ntensors in the `MultiheadAttention` layer + compile for self attention.\nWe compare this against the traditional `nn.MultiheadAttention` +\ncompile with padding and masking.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "N, E_q, E_k, E_v, E_total = 512, 512, 512, 512, 512\nE_out = E_q\nd_model = E_q\nnheads = 8\ndropout = 0.0\nbias = True\ndevice = \"cuda\"\ntorch.manual_seed(6)\nquery, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v, device)\nS = sentence_lengths.max().item()\nprint(\n f\"Total sequence length in nested query {sentence_lengths.sum().item()}, max sequence length {S}\"\n)\npadded_query, padded_key, padded_value = (\n t.to_padded_tensor(0.0) for t in (query, key, value)\n)\n\ntorch.manual_seed(6)\nmha_layer = MultiHeadAttention(\n E_q, E_k, E_v, E_total, nheads, dropout=dropout, bias=bias, device=\"cuda\"\n)\ntorch.manual_seed(6)\nvanilla_mha_layer = nn.MultiheadAttention(\n E_q, nheads, dropout=dropout, batch_first=True, bias=bias, device=\"cuda\"\n)\n\n# ``nn.MultiheadAttention`` uses a non conventional initialization for layers, so do this for exact parity :(\nmha_layer.out_proj.weight = nn.Parameter(\n vanilla_mha_layer.out_proj.weight.clone().detach()\n)\nmha_layer.packed_proj.weight = nn.Parameter(\n vanilla_mha_layer.in_proj_weight.clone().detach()\n)\nmha_layer.out_proj.bias = nn.Parameter(vanilla_mha_layer.out_proj.bias.clone().detach())\nmha_layer.packed_proj.bias = nn.Parameter(\n vanilla_mha_layer.in_proj_bias.clone().detach()\n)\n\nnew_mha_layer = torch.compile(mha_layer)\n# warmup compile\nnested_result_warmup = new_mha_layer(query, query, query, is_causal=True)\n\n# benchmark\nnested_result, nested_time, nested_peak_memory = benchmark(\n new_mha_layer, query, query, query, is_causal=True\n)\npadded_nested_result = nested_result.to_padded_tensor(0.0)\n\n# For the vanilla ``nn.MultiheadAttention``, we need to construct the ``key_padding_mask``\n# Further, ``nn.MultiheadAttention`` forces one to materialize the ``attn_mask`` even if using ``is_causal``\nsrc_key_padding_mask = torch.where(padded_query == 0.0, -math.inf, 0)[:, :, 0]\nattn_mask = torch.empty((N, S, S), device=device).fill_(float(\"-inf\"))\nfor i, s in enumerate(sentence_lengths):\n attn_mask[i, :s, :s] = nn.Transformer.generate_square_subsequent_mask(s)\nattn_mask = attn_mask.unsqueeze(1).expand(N, nheads, S, S).reshape(N * nheads, S, S)\n\nvanilla_mha_layer = torch.compile(vanilla_mha_layer)\n# warmup compile\nwarmup_vanilla_result = vanilla_mha_layer(\n padded_query,\n padded_query,\n padded_query,\n attn_mask=attn_mask,\n key_padding_mask=src_key_padding_mask,\n need_weights=False,\n is_causal=True,\n)\n\n# benchmark\n(padded_result, _), padded_time, padded_peak_memory = benchmark(\n vanilla_mha_layer,\n padded_query,\n padded_query,\n padded_query,\n key_padding_mask=src_key_padding_mask,\n need_weights=False,\n attn_mask=attn_mask,\n is_causal=True,\n)\n\nprint(f\"{padded_time=:.5f}, padded_peak_memory={padded_peak_memory/1e9:.2f} GB\")\nprint(f\"{nested_time=:.5f}, nested_peak_memory={nested_peak_memory/1e9:.2f} GB\")\nprint(\n \"Max difference between vanilla and nested result\",\n (padded_result - padded_nested_result).abs().max().item(),\n)\nprint(f\"Nested speedup: {(padded_time/nested_time):.2f}\")\nprint(\n f\"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For reference, here are some sample outputs on A100:\n\n``` {.}\npadded_time=0.03454, padded_peak_memory=4.14 GB\nnested_time=0.00612, nested_peak_memory=0.76 GB\nMax difference between vanilla and nested result 0.0\nNested speedup: 5.65\nNested peak memory reduction 3.39 GB\n```\n\nWe can also see the same for backward pass\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for i, entry_length in enumerate(sentence_lengths):\n # padding-specific step: remove output projection bias from padded entries for fair comparison\n padded_result[i, entry_length:, :] = 0.0\n\n_, padded_bw_time, padded_bw_peak_mem = benchmark(\n lambda: padded_result.sum().backward()\n)\n_, nested_bw_time, nested_bw_peak_mem = benchmark(\n lambda: padded_nested_result.sum().backward()\n)\n\nprint(f\"{padded_bw_time=:.5f}, padded_bw_peak_mem={padded_bw_peak_mem/1e9:.2f} GB\")\nprint(f\"{nested_bw_time=:.5f}, nested_bw_peak_mem={nested_bw_peak_mem/1e9:.2f} GB\")\nprint(f\"Nested backward speedup: {(padded_bw_time/nested_bw_time):.2f}\")\nprint(\n f\"Nested backward peak memory reduction {((padded_bw_peak_mem - nested_bw_peak_mem)/1e9):.2f} GB\"\n)\n\nprint(\n \"Difference in out_proj.weight.grad\",\n (mha_layer.out_proj.weight.grad - vanilla_mha_layer.out_proj.weight.grad)\n .abs()\n .max()\n .item(),\n)\nprint(\n \"Difference in packed_proj.weight.grad\",\n (mha_layer.packed_proj.weight.grad - vanilla_mha_layer.in_proj_weight.grad)\n .abs()\n .max()\n .item(),\n)\nprint(\n \"Difference in out_proj.bias.grad\",\n (mha_layer.out_proj.bias.grad - vanilla_mha_layer.out_proj.bias.grad)\n .abs()\n .max()\n .item(),\n)\nprint(\n \"Difference in packed_proj.bias.grad\",\n (mha_layer.packed_proj.bias.grad - vanilla_mha_layer.in_proj_bias.grad)\n .abs()\n .max()\n .item(),\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample outputs on A100:\n\n``` {.}\npadded_bw_time=2.09337, padded_bw_peak_mem=5.10 GB\nnested_bw_time=0.01452, nested_bw_peak_mem=3.24 GB\nNested backward speedup: 144.13\nNested backward peak memory reduction 1.86 GB\nDifference in out_proj.weight.grad 0.000244140625\nDifference in packed_proj.weight.grad 0.001556396484375\nDifference in out_proj.bias.grad 0.0\nDifference in packed_proj.bias.grad 0.001953125\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "GPT-style layer\n===============\n\nA basic GPT-style transformer layer consists of a causal self-attention\nlayer followed by a feed-forward network (FFN) with skip connections.\nImplementing this is fairly straightforward using the\n`MultiheadAttention` layer above and gives equivalent results to an\n`nn.TransformerEncoderLayer` with `is_causal=True`.\n\nWe demonstrate examples of implementing the rest of the `nn` layers\n[here](https://github.com/mikaylagawarecki/transformer_tutorial_accompaniment)\nbut omit that from this tutorial for brevity.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Going one step further\n======================\n\nSo far, we have demonstrated how to implement a performant\n`MultiheadAttention` layer that follows the traditional\n`nn.MultiheadAttention`. Going back to our classification of\nmodifications to the transformer architecture, remember that we\nclassified the modifications into layer type, layer ordering, and\nmodifications to the attention score. We trust that changing layer type\nand layer ordering (such as swapping `LayerNorm` for `RMSNorm`) is\nfairly straightforward.\n\nIn this section, we will discuss various functionalities using the\naforementioned building blocks, including the following:\n\n- Cross Attention\n- Fully masked rows no longer cause NaNs\n- Modifying attention score: ALiBi with FlexAttention and NJT\n- Packed Projection\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Cross Attention\n===============\n\nCross attention is a form of attention where the query and key/value\ntensors are from different sequences.\n\nOne example of this is in `nn.TransformerDecoderLayer` where the query\ncomes from the decoder and the key/value come from the encoder.\n\nThe above MultiheadAttention layer nicely generalizes to this case with\nnested tensors for both query and key/value.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "query, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)\n_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)\n\nprint(\n f\"Total sequence length in nested query {q_len.sum().item()}, max sequence length {q_len.max().item()}\"\n)\nprint(\n f\"Total sequence length in nested key/value {kv_len.sum().item()}, max sequence length {kv_len.max().item()}\"\n)\nout = new_mha_layer(query, key, value, is_causal=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As above, we can compare this against the vanilla compiled\n`nn.MultiheadAttention`.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch.manual_seed(6)\nquery, _, _, q_len = gen_batch(N, E_q, E_k, E_v, device)\n_, key, value, kv_len = gen_batch(N, E_q, E_k, E_v, device)\npadded_query, padded_key, padded_value = (\n t.to_padded_tensor(0.0) for t in (query, key, value)\n)\n\nkey_padding_mask = torch.where(padded_key == 0.0, -math.inf, 0)[:, :, 0]\n\n# warmup compile\nwarmup_nested_result = new_mha_layer(query, key, value, is_causal=False)\nwarmup_vanilla_result = vanilla_mha_layer(\n padded_query,\n padded_key,\n padded_value,\n key_padding_mask=key_padding_mask,\n need_weights=False,\n is_causal=False,\n)\n\nnested_result, nested_time, nested_peak_memory = benchmark(\n new_mha_layer, query, key, value, is_causal=False\n)\n(padded_result, _), padded_time, padded_peak_memory = benchmark(\n vanilla_mha_layer,\n padded_query,\n padded_key,\n padded_value,\n key_padding_mask=key_padding_mask,\n need_weights=False,\n is_causal=False,\n)\npadded_nested_result = nested_result.to_padded_tensor(0.0)\nfor i, entry_length in enumerate(q_len):\n # padding-specific step: remove output projection bias from padded entries for fair comparison\n padded_result[i, entry_length:, :] = 0.0\n\nprint(\n \"Max difference between vanilla and nested result\",\n (padded_result - padded_nested_result).abs().max().item(),\n)\nprint(f\"Nested speedup: {(padded_time/nested_time):.2f}\")\nprint(\n f\"Nested peak memory reduction {((padded_peak_memory - nested_peak_memory)/1e9):.2f} GB\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Sample outputs on A100:\n\n``` {.}\nMax difference between vanilla and nested result 0.0\nNested speedup: 4.01\nNested peak memory reduction 1.40 GB\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fully masked rows no longer cause NaNs\n======================================\n\nThere has been a long standing issue with `nn.MultiheadAttention` and\n`scaled_dot_product_attention` where if a row was fully masked out, the\noutput of the attention layer would be NaN. See\n[issue](https://github.com/pytorch/pytorch/issues/41508). This is\nbecause the softmax over an empty set is undefined.\n\nThanks to [this PR](https://github.com/pytorch/pytorch/pull/133882) this\nis no longer the case. Instead, the output corresponding to fully masked\nrows in `scaled_dot_product_attention` will be 0. For cases where\n`nn.MHA` does not employ the \\\"fast-path\\\", this will also apply.\n\nUsing a custom MHA layer with NJTs is strongly recommended over the\nexisting \\\"fast-path\\\" in `nn.MultiheadAttention` as NJT\\'s ability to\nmodel raggedness appropriately makes it possible to properly express\nempty sequences.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "FlexAttention + NJT\n===================\n\nNJT also composes with the `FlexAttention` module. This is a\ngeneralization of the `MultiheadAttention` layer that allows for\narbitrary modifications to the attention score. The example below takes\nthe `alibi_mod` that implements\n[ALiBi](https://arxiv.org/abs/2108.12409) from [attention\ngym](https://github.com/pytorch-labs/attention-gym) and uses it with\nnested input tensors.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.nn.attention.flex_attention import flex_attention\n\n\ndef generate_alibi_bias(H: int):\n \"\"\"Returns an alibi bias score_mod given the number of heads H\n Args:\n H: number of heads\n Returns:\n alibi_bias: alibi bias score_mod\n \"\"\"\n\n def alibi_mod(score, b, h, q_idx, kv_idx):\n scale = torch.exp2(-((h + 1) * 8.0 / H))\n bias = (q_idx - kv_idx) * scale\n return score + bias\n\n return alibi_mod\n\n\nquery, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)\nn_heads, D = 8, E_q // 8\nalibi_score_mod = generate_alibi_bias(n_heads)\nquery = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()\nkey = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()\nvalue = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()\nout_flex2 = flex_attention(query, key, value, score_mod=alibi_score_mod)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In addition, one can also use the `block_mask` utility of\n`FlexAttention` with NJTs via the `create_nested_block_mask` function.\nThis is useful for taking advantage of the sparsity of the mask to speed\nup the attention computation. In particular, the function creates a\nsparse block mask for a \\\"stacked sequence\\\" of all the variable length\nsequences in the NJT combined into one, while properly masking out\ninter-sequence attention. In the following example, we show how to\ncreate a causal block mask using this utility.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.nn.attention.flex_attention import create_nested_block_mask\n\n\ndef causal_mask(b, h, q_idx, kv_idx):\n return q_idx >= kv_idx\n\n\nquery, key, value, _ = gen_batch(N, E_q, E_k, E_v, device)\nblock_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)\nquery = query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()\nkey = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()\nvalue = value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()\nout_flex = flex_attention(query, key, value, block_mask=block_mask)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Packed Projection\n=================\n\nPacked projection is a technique that makes use of the fact that when\nthe input for projection (matrix multiplications) are the same\n(self-attention), we can pack the projection weights and biases into\nsingle tensors. It is especially useful when the individual projections\nare memory bound rather than compute bound. There are two examples that\nwe will demonstrate here:\n\n- Input projection for MultiheadAttention\n- SwiGLU activation in feed-forward network of Transformer Layer\n\nInput projection for MultiheadAttention\n---------------------------------------\n\nWhen doing self-attention, the `query`, `key`, and `value` are the same\ntensor. Each of these tensors is projected with a `Linear(E_q, E_total)`\nlayer. Instead, we can pack this into one layer, which is what we do in\nthe MultiheadAttention layer above.\n\nLet us compare the performance of the packed projection against the\nusual method:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class InputProjection(nn.Module):\n def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):\n factory_kwargs = {\"device\": device, \"dtype\": dtype}\n super().__init__()\n self.q_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)\n self.k_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)\n self.v_proj = nn.Linear(E_q, E_total, bias=bias, **factory_kwargs)\n\n def forward(self, x):\n return self.q_proj(x), self.k_proj(x), self.v_proj(x)\n\n\nclass PackedInputProjection(nn.Module):\n def __init__(self, E_q, E_total, bias=False, device=None, dtype=None):\n factory_kwargs = {\"device\": device, \"dtype\": dtype}\n super().__init__()\n self.packed_proj = nn.Linear(E_q, E_total * 3, bias=bias, **factory_kwargs)\n\n def forward(self, query):\n return torch.chunk(self.packed_proj(query), 3, dim=-1)\n\n\nB, D, dtype = 256, 8192, torch.bfloat16\n\ntorch.set_float32_matmul_precision(\"high\")\nin_proj = torch.compile(InputProjection(D, D, device=\"cuda\", dtype=torch.bfloat16))\npacked_in_proj = torch.compile(\n PackedInputProjection(D, D, device=\"cuda\", dtype=torch.bfloat16)\n)\n\nq, _, _, sequence_lengths = gen_batch(B, D, D, D, device=\"cuda\", dtype=torch.bfloat16)\n\n# warmup\nin_proj(q)\npacked_in_proj(q)\n\n# benchmark\n(q_out, k_out, v_out), time, _ = benchmark(in_proj, q)\n(q_out, k_out, v_out), time_packed, _ = benchmark(packed_in_proj, q)\n# On my A100 prints 1.05x speedup\nprint(\n f\"InputProjection: {time:5f} s, PackedInputProjection: {time_packed:5f} s, speedup: {time/time_packed:.2f}x\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "SwiGLU feed forward network of Transformer Layer\n================================================\n\nSwish-Gated Linear Unit (SwiGLU) is a non-linear activation function\nthat is increasingly popular in the feed-forward network of the\ntransformer layer (e.g. Llama). A feed-forward network with SwiGLU\nactivation is defined as:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class SwiGLUFFN(nn.Module):\n def __init__(\n self,\n dim,\n hidden_dim,\n multiple_of,\n ffn_dim_multiplier=None,\n device=None,\n dtype=None,\n ):\n factory_kwargs = {\"device\": device, \"dtype\": dtype}\n super().__init__()\n hidden_dim = int(2 * hidden_dim / 3)\n # custom dim factor multiplier\n if ffn_dim_multiplier is not None:\n hidden_dim = int(ffn_dim_multiplier * hidden_dim)\n hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)\n\n self.w1 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)\n self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)\n self.w3 = nn.Linear(dim, hidden_dim, bias=False, **factory_kwargs)\n\n def forward(self, x):\n return self.w2(F.silu(self.w1(x)) * self.w3(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "An alternative way of implementing this that uses packed projection is\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class PackedSwiGLUFFN(nn.Module):\n def __init__(\n self,\n dim,\n hidden_dim,\n multiple_of,\n ffn_dim_multiplier=None,\n device=None,\n dtype=None,\n ):\n factory_kwargs = {\"device\": device, \"dtype\": dtype}\n super().__init__()\n hidden_dim = int(2 * hidden_dim / 3)\n # custom dim factor multiplier\n if ffn_dim_multiplier is not None:\n hidden_dim = int(ffn_dim_multiplier * hidden_dim)\n hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)\n\n self.w13 = nn.Linear(dim, 2 * hidden_dim, bias=False, **factory_kwargs)\n self.w2 = nn.Linear(hidden_dim, dim, bias=False, **factory_kwargs)\n\n def forward(self, x):\n x1, x3 = torch.chunk(self.w13(x), 2, dim=-1)\n return self.w2(F.silu(x1) * x3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can compare the performance of the two implementations as follows\nDepending on your hardware, you might see different results. On an A100\nI see 1.12x speedup for D=128.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "D = 128\n\nswigluffn = torch.compile(SwiGLUFFN(D, D * 4, 256, device=\"cuda\", dtype=torch.bfloat16))\npacked_swigluffn = torch.compile(\n PackedSwiGLUFFN(D, D * 4, 256, device=\"cuda\", dtype=torch.bfloat16)\n)\n\nq, _, _, sentence_lengths = gen_batch(D, D, D, D, device=\"cuda\", dtype=torch.bfloat16)\n\n# warmup\nswigluffn(q)\npacked_swigluffn(q)\n\n# benchmark\n_, time, _ = benchmark(swigluffn, q)\n_, time_packed, _ = benchmark(packed_swigluffn, q)\n# On my A100 prints 1.08x speedup\nprint(\n f\"SwiGLUFFN: {time} s, PackedSwiGLUFFN: {time_packed} s, speedup: {time/time_packed:.2f}x\"\n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Extended examples\n=================\n\nWe intend to update this tutorial to demonstrate more examples of how to\nuse the various performant building blocks such as KV-Caching, Grouped\nQuery Attention etc. Further, there are several good examples of using\nvarious performant building blocks to implement various transformer\narchitectures. Some examples include\n\n- [gpt-fast](https://github.com/pytorch-labs/gpt-fast)\n- [segment-anything-fast](https://github.com/pytorch-labs/segment-anything-fast)\n- [lucidrains implementation of NaViT with nested\n tensors](https://github.com/lucidrains/vit-pytorch/blob/73199ab486e0fad9eced2e3350a11681db08b61b/vit_pytorch/na_vit_nested_tensor.py)\n- [torchtune\\'s implementation of\n VisionTransformer](https://github.com/pytorch/torchtune/blob/a8a64ec6a99a6ea2be4fdaf0cd5797b03a2567cf/torchtune/modules/vision_transformer.py#L16)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, we have introduced the low level building blocks\nPyTorch provides for writing transformer layers and demonstrated\nexamples how to compose them. It is our hope that this tutorial has\neducated the reader on the ease with which flexible and performant\ntransformer layers can be implemented by users of PyTorch.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }