• Docs >
  • Automatic Mixed Precision package - torch.cuda.amp
Shortcuts

Automatic Mixed Precision package - torch.cuda.amp

torch.cuda.amp provides convenience methods for running networks with mixed precision, where some operations use the torch.float32 (float) datatype and other operations use torch.float16 (half). Some operations, like linear layers and convolutions, are much faster in float16. Other operations, like reductions, often require the dynamic range of float32. Networks running in mixed precision try to match each operation to its appropriate datatype.

Warning

torch.cuda.amp.GradScaler is not a complete implementation of automatic mixed precision. GradScaler is only useful if you manually run regions of your model in float16. If you aren’t sure how to choose op precision manually, the master branch and nightly pip/conda builds include a context manager that chooses op precision automatically wherever it’s enabled. See the master documentation for details.

Gradient Scaling

When training a network with mixed precision, if the forward pass for a particular op has torch.float16 inputs, the backward pass for that op will produce torch.float16 gradients. Gradient values with small magnitudes may not be representable in torch.float16. These values will flush to zero (“underflow”), so the update for the corresponding parameters will be lost.

To prevent underflow, “gradient scaling” multiplies the network’s loss(es) by a scale factor and invokes a backward pass on the scaled loss(es). Gradients flowing backward through the network are then scaled by the same factor. In other words, gradient values have a larger magnitude, so they don’t flush to zero.

The parameters’ gradients (.grad attributes) should be unscaled before the optimizer uses them to update the parameters, so the scale factor does not interfere with the learning rate.

class torch.cuda.amp.GradScaler(init_scale=65536.0, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000, enabled=True)[source]

An instance scaler of GradScaler helps perform the steps of gradient scaling conveniently.

  • scaler.scale(loss) multiplies a given loss by scaler’s current scale factor.

  • scaler.step(optimizer) safely unscales gradients and calls optimizer.step().

  • scaler.update() updates scaler’s scale factor.

Typical use:

# Creates a GradScaler once at the beginning of training.
scaler = GradScaler()

for epoch in epochs:
    for input, target in data:
        optimizer.zero_grad()
        output = model(input)
        loss = loss_fn(output, target)

        # Scales the loss, and calls backward() on the scaled loss to create scaled gradients.
        scaler.scale(loss).backward()

        # scaler.step() first unscales the gradients of the optimizer's assigned params.
        # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
        # otherwise, optimizer.step() is skipped.
        scaler.step(optimizer)

        # Updates the scale for next iteration.
        scaler.update()

See the Gradient Scaling Examples for usage in more complex cases like gradient clipping, gradient penalty, and multiple losses/optimizers.

scaler dynamically estimates the scale factor each iteration. To minimize gradient underflow, a large scale factor should be used. However, torch.float16 values can “overflow” (become inf or NaN) if the scale factor is too large. Therefore, the optimal scale factor is the largest factor that can be used without incurring inf or NaN gradient values. scaler approximates the optimal scale factor over time by checking the gradients for infs and NaNs during every scaler.step(optimizer) (or optional separate scaler.unscale_(optimizer), see unscale_()).

  • If infs/NaNs are found, scaler.step(optimizer) skips the underlying optimizer.step() (so the params themselves remain uncorrupted) and update() multiplies the scale by backoff_factor.

  • If no infs/NaNs are found, scaler.step(optimizer) runs the underlying optimizer.step() as usual. If growth_interval unskipped iterations occur consecutively, update() multiplies the scale by growth_factor.

The scale factor often causes infs/NaNs to appear in gradients for the first few iterations as its value calibrates. scaler.step will skip the underlying optimizer.step() for these iterations. After that, step skipping should occur rarely (once every few hundred or thousand iterations).

Parameters
  • init_scale (float, optional, default=2.**16) – Initial scale factor.

  • growth_factor (float, optional, default=2.0) – Factor by which the scale is multiplied during update() if no inf/NaN gradients occur for growth_factor consecutive iterations.

  • backoff_factor (float, optional, default=0.5) – Factor by which the scale is multiplied during update() if inf/NaN gradients occur in an iteration.

  • growth_interval (int, optional, default=2000) – Number of consecutive iterations without inf/NaN gradients that must occur for the scale to be multiplied by growth_factor.

  • enabled (bool, optional, default=True) – If False, disables gradient scaling. step() simply invokes the underlying optimizer.step(), and other methods become no-ops.

get_backoff_factor()[source]

Returns a Python float containing the scale backoff factor.

get_growth_factor()[source]

Returns a Python float containing the scale growth factor.

get_growth_interval()[source]

Returns a Python int containing the growth interval.

get_scale()[source]

Returns a Python float containing the current scale, or 1.0 if scaling is disabled.

Warning

get_scale() incurs a CPU-GPU sync.

is_enabled()[source]

Returns a bool indicating whether this instance is enabled.

load_state_dict(state_dict)[source]

Loads the scaler state. If this instance is disabled, load_state_dict() is a no-op.

Parameters

state_dict (dict) – scaler state. Should be an object returned from a call to state_dict().

scale(outputs)[source]

Multiplies (‘scales’) a tensor or list of tensors by the scale factor.

Returns scaled outputs. If this instance of GradScaler is not enabled, outputs are returned unmodified.

Parameters

outputs (Tensor or iterable of Tensors) – Outputs to scale.

set_backoff_factor(new_factor)[source]
Parameters

new_scale (float) – Value to use as the new scale backoff factor.

set_growth_factor(new_factor)[source]
Parameters

new_scale (float) – Value to use as the new scale growth factor.

set_growth_interval(new_interval)[source]
Parameters

new_interval (int) – Value to use as the new growth interval.

state_dict()[source]

Returns the state of the scaler as a dict. It contains five entries:

  • "scale" - a Python float containing the current scale

  • "growth_factor" - a Python float containing the current growth factor

  • "backoff_factor" - a Python float containing the current backoff factor

  • "growth_interval" - a Python int containing the current growth interval

  • "_growth_tracker" - a Python int containing the number of recent consecutive unskipped steps.

If this instance is not enabled, returns an empty dict.

Note

If you wish to checkpoint the scaler’s state after a particular iteration, state_dict() should be called after update().

step(optimizer, *args, **kwargs)[source]

step() carries out the following two operations:

  1. Internally invokes unscale_(optimizer) (unless unscale_() was explicitly called for optimizer earlier in the iteration). As part of the unscale_(), gradients are checked for infs/NaNs.

  2. If no inf/NaN gradients are found, invokes optimizer.step() using the unscaled gradients. Otherwise, optimizer.step() is skipped to avoid corrupting the params.

*args and **kwargs are forwarded to optimizer.step().

Returns the return value of optimizer.step(*args, **kwargs).

Parameters
  • optimizer (torch.optim.Optimizer) – Optimizer that applies the gradients.

  • args – Any arguments.

  • kwargs – Any keyword arguments.

Warning

Closure use is not currently supported.

unscale_(optimizer)[source]

Divides (“unscales”) the optimizer’s gradient tensors by the scale factor.

unscale_() is optional, serving cases where you need to modify or inspect gradients between the backward pass(es) and step(). If unscale_() is not called explicitly, gradients will be unscaled automatically during step().

Simple example, using unscale_() to enable clipping of unscaled gradients:

...
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
Parameters

optimizer (torch.optim.Optimizer) – Optimizer that owns the gradients to be unscaled.

Note

unscale_() does not incur a CPU-GPU sync.

Warning

unscale_() should only be called once per optimizer per step() call, and only after all gradients for that optimizer’s assigned parameters have been accumulated. Calling unscale_() twice for a given optimizer between each step() triggers a RuntimeError.

update(new_scale=None)[source]

Updates the scale factor.

If any optimizer steps were skipped the scale is multiplied by backoff_factor to reduce it. If growth_interval unskipped iterations occurred consecutively, the scale is multiplied by growth_factor to increase it.

Passing new_scale sets the scale directly.

Parameters

new_scale (float or torch.cuda.FloatTensor, optional, default=None) – New scale factor.

Warning

update() should only be called at the end of the iteration, after scaler.step(optimizer) has been invoked for all optimizers used this iteration.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources