{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(Beta) Implementing High-Performance Transformers with Scaled Dot Product Attention (SDPA)\n==========================================================================================\n\n**Author:** [Driss Guessous](https://github.com/drisspg)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Summary\n=======\n\nIn this tutorial, we want to highlight a new `torch.nn.functional`\nfunction that can be helpful for implementing transformer architectures.\nThe function is named\n`torch.nn.functional.scaled_dot_product_attention`. For detailed\ndescription of the function, see the [PyTorch\ndocumentation](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention).\nThis function has already been incorporated into\n`torch.nn.MultiheadAttention` and `torch.nn.TransformerEncoderLayer`.\n\nOverview\n========\n\nAt a high level, this PyTorch function calculates the scaled dot product\nattention (SDPA) between query, key, and value according to the\ndefinition found in the paper [Attention is all you\nneed](https://arxiv.org/abs/1706.03762). While this function can be\nwritten in PyTorch using existing functions, a fused implementation can\nprovide large performance benefits over a naive implementation.\n\nFused implementations\n=====================\n\nFor CUDA tensor inputs, the function will dispatch into one of the\nfollowing implementations:\n\n- [FlashAttention: Fast and Memory-Efficient Exact Attention with\n IO-Awareness](https://arxiv.org/abs/2205.14135)\n- [Memory-Efficient\n Attention](https://github.com/facebookresearch/xformers)\n- A PyTorch implementation defined in C++\n\n```{=html}\n
This tutorial requires PyTorch 2.0.0 or later.
\n```\n```{=html}\n