torch.utils.checkpoint¶
Note
Checkpointing is implemented by rerunning a forward-pass segment for
each checkpointed segment during backward propagation.  This can cause persistent
states like the RNG state to be more advanced than they would without
checkpointing.  By default, checkpointing includes logic to juggle
the RNG state such that checkpointed passes making use of RNG
(through dropout for example) have deterministic output as
compared to non-checkpointed passes.  The logic to stash and restore
RNG states can incur a moderate performance hit depending on the runtime
of checkpointed operations.  If deterministic output compared to
non-checkpointed passes is not required, supply preserve_rng_state=False
to checkpoint or checkpoint_sequential to omit stashing and
restoring the RNG state during each checkpoint.
The stashing logic saves and restores the RNG state for CPU and another
device type (infer the device type from Tensor arguments excluding CPU
tensors by _infer_device_type) to the run_fn. If there are multiple
device, device state will only be saved for devices of a single device type,
and the remaining devices will be ignored. Consequently, if any checkpointed
functions involve randomness, this may result in incorrect gradients. (Note
that if CUDA devices are among the devices detected, it will be prioritized;
otherwise, the first device encountered will be selected.) If there are no
CPU-tensors, the default device type state (default value is cuda, and it
could be set to other device by DefaultDeviceType) will be saved and restored.
However, the logic has no way to anticipate if the user will move
Tensors to a new device within the run_fn itself.  Therefore, if you move
Tensors to a new device (“new” meaning not belonging to the set of
[current device + devices of Tensor arguments]) within run_fn, deterministic
output compared to non-checkpointed passes is never guaranteed.
- torch.utils.checkpoint.checkpoint(function, *args, use_reentrant=None, context_fn=<function noop_context_fn>, determinism_check='default', debug=False, **kwargs)[source]¶
- Checkpoint a model or part of the model. - Activation checkpointing is a technique that trades compute for memory. Instead of keeping tensors needed for backward alive until they are used in gradient computation during backward, forward computation in checkpointed regions omits saving tensors for backward and recomputes them during the backward pass. Activation checkpointing can be applied to any part of a model. - There are currently two checkpointing implementations available, determined by the - use_reentrantparameter. It is recommended that you use- use_reentrant=False. Please refer the note below for a discussion of their differences.- Warning - If the - functioninvocation during the backward pass differs from the forward pass, e.g., due to a global variable, the checkpointed version may not be equivalent, potentially causing an error being raised or leading to silently incorrect gradients.- Warning - The - use_reentrantparameter should be passed explicitly. In version 2.4 we will raise an exception if- use_reentrantis not passed. If you are using the- use_reentrant=Truevariant, please refer to the note below for important considerations and potential limitations.- Note - The reentrant variant of checkpoint ( - use_reentrant=True) and the non-reentrant variant of checkpoint (- use_reentrant=False) differ in the following ways:- Non-reentrant checkpoint stops recomputation as soon as all needed intermediate activations have been recomputed. This feature is enabled by default, but can be disabled with - set_checkpoint_early_stop(). Reentrant checkpoint always recomputes- functionin its entirety during the backward pass.
- The reentrant variant does not record the autograd graph during the forward pass, as it runs with the forward pass under - torch.no_grad(). The non-reentrant version does record the autograd graph, allowing one to perform backward on the graph within checkpointed regions.
- The reentrant checkpoint only supports the - torch.autograd.backward()API for the backward pass without its inputs argument, while the non-reentrant version supports all ways of performing the backward pass.
- At least one input and output must have - requires_grad=Truefor the reentrant variant. If this condition is unmet, the checkpointed part of the model will not have gradients. The non-reentrant version does not have this requirement.
- The reentrant version does not consider tensors in nested structures (e.g., custom objects, lists, dicts, etc) as participating in autograd, while the non-reentrant version does. 
- The reentrant checkpoint does not support checkpointed regions with detached tensors from the computational graph, whereas the non-reentrant version does. For the reentrant variant, if the checkpointed segment contains tensors detached using - detach()or with- torch.no_grad(), the backward pass will raise an error. This is because- checkpointmakes all the outputs require gradients and this causes issues when a tensor is defined to have no gradient in the model. To avoid this, detach the tensors outside of the- checkpointfunction.
 - Parameters
- function – describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes - (activation, hidden),- functionshould correctly use the first input as- activationand the second input as- hidden
- preserve_rng_state (bool, optional) – Omit stashing and restoring the RNG state during each checkpoint. Note that under torch.compile, this flag doesn’t take effect and we always preserve RNG state. Default: - True
- use_reentrant (bool) – specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed explicitly. In version 2.4 we will raise an exception if - use_reentrantis not passed. If- use_reentrant=False,- checkpointwill use an implementation that does not require reentrant autograd. This allows- checkpointto support additional functionality, such as working as expected with- torch.autograd.gradand support for keyword arguments input into the checkpointed function.
- context_fn (Callable, optional) – A callable returning a tuple of two context managers. The function and its recomputation will be run under the first and second context managers respectively. This argument is only supported if - use_reentrant=False.
- determinism_check (str, optional) – A string specifying the determinism check to perform. By default it is set to - "default"which compares the shapes, dtypes, and devices of the recomputed tensors against those the saved tensors. To turn off this check, specify- "none". Currently these are the only two supported values. Please open an issue if you would like to see more determinism checks. This argument is only supported if- use_reentrant=False, if- use_reentrant=True, the determinism check is always disabled.
- debug (bool, optional) – If - True, error messages will also include a trace of the operators ran during the original forward computation as well as the recomputation. This argument is only supported if- use_reentrant=False.
- args – tuple containing inputs to the - function
 
- Returns
- Output of running - functionon- *args
 
- torch.utils.checkpoint.checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs)[source]¶
- Checkpoint a sequential model to save memory. - Sequential models execute a list of modules/functions in order (sequentially). Therefore, we can divide such a model in various segments and checkpoint each segment. All segments except the last will not store the intermediate activations. The inputs of each checkpointed segment will be saved for re-running the segment in the backward pass. - Warning - The - use_reentrantparameter should be passed explicitly. In version 2.4 we will raise an exception if- use_reentrantis not passed. If you are using the- use_reentrant=True` variant, please see :func:`~torch.utils.checkpoint.checkpoint` for the important considerations and limitations of this variant. It is recommended that you use ``use_reentrant=False.- Parameters
- functions – A - torch.nn.Sequentialor the list of modules or functions (comprising the model) to run sequentially.
- segments – Number of chunks to create in the model 
- input – A Tensor that is input to - functions
- preserve_rng_state (bool, optional) – Omit stashing and restoring the RNG state during each checkpoint. Default: - True
- use_reentrant (bool) – specify whether to use the activation checkpoint variant that requires reentrant autograd. This parameter should be passed explicitly. In version 2.4 we will raise an exception if - use_reentrantis not passed. If- use_reentrant=False,- checkpointwill use an implementation that does not require reentrant autograd. This allows- checkpointto support additional functionality, such as working as expected with- torch.autograd.gradand support for keyword arguments input into the checkpointed function.
 
- Returns
- Output of running - functionssequentially on- *inputs
 - Example - >>> model = nn.Sequential(...) >>> input_var = checkpoint_sequential(model, chunks, input_var) 
- torch.utils.checkpoint.set_checkpoint_debug_enabled(enabled)[source]¶
- Context manager that sets whether checkpoint should print additional debug information when running. See the - debugflag for- checkpoint()for more information. Note that when set, this context manager overrides the value of- debugpassed to checkpoint. To defer to the local setting, pass- Noneto this context.- Parameters
- enabled (bool) – Whether checkpoint should print debug information. Default is ‘None’.