• Tutorials >
  • (prototype) Accelerating torch.save and torch.load with GPUDirect Storage
Shortcuts

(prototype) Accelerating torch.save and torch.load with GPUDirect Storage

GPUDirect Storage enables a direct data path for direct memory access transfers between GPU memory and storage, avoiding a bounce buffer through the CPU.

In version 2.7, we introduced new prototype APIs to torch.cuda.gds that serve as thin wrappers around the cuFile APIs that can be used with torch.Tensor to achieve improved I/O performance.

In this tutorial, we will demonstrate how to use the torch.cuda.gds APIs in conjunction with checkpoints generated by torch.save and torch.load on local filesystem.

What you will learn
  • Understand how to use the torch.cuda.gds APIs in conjunction with checkpoints generated by torch.save and torch.load on local filesystem

Prerequisites
  • PyTorch v.2.7.0 or later

  • GPUDirect Storage must be installed per the documentation

  • Ensure that the filesystem that you are saving/loading to supports GPUDirect Storage.

Using GPUDirect Storage with torch.save and torch.load

GPUDirect Storage requires a storage alignment of 4KB. You can toggle this by using torch.utils.serialization.config.save.storage_alignment:

import torch
from torch.utils.serialization import config as serialization_config

serialization_config.save.storage_alignment = 4096
The steps involved in the process are as follows:
  • Write the checkpoint file without any actual data. This reserves the space on disk.

  • Read the offsets for the storage associated with each tensor in the checkpoint using FakeTensor.

  • Use GDSFile to write the appropriate data at these offsets.

Given a state dictionary of tensors that are on the GPU, one can use the torch.serialization.skip_data context manager to save a checkpoint that contains all relevant metadata except the storage bytes. For each torch.Storage in the state dictionary, space will be reserved within the checkpoint for the storage bytes.

import torch.nn as nn

m = nn.Linear(5, 10, device='cuda')
sd = m.state_dict()

with torch.serialization.skip_data():
    torch.save(sd, "checkpoint.pt")

We can get the offsets that each storage should be written to within the checkpoint by loading under a FakeTensorMode. A FakeTensor is a tensor that has metadata (such as sizes, strides, dtype, device) information about the tensor but does not have any storage bytes. The following snippet will not materialize any data but will tag each FakeTensor with the offset within the checkpoint that corresponds to the tensor.

If you are continuously saving the same state dictionary during training, you would only need to obtain the offsets once and the same offsets can be re-used. Similarly if tensor is going to be saved or loaded to repeatedly you can use the torch.cuda.gds.gds_register_buffer which wraps cuFileBufRegister to register the storages as GDS buffers.

Note that torch.cuda.gds.GdsFile.save_storage binds to the synchronous cuFileWrite API, so no synchronization is needed afterwards.

import os
from torch._subclasses.fake_tensor import FakeTensorMode

with FakeTensorMode() as mode:
    fake_sd = torch.load("checkpoint.pt")

for k, v in fake_sd.items():
    print(f"key={k}, offset={v.untyped_storage()._checkpoint_offset}")

f = torch.cuda.gds.GdsFile("checkpoint.pt", os.O_RDWR)

for k, v in sd.items():
    offset = fake_sd[k].untyped_storage()._checkpoint_offset
    # save_storage is a wrapper around `cuFileWrite`
    f.save_storage(v.untyped_storage(), offset)

We verify correctness of the saved checkpoint by torch.load and comparing.

sd_loaded = torch.load("checkpoint.pt")
for k, v in sd_loaded.items():
    assert torch.equal(v, sd[k])

The loading flow is the inverse: you can use torch.load with the torch.serialization.skip_data context manager to load everything except the storage bytes. This means that any tensors in the checkpoint will be created but their storages will be empty (as if the tensors were created via torch.empty).

with torch.serialization.skip_data():
    sd_loaded = torch.load("checkpoint.pt")

We once again use the FakeTensorMode to get the checkpoint offsets and ascertain that the loaded checkpoint is the same as the saved checkpoint.

Similar to torch.cuda.gds.GdsFile.save_storage, torch.cuda.gds.GdsFile.load_storage binds to the synchronous cuFileRead API, so no synchronization is needed afterwards.

for k, v in sd_loaded.items():
    assert not torch.equal(v, sd[k])
    offset = fake_sd[k].untyped_storage()._checkpoint_offset
    # load_storage is a wrapper around `cuFileRead`
    f.load_storage(v.untyped_storage(), offset)

for k, v in sd_loaded.items():
    assert torch.equal(v, sd[k])

del f

Conclusion

In this tutorial we have demonstrated how to use the prototype torch.cuda.gds APIs in conjunction with torch.save and torch.load on local filesystem. Please file an issue in the PyTorch GitHub repo if you have any feedback.

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources