torch.use_deterministic_algorithms#
- torch.use_deterministic_algorithms(mode, *, warn_only=False)[source]#
Sets whether PyTorch operations must use “deterministic” algorithms. That is, algorithms which, given the same input, and when run on the same software and hardware, always produce the same output. When enabled, operations will use deterministic algorithms when available, and if only nondeterministic algorithms are available they will throw a
RuntimeErrorwhen called.Note
This setting alone is not always enough to make an application reproducible. Refer to Reproducibility for more information.
Note
torch.set_deterministic_debug_mode()offers an alternative interface for this feature.The following normally-nondeterministic operations will act deterministically when
mode=True:torch.nn.Conv1dwhen called on CUDA tensortorch.nn.Conv2dwhen called on CUDA tensortorch.nn.Conv3dwhen called on CUDA tensortorch.nn.ConvTranspose1dwhen called on CUDA tensortorch.nn.ConvTranspose2dwhen called on CUDA tensortorch.nn.ConvTranspose3dwhen called on CUDA tensortorch.nn.ReplicationPad1dwhen attempting to differentiate a CUDA tensortorch.nn.ReplicationPad2dwhen attempting to differentiate a CUDA tensortorch.nn.ReplicationPad3dwhen attempting to differentiate a CUDA tensortorch.bmm()when called on sparse-dense CUDA tensorstorch.Tensor.__getitem__()when attempting to differentiate a CPU tensor and the index is a list of tensorstorch.Tensor.index_put()withaccumulate=Falsetorch.Tensor.index_put()withaccumulate=Truewhen called on a CPU tensortorch.Tensor.put_()withaccumulate=Truewhen called on a CPU tensortorch.Tensor.scatter_add_()when called on a CUDA tensortorch.gather()when called on a CUDA tensor that requires gradtorch.index_add()when called on CUDA tensortorch.index_select()when attempting to differentiate a CUDA tensortorch.repeat_interleave()when attempting to differentiate a CUDA tensortorch.Tensor.index_copy()when called on a CPU or CUDA tensortorch.Tensor.scatter()when src type is Tensor and called on CUDA tensortorch.Tensor.scatter_reduce()whenreduce='sum'orreduce='mean'and called on CUDA tensor
The following normally-nondeterministic operations will throw a
RuntimeErrorwhenmode=True:torch.nn.AvgPool3dwhen attempting to differentiate a CUDA tensortorch.nn.AdaptiveAvgPool2dwhen attempting to differentiate a CUDA tensortorch.nn.AdaptiveAvgPool3dwhen attempting to differentiate a CUDA tensortorch.nn.MaxPool3dwhen attempting to differentiate a CUDA tensortorch.nn.AdaptiveMaxPool2dwhen attempting to differentiate a CUDA tensortorch.nn.FractionalMaxPool2dwhen attempting to differentiate a CUDA tensortorch.nn.FractionalMaxPool3dwhen attempting to differentiate a CUDA tensortorch.nn.functional.interpolate()when attempting to differentiate a CUDA tensor and one of the following modes is used:linearbilinearbicubictrilinear
torch.nn.ReflectionPad1dwhen attempting to differentiate a CUDA tensortorch.nn.ReflectionPad2dwhen attempting to differentiate a CUDA tensortorch.nn.ReflectionPad3dwhen attempting to differentiate a CUDA tensortorch.nn.NLLLosswhen called on a CUDA tensortorch.nn.CTCLosswhen attempting to differentiate a CUDA tensortorch.nn.EmbeddingBagwhen attempting to differentiate a CUDA tensor whenmode='max'torch.Tensor.put_()whenaccumulate=Falsetorch.Tensor.put_()whenaccumulate=Trueand called on a CUDA tensortorch.histc()when called on a CUDA tensortorch.bincount()when called on a CUDA tensor andweightstensor is giventorch.median()with indices output when called on a CUDA tensortorch.nn.functional.grid_sample()when attempting to differentiate a CUDA tensortorch.cumsum()when called on a CUDA tensor when dtype is floating point or complextorch.Tensor.scatter_reduce()whenreduce='prod'and called on CUDA tensortorch.Tensor.resize_()when called with a quantized tensor
In addition, several operations fill uninitialized memory when this setting is turned on and when
torch.utils.deterministic.fill_uninitialized_memoryis turned on. See the documentation for that attribute for more information.Note that deterministic operations tend to have worse performance than nondeterministic operations.
When this setting is turned on, the Inductor deterministic mode is also tuned on automatically. In deterministic mode, Inductor would avoid doing on device benchmarking that affect numerics. This includes:
don’t pad matmul input shapes. Without enabling deterministic mode, Inductor would do benchmarking to check if padding matmul shape is beneficial.
don’t autotune templates. Inductor has templates for kernels like matmul/conv/attention. Without enabling deterministic mode, Inductor would do autotuning to pick the best configs for those templates and adopt it if it’s faster than the kernel in eager mode. In deterministic mode, we pick the eager kernel.
don’t autotune triton configs for reduction. Reduction numerics are very sensitive to triton configs. In deterministic mode, Inductor will use some heuristics to pick the most promising configs rather than do autotuning.
Skip autotuning for reduction in coordinate descent tuning.
Don’t benchmarking for the computation/communication reordering pass
Disable the feature that dynamically scale down RBLOCK triton config for higher occupancy.
Note
This flag does not detect or prevent nondeterministic behavior caused by calling an inplace operation on a tensor with an internal memory overlap or by giving such a tensor as the
outargument for an operation. In these cases, multiple writes of different data may target a single memory location, and the order of writes is not guaranteed.- Parameters
mode (
bool) – If True, makes potentially nondeterministic operations switch to a deterministic algorithm or throw a runtime error. If False, allows nondeterministic operations.- Keyword Arguments
warn_only (
bool, optional) – If True, operations that do not have a deterministic implementation will throw a warning instead of an error. Default:False
Example:
>>> torch.use_deterministic_algorithms(True) # Backward mode nondeterministic error >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward() ... RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation...