Rate this Page

Checkpointing#

This module implements methods for checkpointing and resuming training from a checkpoint.

class torchft.checkpointing.CheckpointTransport[source]#

Bases: Generic[T], ABC

disallow_checkpoint() None[source]#

Called after send_checkpoint to wait for the checkpoint to be sent.

Once this returns, the state_dict may be mutated so no further data should be sent.

abstract metadata() str[source]#

Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.

abstract recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]#

Receives the checkpoint from the given rank.

Parameters
  • src_rank – the rank to receive the checkpoint from

  • metadata – the metadata returned by the remote CheckpointTransport

  • step – the step number to receive

  • timeout – the timeout to wait for the checkpoint

abstract send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]#

Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

Parameters
  • dst_ranks – the ranks to send to

  • step – the step number to send

  • state_dict – the state dict to send

  • timeout – the timeout to wait for the checkpoint to be sent

shutdown(wait: bool = True) None[source]#

Called to shutdown the checkpoint transport.

Parameters

wait – whether to wait for the transport to shutdown

class torchft.checkpointing.HTTPTransport(timeout: timedelta, num_chunks: int)[source]#

Bases: CheckpointTransport[T]

This is an HTTP server that can be used to transfer checkpoints between workers.

This allows for fast recovery of workers by fetching the current weights from an existing worker.

Parameters
  • timeout – the timeout for HTTP requests

  • num_chunks – the number of chunks to split the checkpoint into (0 for no chunking)

address() str[source]#

Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address.

Format: http://host:port/checkpoint/1234

Returns

an HTTP address

allow_checkpoint(step: int) None[source]#

Allows serving the checkpoint with the specified step number.

Parameters

step – the step number to serve

disallow_checkpoint() None[source]#

Disallows serving the checkpoint.

All requests will block until allow_checkpoint is called.

metadata() str[source]#

Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint.

recv_checkpoint(src_rank: int, metadata: str, step: int, timeout: timedelta) T[source]#

Receives the checkpoint from the given rank.

Parameters
  • src_rank – the rank to receive the checkpoint from

  • metadata – the metadata returned by the remote CheckpointTransport

  • step – the step number to receive

  • timeout – the timeout to wait for the checkpoint

send_checkpoint(dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta) None[source]#

Sends the checkpoint, only called when there is a rank that is behind.

This may be async.

Parameters
  • dst_ranks – the ranks to send to

  • step – the step number to send

  • state_dict – the state dict to send

  • timeout – the timeout to wait for the checkpoint to be sent

shutdown(wait: bool = True) None[source]#

Shutdown the server.