Shortcuts

torch.testing

Warning

This module is in a PROTOTYPE state. New functions are still being added, and the available functions may change in future PyTorch releases. We are actively looking for feedback for UI/UX improvements or missing functionalities.

torch.testing.assert_close(actual, expected, *, rtol=None, atol=None, equal_nan=False, check_device=True, check_dtype=True, check_stride=True, msg=None)[source]

Asserts that actual and expected are close.

If actual and expected are real-valued and finite, they are considered close if

actualexpectedatol+rtolexpected\lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert

and they have the same device (if check_device is True), same dtype (if check_dtype is True), and the same stride (if check_stride is True). Non-finite values (-inf and inf) are only considered close if and only if they are equal. NaN’s are only considered equal to each other if equal_nan is True.

If actual and expected are complex-valued, they are considered close if both their real and imaginary components are considered close according to the definition above.

actual and expected can be Tensor’s or any array-or-scalar-like of the same type, from which torch.Tensor’s can be constructed with torch.as_tensor(). In addition, actual and expected can be Sequence’s or Mapping’s in which case they are considered close if their structure matches and all their elements are considered close according to the above definition.

Parameters
  • actual (Any) – Actual input.

  • expected (Any) – Expected input.

  • rtol (Optional[float]) – Relative tolerance. If specified atol must also be specified. If omitted, default values based on the dtype are selected with the below table.

  • atol (Optional[float]) – Absolute tolerance. If specified rtol must also be specified. If omitted, default values based on the dtype are selected with the below table.

  • equal_nan (Union[bool, str]) – If True, two NaN values will be considered equal. If "relaxed", complex values are considered as NaN if either the real or imaginary component is NaN.

  • check_device (bool) – If True (default), asserts that corresponding tensors are on the same device. If this check is disabled, tensors on different device’s are moved to the CPU before being compared.

  • check_dtype (bool) – If True (default), asserts that corresponding tensors have the same dtype. If this check is disabled, tensors with different dtype’s are promoted to a common dtype (according to torch.promote_types()) before being compared.

  • check_stride (bool) – If True (default), asserts that corresponding tensors have the same stride.

  • msg (Optional[Union[str, Callable[[Tensor, Tensor, DiagnosticInfo], str]]]) – Optional error message to use if the values of corresponding tensors mismatch. Can be passed as callable in which case it will be called with the mismatching tensors and a namespace of diagnostic info about the mismatches. See below for details.

Raises
  • UsageError – If a torch.Tensor can’t be constructed from an array-or-scalar-like.

  • UsageError – If any tensor is quantized or sparse. This is a temporary restriction and will be relaxed in the future.

  • UsageError – If only rtol or atol is specified.

  • AssertionError – If corresponding array-likes have different types.

  • AssertionError – If the inputs are Sequence’s, but their length does not match.

  • AssertionError – If the inputs are Mapping’s, but their set of keys do not match.

  • AssertionError – If corresponding tensors do not have the same shape.

  • AssertionError – If check_device, but corresponding tensors are not on the same device.

  • AssertionError – If check_dtype, but corresponding tensors do not have the same dtype.

  • AssertionError – If check_stride, but corresponding tensors do not have the same stride.

  • AssertionError – If the values of corresponding tensors are not close.

The following table displays the default rtol and atol for different dtype’s. Note that the dtype refers to the promoted type in case actual and expected do not have the same dtype.

dtype

rtol

atol

float16

1e-3

1e-5

bfloat16

1.6e-2

1e-5

float32

1.3e-6

1e-5

float64

1e-7

1e-7

complex32

1e-3

1e-5

complex64

1.3e-6

1e-5

complex128

1e-7

1e-7

other

0.0

0.0

The namespace of diagnostic information that will be passed to msg if its a callable has the following attributes:

  • number_of_elements (int): Number of elements in each tensor being compared.

  • total_mismatches (int): Total number of mismatches.

  • mismatch_ratio (float): Total mismatches divided by number of elements.

  • max_abs_diff (Union[int, float]): Greatest absolute difference of the inputs.

  • max_abs_diff_idx (Union[int, Tuple[int, …]]): Index of greatest absolute difference.

  • max_rel_diff (Union[int, float]): Greatest relative difference of the inputs.

  • max_rel_diff_idx (Union[int, Tuple[int, …]]): Index of greatest relative difference.

For max_abs_diff and max_rel_diff the type depends on the dtype of the inputs.

Examples

>>> # tensor to tensor comparison
>>> expected = torch.tensor([1e0, 1e-1, 1e-2])
>>> actual = torch.acos(torch.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # scalar to scalar comparison
>>> import math
>>> expected = math.sqrt(2.0)
>>> actual = 2.0 / math.sqrt(2.0)
>>> torch.testing.assert_close(actual, expected)
>>> # numpy array to numpy array comparison
>>> import numpy as np
>>> expected = np.array([1e0, 1e-1, 1e-2])
>>> actual = np.arccos(np.cos(expected))
>>> torch.testing.assert_close(actual, expected)
>>> # sequence to sequence comparison
>>> import numpy as np
>>> # The types of the sequences do not have to match. They only have to have the same
>>> # length and their elements have to match.
>>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
>>> actual = tuple(expected)
>>> torch.testing.assert_close(actual, expected)
>>> # mapping to mapping comparison
>>> from collections import OrderedDict
>>> import numpy as np
>>> foo = torch.tensor(1.0)
>>> bar = 2.0
>>> baz = np.array(3.0)
>>> # The types and a possible ordering of mappings do not have to match. They only
>>> # have to have the same set of keys and their elements have to match.
>>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
>>> actual = {"baz": baz, "bar": bar, "foo": foo}
>>> torch.testing.assert_close(actual, expected)
>>> # Different input types are never considered close.
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = expected.numpy()
>>> torch.testing.assert_close(actual, expected)
AssertionError: Except for scalars, type equality is required, but got
<class 'numpy.ndarray'> and <class 'torch.Tensor'> instead.
>>> # Scalars of different types are an exception and can be compared with
>>> # check_dtype=False.
>>> torch.testing.assert_close(1.0, 1, check_dtype=False)
>>> # NaN != NaN by default.
>>> expected = torch.tensor(float("Nan"))
>>> actual = expected.clone()
>>> torch.testing.assert_close(actual, expected)
AssertionError: Tensors are not close!
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
>>> # If equal_nan=True, the real and imaginary NaN's of complex inputs have to match.
>>> expected = torch.tensor(complex(float("NaN"), 0))
>>> actual = torch.tensor(complex(0, float("NaN")))
>>> torch.testing.assert_close(actual, expected, equal_nan=True)
AssertionError: Tensors are not close!
>>> # If equal_nan="relaxed", however, then complex numbers are treated as NaN if any
>>> # of the real or imaginary component is NaN.
>>> torch.testing.assert_close(actual, expected, equal_nan="relaxed")
>>> expected = torch.tensor([1.0, 2.0, 3.0])
>>> actual = torch.tensor([1.0, 4.0, 5.0])
>>> # The default mismatch message can be overwritten.
>>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
AssertionError: Argh, the tensors are not close!
>>> # The error message can also created at runtime by passing a callable.
>>> def custom_msg(actual, expected, diagnostic_info):
...     return (
...         f"Argh, we found {diagnostic_info.total_mismatches} mismatches! "
...         f"That is {diagnostic_info.mismatch_ratio:.1%}!"
...     )
>>> torch.testing.assert_close(actual, expected, msg=custom_msg)
AssertionError: Argh, we found 2 mismatches! That is 66.7%!

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources