.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "prototype/gpu_direct_storage.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_prototype_gpu_direct_storage.py: (prototype) Accelerating ``torch.save`` and ``torch.load`` with GPUDirect Storage ================================================================================= GPUDirect Storage enables a direct data path for direct memory access transfers between GPU memory and storage, avoiding a bounce buffer through the CPU. In version **2.7**, we introduced new prototype APIs to ``torch.cuda.gds`` that serve as thin wrappers around the `cuFile APIs `_ that can be used with ``torch.Tensor`` to achieve improved I/O performance. In this tutorial, we will demonstrate how to use the ``torch.cuda.gds`` APIs in conjunction with checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem. .. grid:: 2 .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn :class-card: card-prerequisites * Understand how to use the ``torch.cuda.gds`` APIs in conjunction with checkpoints generated by ``torch.save`` and ``torch.load`` on local filesystem .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites :class-card: card-prerequisites * PyTorch v.2.7.0 or later * GPUDirect Storage must be installed per `the documentation `_ * Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage. .. GENERATED FROM PYTHON SOURCE LINES 33-37 Using GPUDirect Storage with ``torch.save`` and ``torch.load`` ------------------------------------------------------------------------------------ GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using ``torch.utils.serialization.config.save.storage_alignment``: .. GENERATED FROM PYTHON SOURCE LINES 37-43 .. code-block:: default import torch from torch.utils.serialization import config as serialization_config serialization_config.save.storage_alignment = 4096 .. GENERATED FROM PYTHON SOURCE LINES 44-52 The steps involved in the process are as follows: * Write the checkpoint file without any actual data. This reserves the space on disk. * Read the offsets for the storage associated with each tensor in the checkpoint using ``FakeTensor``. * Use ``GDSFile`` to write the appropriate data at these offsets. Given a state dictionary of tensors that are on the GPU, one can use the ``torch.serialization.skip_data`` context manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each ``torch.Storage`` in the state dictionary, space will be reserved within the checkpoint for the storage bytes. .. GENERATED FROM PYTHON SOURCE LINES 52-61 .. code-block:: default import torch.nn as nn m = nn.Linear(5, 10, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(): torch.save(sd, "checkpoint.pt") .. GENERATED FROM PYTHON SOURCE LINES 62-75 We can get the offsets that each storage should be written to within the checkpoint by loading under a ``FakeTensorMode``. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device) information about the tensor but does not have any storage bytes. The following snippet will not materialize any data but will tag each ``FakeTensor`` with the offset within the checkpoint that corresponds to the tensor. If you are continuously saving the same state dictionary during training, you would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to be saved or loaded to repeatedly you can use the ``torch.cuda.gds.gds_register_buffer`` which wraps ``cuFileBufRegister`` to register the storages as GDS buffers. Note that ``torch.cuda.gds.GdsFile.save_storage`` binds to the synchronous ``cuFileWrite`` API, so no synchronization is needed afterwards. .. GENERATED FROM PYTHON SOURCE LINES 75-94 .. code-block:: default import os from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode() as mode: fake_sd = torch.load("checkpoint.pt") for k, v in fake_sd.items(): print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}") f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR) for k, v in sd.items(): offset = fake_sd[k].untyped_storage()._checkpoint_offset # save_storage is a wrapper around `cuFileWrite` f.save_storage(v.untyped_storage(), offset) .. GENERATED FROM PYTHON SOURCE LINES 95-96 We verify correctness of the saved checkpoint by ``torch.load`` and comparing. .. GENERATED FROM PYTHON SOURCE LINES 96-101 .. code-block:: default sd_loaded = torch.load("checkpoint.pt") for k, v in sd_loaded.items(): assert torch.equal(v, sd[k]) .. GENERATED FROM PYTHON SOURCE LINES 102-105 The loading flow is the inverse: you can use ``torch.load`` with the ``torch.serialization.skip_data`` context manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be created but their storages will be empty (as if the tensors were created via ``torch.empty``). .. GENERATED FROM PYTHON SOURCE LINES 105-109 .. code-block:: default with torch.serialization.skip_data(): sd_loaded = torch.load("checkpoint.pt") .. GENERATED FROM PYTHON SOURCE LINES 110-115 We once again use the ``FakeTensorMode`` to get the checkpoint offsets and ascertain that the loaded checkpoint is the same as the saved checkpoint. Similar to ``torch.cuda.gds.GdsFile.save_storage``, ``torch.cuda.gds.GdsFile.load_storage`` binds to the synchronous ``cuFileRead`` API, so no synchronization is needed afterwards. .. GENERATED FROM PYTHON SOURCE LINES 115-126 .. code-block:: default for k, v in sd_loaded.items(): assert not torch.equal(v, sd[k]) offset = fake_sd[k].untyped_storage()._checkpoint_offset # load_storage is a wrapper around `cuFileRead` f.load_storage(v.untyped_storage(), offset) for k, v in sd_loaded.items(): assert torch.equal(v, sd[k]) del f .. GENERATED FROM PYTHON SOURCE LINES 127-133 Conclusion ========== In this tutorial we have demonstrated how to use the prototype ``torch.cuda.gds`` APIs in conjunction with ``torch.save`` and ``torch.load`` on local filesystem. Please file an issue in the PyTorch GitHub repo if you have any feedback. .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 0.000 seconds) .. _sphx_glr_download_prototype_gpu_direct_storage.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: gpu_direct_storage.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: gpu_direct_storage.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_