{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "::: {.meta description=\"An end-to-end example of how to use AOTInductor for Python runtime.\" keywords=\"torch.export, AOTInductor, torch._inductor.aoti_compile_and_package, aot_compile, torch._export.aoti_load_package\"}\n:::\n\n`torch.export` AOTInductor Tutorial for Python runtime (Beta)\n=============================================================\n\n**Author:** Ankith Gunapal, Bin Bao, Angela Yi\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "```{=html}\n
WARNING:
\n```\n```{=html}\n
\n```\n```{=html}\n

torch._inductor.aoti_compile_and_package andtorch._inductor.aoti_load_package are in Beta status and are subjectto backwards compatibility breaking changes. This tutorial provides anexample of how to use these APIs for model deployment using Pythonruntime.

\n```\n```{=html}\n
\n```\nIt has been shown\n[previously](https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#)\nhow AOTInductor can be used to do Ahead-of-Time compilation of PyTorch\nexported models by creating an artifact that can be run in a non-Python\nenvironment. In this tutorial, you will learn an end-to-end example of\nhow to use AOTInductor for Python runtime.\n\n**Contents**\n\n::: {.contents local=\"\"}\n:::\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Prerequisites\n=============\n\n- PyTorch 2.6 or later\n- Basic understanding of `torch.export` and AOTInductor\n- Complete the [AOTInductor: Ahead-Of-Time Compilation for\n Torch.Export-ed\n Models](https://pytorch.org/docs/stable/torch.compiler_aot_inductor.html#)\n tutorial\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "What you will learn\n===================\n\n- How to use AOTInductor for Python runtime.\n- How to use\n `torch._inductor.aoti_compile_and_package`{.interpreted-text\n role=\"func\"} along with `torch.export.export`{.interpreted-text\n role=\"func\"} to generate a compiled artifact\n- How to load and run the artifact in a Python runtime using\n `torch._export.aot_load`{.interpreted-text role=\"func\"}.\n- When to you use AOTInductor with a Python runtime\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model Compilation\n=================\n\nWe will use the TorchVision pretrained `ResNet18` model as an example.\n\nThe first step is to export the model to a graph representation using\n`torch.export.export`{.interpreted-text role=\"func\"}. To learn more\nabout using this function, you can check out the\n[docs](https://pytorch.org/docs/main/export.html) or the\n[tutorial](https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html).\n\nOnce we have exported the PyTorch model and obtained an\n`ExportedProgram`, we can apply\n`torch._inductor.aoti_compile_and_package`{.interpreted-text\nrole=\"func\"} to AOTInductor to compile the program to a specified\ndevice, and save the generated contents into a \\\".pt2\\\" artifact.\n\n```{=html}\n
NOTE:
\n```\n```{=html}\n
\n```\n```{=html}\n

This API supports the same available options that torch.compilehas, such as mode and max_autotune (for those who want to enableCUDA graphs and leverage Triton based matrix multiplications andconvolutions)

\n```\n```{=html}\n
\n```\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\nimport torch\nimport torch._inductor\nfrom torchvision.models import ResNet18_Weights, resnet18\n\nmodel = resnet18(weights=ResNet18_Weights.DEFAULT)\nmodel.eval()\n\nwith torch.inference_mode():\n inductor_configs = {}\n\n if torch.cuda.is_available():\n device = \"cuda\"\n inductor_configs[\"max_autotune\"] = True\n else:\n device = \"cpu\"\n\n model = model.to(device=device)\n example_inputs = (torch.randn(2, 3, 224, 224, device=device),)\n\n exported_program = torch.export.export(\n model,\n example_inputs,\n )\n path = torch._inductor.aoti_compile_and_package(\n exported_program,\n package_path=os.path.join(os.getcwd(), \"resnet18.pt2\"),\n inductor_configs=inductor_configs\n )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The result of `aoti_compile_and_package`{.interpreted-text role=\"func\"}\nis an artifact \\\"resnet18.pt2\\\" which can be loaded and executed in\nPython and C++.\n\nThe artifact itself contains a bunch of AOTInductor generated code, such\nas a generated C++ runner file, a shared library compiled from the C++\nfile, and CUDA binary files, aka cubin files, if optimizing for CUDA.\n\nStructure-wise, the artifact is a structured `.zip` file, with the\nfollowing specification:\n\n``` {..\n\u251c\u2500\u2500 archive_format\n\u251c\u2500\u2500 version\n\u251c\u2500\u2500 data\n\u2502 \u251c\u2500\u2500 aotinductor\n\u2502 \u2502 \u2514\u2500\u2500 model\n\u2502 \u2502 \u251c\u2500\u2500 xxx.cpp # AOTInductor generated cpp file\n\u2502 \u2502 \u251c\u2500\u2500 xxx.so # AOTInductor generated shared library\n\u2502 \u2502 \u251c\u2500\u2500 xxx.cubin # Cubin files (if running on CUDA)\n\u2502 \u2502 \u2514\u2500\u2500 xxx_metadata.json # Additional metadata to save\n\u2502 \u251c\u2500\u2500 weights\n\u2502 \u2502 \u2514\u2500\u2500 TBD\n\u2502 \u2514\u2500\u2500 constants\n\u2502 \u2514\u2500\u2500 TBD\n\u2514\u2500\u2500 extra\n\u2514\u2500\u2500 metadata.json}\n```\n\nWe can use the following command to inspect the artifact contents:\n\n``` {.bash}\n$ unzip -l resnet18.pt2\n```\n\n``` {.}\nArchive: resnet18.pt2\n Length Date Time Name\n--------- ---------- ----- ----\n 1 01-08-2025 16:40 version\n 3 01-08-2025 16:40 archive_format\n 10088 01-08-2025 16:40 data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin\n 17160 01-08-2025 16:40 data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin\n 16616 01-08-2025 16:40 data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin\n 17776 01-08-2025 16:40 data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin\n 10856 01-08-2025 16:40 data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin\n 14608 01-08-2025 16:40 data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin\n 11376 01-08-2025 16:40 data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin\n 10984 01-08-2025 16:40 data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin\n 14736 01-08-2025 16:40 data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin\n 11376 01-08-2025 16:40 data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin\n 11624 01-08-2025 16:40 data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin\n 15632 01-08-2025 16:40 data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin\n 25472 01-08-2025 16:40 data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin\n 139389 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp\n 27 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json\n 47195424 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so\n--------- -------\n 47523148 18 files\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Model Inference in Python\n=========================\n\nTo load and run the artifact in Python, we can use\n`torch._inductor.aoti_load_package`{.interpreted-text role=\"func\"}.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\nimport torch\nimport torch._inductor\n\nmodel_path = os.path.join(os.getcwd(), \"resnet18.pt2\")\n\ncompiled_model = torch._inductor.aoti_load_package(model_path)\nexample_inputs = (torch.randn(2, 3, 224, 224, device=device),)\n\nwith torch.inference_mode():\n output = compiled_model(example_inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When to use AOTInductor with a Python Runtime\n=============================================\n\nThere are mainly two reasons why one would use AOTInductor with a Python\nRuntime:\n\n- `torch._inductor.aoti_compile_and_package` generates a singular\n serialized artifact. This is useful for model versioning for\n deployments and tracking model performance over time.\n- With `torch.compile`{.interpreted-text role=\"func\"} being a JIT\n compiler, there is a warmup cost associated with the first\n compilation. Your deployment needs to account for the compilation\n time taken for the first inference. With AOTInductor, the\n compilation is done ahead of time using `torch.export.export` and\n `torch._inductor.aoti_compile_and_package`. At deployment time,\n after loading the model, running inference does not have any\n additional cost.\n\nThe section below shows the speedup achieved with AOTInductor for first\ninference\n\nWe define a utility function `timed` to measure the time taken for\ninference\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import time\ndef timed(fn):\n # Returns the result of running `fn()` and the time it took for `fn()` to run,\n # in seconds. We use CUDA events and synchronization for accurate\n # measurement on CUDA enabled devices.\n if torch.cuda.is_available():\n start = torch.cuda.Event(enable_timing=True)\n end = torch.cuda.Event(enable_timing=True)\n start.record()\n else:\n start = time.time()\n\n result = fn()\n if torch.cuda.is_available():\n end.record()\n torch.cuda.synchronize()\n else:\n end = time.time()\n\n # Measure time taken to execute the function in miliseconds\n if torch.cuda.is_available():\n duration = start.elapsed_time(end)\n else:\n duration = (end - start) * 1000\n\n return result, duration" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets measure the time for first inference using AOTInductor\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch._dynamo.reset()\n\nmodel = torch._inductor.aoti_load_package(model_path)\nexample_inputs = (torch.randn(1, 3, 224, 224, device=device),)\n\nwith torch.inference_mode():\n _, time_taken = timed(lambda: model(example_inputs))\n print(f\"Time taken for first inference for AOTInductor is {time_taken:.2f} ms\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lets measure the time for first inference using `torch.compile`\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "torch._dynamo.reset()\n\nmodel = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)\nmodel.eval()\n\nmodel = torch.compile(model)\nexample_inputs = torch.randn(1, 3, 224, 224, device=device)\n\nwith torch.inference_mode():\n _, time_taken = timed(lambda: model(example_inputs))\n print(f\"Time taken for first inference for torch.compile is {time_taken:.2f} ms\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that there is a drastic speedup in first inference time using\nAOTInductor compared to `torch.compile`\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this recipe, we have learned how to effectively use the AOTInductor\nfor Python runtime by compiling and loading a pretrained `ResNet18`\nmodel. This process demonstrates the practical application of generating\na compiled artifact and running it within a Python environment. We also\nlooked at the advantage of using AOTInductor in model deployments, with\nregards to speed up in first inference time.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }