• Tutorials >
  • (原型)使用GPUDirect Storage加速``torch.save``和``torch.load``
Shortcuts

(原型)使用GPUDirect Storage加速``torch.save``和``torch.load``

GPUDirect Storage为GPU内存和存储之间的数据直接路径提供直接内存访问传输,避免通过CPU的中间缓冲区。

在版本**2.7**中,我们针对``torch.cuda.gds``引入了新的原型API,这些API作为`cuFile API <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_的轻量封装,可与``torch.Tensor``一起使用以实现更高的I/O性能。

在本教程中,我们将演示如何结合``torch.save``和``torch.load``在本地文件系统上使用``torch.cuda.gds``API。

What you will learn
  • 了解如何结合``torch.save``和``torch.load``在本地文件系统上使用``torch.cuda.gds``API

Prerequisites

结合``torch.save``和``torch.load``使用GPUDirect Storage

GPUDirect Storage需要4KB的存储对齐。您可以通过使用``torch.utils.serialization.config.save.storage_alignment``进行切换:

import torch
from torch.utils.serialization import config as serialization_config

serialization_config.save.storage_alignment = 4096
具体步骤如下:
  • 写入检查点文件但不包含任何实际数据。这会在磁盘上保留空间。

  • 使用``FakeTensor``读取检查点中与每个张量相关的存储的偏移量。

  • 使用``GDSFile``将适当的数据写入这些偏移量。

在GPU上的张量状态字典情况下,可以使用``torch.serialization.skip_data``上下文管理器保存仅包含相关元数据但不包括存储字节的检查点。对于状态字典中的每个``torch.Storage``,空间会在检查点中为存储字节保留。

import torch.nn as nn

m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()

with torch.serialization.skip_data():
    torch.save(sd, "checkpoint.pt")

我们可以通过在``FakeTensorMode``下加载来获得每个存储应写入检查点的偏移量。FakeTensor 是一种具有元数据(例如大小、步幅、数据类型、设备)信息的张量,但没有存储字节。以下代码段不会实际化任何数据,但会将每个``FakeTensor``标记为检查点中张量对应的偏移量。

如果您在训练期间持续保存相同的状态字典,您只需获取一次偏移量即可,并且这些偏移量可以重复使用。同样,如果张量将被重复保存或加载,您可以使用``torch.cuda.gds.gds_register_buffer``包装``cuFileBufRegister``以将存储注册为GDS缓冲区。

注意,``torch.cuda.gds.GdsFile.save_storage``绑定到同步``cuFileWrite``API,因此之后不需要任何同步。

import os
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode() as mode:
    fake_sd = torch.load("checkpoint.pt")

for k, v in fake_sd.items():
    print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")

f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)

for k, v in sd.items():
    offset = fake_sd[k].untyped_storage()._checkpoint_offset
    # save_storage is a wrapper around `cuFileWrite`
    f.save_storage(v.untyped_storage(), offset)

我们通过``torch.load``及对比验证保存的检查点的正确性。

sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
    assert torch.equal(v, sd[k])

加载流程是相反的:您可以使用``torch.load``与``torch.serialization.skip_data``上下文管理器加载所有数据(不包括存储字节)。这意味着检查点中的任何张量会被创建,但其存储为空(就像张量通过``torch.empty``创建一样)。

with torch.serialization.skip_data():
    sd_loaded = torch.load("checkpoint.pt")

我们再次使用``FakeTensorMode``获取检查点偏移量并确认加载的检查点与保存的检查点相同。

与``torch.cuda.gds.GdsFile.save_storage``类似,``torch.cuda.gds.GdsFile.load_storage``绑定到同步``cuFileRead``API,因此之后不需要任何同步。

for k, v in sd_loaded.items():
    assert not torch.equal(v, sd[k])
    offset = fake_sd[k].untyped_storage()._checkpoint_offset
    # load_storage is a wrapper around `cuFileRead`
    f.load_storage(v.untyped_storage(), offset)

for k, v in sd_loaded.items():
    assert torch.equal(v, sd[k])

del f

总结

在本教程中,我们演示了如何结合``torch.save``和``torch.load``在本地文件系统上使用原型``torch.cuda.gds``API。如果您有任何反馈,请在PyTorch GitHub仓库中提交问题。

脚本总运行时间: (0分钟 0.000秒)

画廊由Sphinx-Gallery生成

文档

访问 PyTorch 的详细开发者文档

查看文档

教程

获取针对初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源