{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)\n==========================================================================================\n\n**Author:** [Driss Guessous](https://github.com/drisspg)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Summary\n=======\n\nIn this tutorial, we want to highlight a new `torch.nn.functional`\nfunction that can be helpful for implementing transformer architectures.\nThe function is named\n`torch.nn.functional.scaled_dot_product_attention`. For detailed\ndescription of the function, see the [PyTorch\ndocumentation](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention).\nThis function has already been incorporated into\n`torch.nn.MultiheadAttention` and `torch.nn.TransformerEncoderLayer`.\n\nOverview\n========\n\nAt a high level, this PyTorch function calculates the scaled dot product\nattention (SDPA) between query, key, and value according to the\ndefinition found in the paper [Attention is all you\nneed](https://arxiv.org/abs/1706.03762). While this function can be\nwritten in PyTorch using existing functions, a fused implementation can\nprovide large performance benefits over a naive implementation.\n\nFused implementations\n=====================\n\nFor CUDA tensor inputs, the function will dispatch into one of the\nfollowing implementations:\n\n- [FlashAttention: Fast and Memory-Efficient Exact Attention with\n IO-Awareness](https://arxiv.org/abs/2205.14135)\n- [Memory-Efficient\n Attention](https://github.com/facebookresearch/xformers)\n- A PyTorch implementation defined in C++\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This tutorial requires PyTorch 2.0.0 or later.

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nimport torch.nn as nn\nimport torch.nn.functional as F\ndevice = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n\n# Example Usage:\nquery, key, value = torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device), torch.randn(2, 3, 8, device=device)\nF.scaled_dot_product_attention(query, key, value)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Explicit Dispatcher Control\n===========================\n\nWhile the function will implicitly dispatch to one of the three\nimplementations, the user can also explicitly control the dispatch via\nthe use of a context manager. This context manager allows users to\nexplicitly disable certain implementations. If a user wants to ensure\nthe function is indeed using the fastest implementation for their\nspecific inputs, the context manager can be used to sweep through\nmeasuring performance.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Lets define a helpful benchmarking function:\nimport torch.utils.benchmark as benchmark\ndef benchmark_torch_function_in_microseconds(f, *args, **kwargs):\n t0 = benchmark.Timer(\n stmt=\"f(*args, **kwargs)\", globals={\"args\": args, \"kwargs\": kwargs, \"f\": f}\n )\n return t0.blocked_autorange().mean * 1e6\n\n# Lets define the hyper-parameters of our input\nbatch_size = 32\nmax_sequence_len = 1024\nnum_heads = 32\nembed_dimension = 32\n\ndtype = torch.float16\n\nquery = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)\nkey = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)\nvalue = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)\n\nprint(f\"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds\")\n\n# Lets explore the speed of each of the 3 implementations\nfrom torch.nn.attention import SDPBackend, sdpa_kernel\n\n\nwith sdpa_kernel(SDPBackend.MATH):\n math_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)\n print(f\"The math implementation runs in {math_time:.3f} microseconds\")\n\nwith sdpa_kernel(SDPBackend.FLASH_ATTENTION):\n try:\n flash_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)\n print(f\"The flash attention implementation runs in {flash_time:.3f} microseconds\")\n except RuntimeError:\n print(\"FlashAttention is not supported. See warnings for reasons.\")\n\nwith sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):\n try:\n efficient_time=benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value)\n print(f\"The memory efficient implementation runs in {efficient_time:.3f} microseconds\")\n except RuntimeError:\n print(\"EfficientAttention is not supported. See warnings for reasons.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Hardware dependence\n===================\n\nDepending on what machine you ran the above cell on and what hardware is\navailable, your results might be different. - If you don't have a GPU\nand are running on CPU then with FP32 the context manager will have no\neffect and all three runs should return similar timings. - Depending on\nwhat compute capability your graphics card supports flash attention or\nmemory efficient might have failed.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Causal Self Attention\n=====================\n\nBelow is an example implementation of a multi-headed causal self\nattention block inspired by [Andrej Karpathy\nNanoGPT](https://github.com/karpathy/nanoGPT) repository.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class CausalSelfAttention(nn.Module):\n\n def __init__(self, num_heads: int, embed_dimension: int, bias: bool=False, is_causal: bool=False, dropout:float=0.0):\n super().__init__()\n assert embed_dimension % num_heads == 0\n # key, query, value projections for all heads, but in a batch\n self.c_attn = nn.Linear(embed_dimension, 3 * embed_dimension, bias=bias)\n # output projection\n self.c_proj = nn.Linear(embed_dimension, embed_dimension, bias=bias)\n # regularization\n self.dropout = dropout\n self.resid_dropout = nn.Dropout(dropout)\n self.num_heads = num_heads\n self.embed_dimension = embed_dimension\n # Perform causal masking\n self.is_causal = is_causal\n\n def forward(self, x):\n # calculate query, key, values for all heads in batch and move head forward to be the batch dim\n query_projected = self.c_attn(x)\n\n batch_size = query_projected.size(0)\n embed_dim = query_projected.size(2)\n head_dim = embed_dim // (self.num_heads * 3)\n\n query, key, value = query_projected.chunk(3, -1)\n query = query.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)\n key = key.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)\n value = value.view(batch_size, -1, self.num_heads, head_dim).transpose(1, 2)\n\n if self.training:\n dropout = self.dropout\n is_causal = self.is_causal\n else:\n dropout = 0.0\n is_causal = False\n\n y = F.scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=dropout, is_causal=is_causal)\n y = y.transpose(1, 2).view(batch_size, -1, self.num_heads * head_dim)\n\n y = self.resid_dropout(self.c_proj(y))\n return y\n\n\nnum_heads = 8\nheads_per_dim = 64\nembed_dimension = num_heads * heads_per_dim\ndtype = torch.float16\nmodel = CausalSelfAttention(num_heads=num_heads, embed_dimension=embed_dimension, bias=False, is_causal=True, dropout=0.1).to(\"cuda\").to(dtype).eval()\nprint(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`NestedTensor` and Dense tensor support\n=======================================\n\nSDPA supports both `NestedTensor` and Dense tensor inputs.\n`NestedTensors` handle the case where the input is a batch of variable\nlength sequences without needing to pad each sequence to the maximum\nlength in the batch. For more information about `NestedTensors` see\n[torch.nested](https://pytorch.org/docs/stable/nested.html) and\n[NestedTensors\nTutorial](https://pytorch.org/tutorials/prototype/nestedtensor.html).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import random\ndef generate_rand_batch(\n batch_size,\n max_sequence_len,\n embed_dimension,\n pad_percentage=None,\n dtype=torch.float16,\n device=\"cuda\",\n):\n if not pad_percentage:\n return (\n torch.randn(\n batch_size,\n max_sequence_len,\n embed_dimension,\n dtype=dtype,\n device=device,\n ),\n None,\n )\n # Random sequence lengths\n seq_len_list = [\n int(max_sequence_len * (1 - random.gauss(pad_percentage, 0.01)))\n for _ in range(batch_size)\n ]\n # Make random entry in the batch have max sequence length\n seq_len_list[random.randint(0, batch_size - 1)] = max_sequence_len\n return (\n torch.nested.nested_tensor(\n [\n torch.randn(seq_len, embed_dimension,\n dtype=dtype, device=device)\n for seq_len in seq_len_list\n ]\n ),\n seq_len_list,\n )\n\nrandom_nt, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=0.5, dtype=dtype, device=device)\nrandom_dense, _ = generate_rand_batch(32, 512, embed_dimension, pad_percentage=None, dtype=dtype, device=device)\n\n# Currently the fused implementations don't support ``NestedTensor`` for training\nmodel.eval()\n\nwith sdpa_kernel(SDPBackend.FLASH_ATTENTION):\n try:\n print(f\"Random NT runs in {benchmark_torch_function_in_microseconds(model, random_nt):.3f} microseconds\")\n print(f\"Random Dense runs in {benchmark_torch_function_in_microseconds(model, random_dense):.3f} microseconds\")\n except RuntimeError:\n print(\"FlashAttention is not supported. See warnings for reasons.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using SDPA with `torch.compile`\n===============================\n\nWith the release of PyTorch 2.0, a new feature called `torch.compile()`\nhas been introduced, which can provide significant performance\nimprovements over eager mode. Scaled dot product attention is fully\ncomposable with `torch.compile()`. To demonstrate this, let\\'s compile\nthe `CausalSelfAttention` module using `torch.compile()` and observe the\nresulting performance improvements.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "batch_size = 32\nmax_sequence_len = 256\nx = torch.rand(batch_size, max_sequence_len,\n embed_dimension, device=device, dtype=dtype)\nprint(\n f\"The non compiled module runs in {benchmark_torch_function_in_microseconds(model, x):.3f} microseconds\")\n\n\ncompiled_model = torch.compile(model)\n# Let's compile it\ncompiled_model(x)\nprint(\n f\"The compiled module runs in {benchmark_torch_function_in_microseconds(compiled_model, x):.3f} microseconds\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The exact execution time is dependent on machine, however the results\nfor mine: The non compiled module runs in 166.616 microseconds The\ncompiled module runs in 166.726 microseconds That is not what we were\nexpecting. Let\\'s dig a little deeper. PyTorch comes with an amazing\nbuilt-in profiler that you can use to inspect the performance\ncharacteristics of your code.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torch.profiler import profile, record_function, ProfilerActivity\nactivities = [ProfilerActivity.CPU]\nif device == 'cuda':\n activities.append(ProfilerActivity.CUDA)\n\nwith profile(activities=activities, record_shapes=False) as prof:\n with record_function(\" Non-Compilied Causal Attention\"):\n for _ in range(25):\n model(x)\nprint(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n\nwith profile(activities=activities, record_shapes=False) as prof:\n with record_function(\"Compiled Causal Attention\"):\n for _ in range(25):\n compiled_model(x)\nprint(prof.key_averages().table(sort_by=\"cuda_time_total\", row_limit=10))\n\n# For even more insights, you can export the trace and use ``chrome://tracing`` to view the results\n#\n# .. code-block:: python\n#\n# prof.export_chrome_trace(\"compiled_causal_attention_trace.json\")." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The previous code snippet generates a report of the top 10 PyTorch\nfunctions that consumed the most GPU execution time, for both the\ncompiled and non-compiled module. The analysis reveals that the majority\nof time spent on the GPU is concentrated on the same set of functions\nfor both modules. The reason for this here is that `torch.compile` is\nvery good at removing the framework overhead associated with PyTorch. If\nyour model is launching large, efficient CUDA kernels, which in this\ncase `CausalSelfAttention` is, then the overhead of PyTorch can be\nhidden.\n\nIn reality, your module does not normally consist of a singular\n`CausalSelfAttention` block. When experimenting with [Andrej Karpathy\nNanoGPT](https://github.com/karpathy/nanoGPT) repository, compiling the\nmodule took the time per train step from: `6090.49ms` to `3273.17ms`!\nThis was done on commit: `ae3a8d5` of NanoGPT training on the\nShakespeare dataset.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using SDPA with attn\\_bias subclasses\n=====================================\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# As of PyTorch 2.3, we have added a new submodule that contains tensor subclasses.\n# Designed to be used with ``torch.nn.functional.scaled_dot_product_attention``.\n# The module is named ``torch.nn.attention.bias`` and contains the following two\n# utilities for generating causal attention variants:\n#\n# - ``torch.nn.attention.bias.causal_upper_left``\n# - ``torch.nn.attention.bias.causal_lower_right``\n#\n# .. note::\n# The current argument ``is_causal`` in ``torch.nn.functional.scaled_dot_product_attention``\n# is the same as using ``torch.nn.attention.bias.causal_upper_left``.\n#\n\nfrom torch.nn.attention.bias import causal_lower_right, causal_upper_left\n\nbatch_size = 32\nsequence_length_q = 2\nsequence_length_kv = 10\nnum_heads = 16\nembed_dimension = 32\n\ndtype = torch.float16\n\nquery = torch.rand(batch_size, num_heads, sequence_length_q, embed_dimension, device=device, dtype=dtype)\nkey = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)\nvalue = torch.rand(batch_size, num_heads, sequence_length_kv, embed_dimension, device=device, dtype=dtype)\n\nupper_left_bias = causal_upper_left(sequence_length_q, sequence_length_kv)\nlower_right_bias = causal_lower_right(sequence_length_q, sequence_length_kv)\n\nprint(type(upper_left_bias))\nprint(type(lower_right_bias))\n\nassert type(upper_left_bias) == type(lower_right_bias)\nassert issubclass(type(upper_left_bias), torch.Tensor)\n\n# As you can see from the previous output, are the same type ``torch.nn.attention.bias.CausalBias``\n# and subclass ``torch.Tensor``\n\n# Lets see what these tensors look like\nprint(upper_left_bias)\nprint(lower_right_bias)\n\n# Upper Left Bias aligns the causal attention mask to the upper left corner of the attention scores matrix.\n# This only has an impact when the attention scores matrix is not square, which is common for decoding use cases.\n# Another way of thinking about this concept is that when you use upper left bias,\n# the 0th token in the query is aligned to the 0th token in the key, while for lower right bias,\n# Assuming the attention score matrix is two dimensional, ``attn_score[0][0]`` is the attention score\n# between the 0th token in the query and the 0th token in the key.\n# For lower right bias, the sequence of q is aligned so that the last token in q is aligned to the last token in k\n# (for example, ``attn_score[-1][-1])`` is all True since the last token in q is at the same position as the last token in k\n# even if the sequence length of q and k are different.\n\n# These objects are intended to be used with sdpa\nout_upper_left = F.scaled_dot_product_attention(query, key, value, upper_left_bias)\nout_lower_right = F.scaled_dot_product_attention(query, key, value, lower_right_bias)\nout_is_causal = F.scaled_dot_product_attention(query, key, value, is_causal=True)\n\nassert torch.allclose(out_upper_left, out_is_causal)\nassert not torch.allclose(out_upper_left, out_lower_right)\n\n# These attention biases should also be compatible with torch.compile\ncompiled_sdpa = torch.compile(F.scaled_dot_product_attention, fullgraph=True)\nout_upper_left = compiled_sdpa(query, key, value, upper_left_bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial, we have demonstrated the basic usage of\n`torch.nn.functional.scaled_dot_product_attention`. We have shown how\nthe `sdpa_kernel` context manager can be used to assert a certain\nimplementation is used on GPU. As well, we built a simple\n`CausalSelfAttention` module that works with `NestedTensor` and is torch\ncompilable. In the process we have shown how to the profiling tools can\nbe used to explore the performance characteristics of a user defined\nmodule.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }