{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# For tips on running notebooks in Google Colab, see\n# https://codelin.vip/beginner/colab\n%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "(prototype) Accelerating `torch.save` and `torch.load` with GPUDirect Storage\n=============================================================================\n\nGPUDirect Storage enables a direct data path for direct memory access\ntransfers between GPU memory and storage, avoiding a bounce buffer\nthrough the CPU.\n\nIn version **2.7**, we introduced new prototype APIs to `torch.cuda.gds`\nthat serve as thin wrappers around the [cuFile\nAPIs](https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api)\nthat can be used with `torch.Tensor` to achieve improved I/O\nperformance.\n\nIn this tutorial, we will demonstrate how to use the `torch.cuda.gds`\nAPIs in conjunction with checkpoints generated by `torch.save` and\n`torch.load` on local filesystem.\n\n```{=html}\n

What you will learn

Prerequisites

\n```\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Using GPUDirect Storage with `torch.save` and `torch.load`\n==========================================================\n\nGPUDirect Storage requires a storage alignment of 4KB. You can toggle\nthis by using `torch.utils.serialization.config.save.storage_alignment`:\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torch.utils.serialization import config as serialization_config\n\nserialization_config.save.storage_alignment = 4096" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The steps involved in the process are as follows:\n\n: - Write the checkpoint file without any actual data. This reserves\n the space on disk.\n - Read the offsets for the storage associated with each tensor in\n the checkpoint using `FakeTensor`.\n - Use `GDSFile` to write the appropriate data at these offsets.\n\nGiven a state dictionary of tensors that are on the GPU, one can use the\n`torch.serialization.skip_data` context manager to save a checkpoint\nthat contains all relevant metadata except the storage bytes. For each\n`torch.Storage` in the state dictionary, space will be reserved within\nthe checkpoint for the storage bytes.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch.nn as nn\n\nm = nn.Linear(5, 10, device='cuda')\nsd = m.state_dict()\n\nwith torch.serialization.skip_data():\n torch.save(sd, \"checkpoint.pt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can get the offsets that each storage should be written to within the\ncheckpoint by loading under a `FakeTensorMode`. A FakeTensor is a tensor\nthat has metadata (such as sizes, strides, dtype, device) information\nabout the tensor but does not have any storage bytes. The following\nsnippet will not materialize any data but will tag each `FakeTensor`\nwith the offset within the checkpoint that corresponds to the tensor.\n\nIf you are continuously saving the same state dictionary during\ntraining, you would only need to obtain the offsets once and the same\noffsets can be re-used. Similarly if tensor is going to be saved or\nloaded to repeatedly you can use the\n`torch.cuda.gds.gds_register_buffer` which wraps `cuFileBufRegister` to\nregister the storages as GDS buffers.\n\nNote that `torch.cuda.gds.GdsFile.save_storage` binds to the synchronous\n`cuFileWrite` API, so no synchronization is needed afterwards.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import os\nfrom torch._subclasses.fake_tensor import FakeTensorMode\n\nwith FakeTensorMode() as mode:\n fake_sd = torch.load(\"checkpoint.pt\")\n\nfor k, v in fake_sd.items():\n print(f\"key={k}, offset={v.untyped_storage()._checkpoint_offset}\")\n\nf = torch.cuda.gds.GdsFile(\"checkpoint.pt\", os.O_RDWR)\n\nfor k, v in sd.items():\n offset = fake_sd[k].untyped_storage()._checkpoint_offset\n # save_storage is a wrapper around `cuFileWrite`\n f.save_storage(v.untyped_storage(), offset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We verify correctness of the saved checkpoint by `torch.load` and\ncomparing.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sd_loaded = torch.load(\"checkpoint.pt\")\nfor k, v in sd_loaded.items():\n assert torch.equal(v, sd[k])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The loading flow is the inverse: you can use `torch.load` with the\n`torch.serialization.skip_data` context manager to load everything\nexcept the storage bytes. This means that any tensors in the checkpoint\nwill be created but their storages will be empty (as if the tensors were\ncreated via `torch.empty`).\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with torch.serialization.skip_data():\n sd_loaded = torch.load(\"checkpoint.pt\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We once again use the `FakeTensorMode` to get the checkpoint offsets and\nascertain that the loaded checkpoint is the same as the saved\ncheckpoint.\n\nSimilar to `torch.cuda.gds.GdsFile.save_storage`,\n`torch.cuda.gds.GdsFile.load_storage` binds to the synchronous\n`cuFileRead` API, so no synchronization is needed afterwards.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "for k, v in sd_loaded.items():\n assert not torch.equal(v, sd[k])\n offset = fake_sd[k].untyped_storage()._checkpoint_offset\n # load_storage is a wrapper around `cuFileRead`\n f.load_storage(v.untyped_storage(), offset)\n\nfor k, v in sd_loaded.items():\n assert torch.equal(v, sd[k])\n\ndel f" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Conclusion\n==========\n\nIn this tutorial we have demonstrated how to use the prototype\n`torch.cuda.gds` APIs in conjunction with `torch.save` and `torch.load`\non local filesystem. Please file an issue in the PyTorch GitHub repo if\nyou have any feedback.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 0 }