备注
点击 here 下载完整示例代码
(原型)使用GPUDirect Storage加速``torch.save``和``torch.load``¶
GPUDirect Storage为GPU内存和存储之间的数据直接路径提供直接内存访问传输,避免通过CPU的中间缓冲区。
在版本**2.7**中,我们针对``torch.cuda.gds``引入了新的原型API,这些API作为`cuFile API <https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api>`_的轻量封装,可与``torch.Tensor``一起使用以实现更高的I/O性能。
在本教程中,我们将演示如何结合``torch.save``和``torch.load``在本地文件系统上使用``torch.cuda.gds``API。
了解如何结合``torch.save``和``torch.load``在本地文件系统上使用``torch.cuda.gds``API
需要PyTorch v.2.7.0或更高版本
GPUDirect Storage必须安装,具体步骤参见`文档 <https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/contents.html>`_
确保您保存/加载到的文件系统支持GPUDirect Storage。
结合``torch.save``和``torch.load``使用GPUDirect Storage¶
GPUDirect Storage需要4KB的存储对齐。您可以通过使用``torch.utils.serialization.config.save.storage_alignment``进行切换:
import torch
from torch.utils.serialization import config as serialization_config
serialization_config.save.storage_alignment = 4096
- 具体步骤如下:
写入检查点文件但不包含任何实际数据。这会在磁盘上保留空间。
使用``FakeTensor``读取检查点中与每个张量相关的存储的偏移量。
使用``GDSFile``将适当的数据写入这些偏移量。
在GPU上的张量状态字典情况下,可以使用``torch.serialization.skip_data``上下文管理器保存仅包含相关元数据但不包括存储字节的检查点。对于状态字典中的每个``torch.Storage``,空间会在检查点中为存储字节保留。
import torch.nn as nn
m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data():
torch.save(sd, "checkpoint.pt")
我们可以通过在``FakeTensorMode``下加载来获得每个存储应写入检查点的偏移量。FakeTensor 是一种具有元数据(例如大小、步幅、数据类型、设备)信息的张量,但没有存储字节。以下代码段不会实际化任何数据,但会将每个``FakeTensor``标记为检查点中张量对应的偏移量。
如果您在训练期间持续保存相同的状态字典,您只需获取一次偏移量即可,并且这些偏移量可以重复使用。同样,如果张量将被重复保存或加载,您可以使用``torch.cuda.gds.gds_register_buffer``包装``cuFileBufRegister``以将存储注册为GDS缓冲区。
注意,``torch.cuda.gds.GdsFile.save_storage``绑定到同步``cuFileWrite``API,因此之后不需要任何同步。
import os
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode() as mode:
fake_sd = torch.load("checkpoint.pt")
for k, v in fake_sd.items():
print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")
f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)
for k, v in sd.items():
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# save_storage is a wrapper around `cuFileWrite`
f.save_storage(v.untyped_storage(), offset)
我们通过``torch.load``及对比验证保存的检查点的正确性。
sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])
加载流程是相反的:您可以使用``torch.load``与``torch.serialization.skip_data``上下文管理器加载所有数据(不包括存储字节)。这意味着检查点中的任何张量会被创建,但其存储为空(就像张量通过``torch.empty``创建一样)。
with torch.serialization.skip_data():
sd_loaded = torch.load("checkpoint.pt")
我们再次使用``FakeTensorMode``获取检查点偏移量并确认加载的检查点与保存的检查点相同。
与``torch.cuda.gds.GdsFile.save_storage``类似,``torch.cuda.gds.GdsFile.load_storage``绑定到同步``cuFileRead``API,因此之后不需要任何同步。
for k, v in sd_loaded.items():
assert not torch.equal(v, sd[k])
offset = fake_sd[k].untyped_storage()._checkpoint_offset
# load_storage is a wrapper around `cuFileRead`
f.load_storage(v.untyped_storage(), offset)
for k, v in sd_loaded.items():
assert torch.equal(v, sd[k])
del f
总结¶
在本教程中,我们演示了如何结合``torch.save``和``torch.load``在本地文件系统上使用原型``torch.cuda.gds``API。如果您有任何反馈,请在PyTorch GitHub仓库中提交问题。
脚本总运行时间: (0分钟 0.000秒)