Rate this Page

DebugMode: Recording Dispatched Operations and Numerical Debugging#

Authors: Pian Pawakapan, Shangdi Yu

What you will learn
  • How to capture dispatched ops for eager and torch.compile runs

  • How to use tensor hashes and stack traces in DebugMode to pinpoint numerical divergence

Prerequisites
  • PyTorch 2.10 or later

Overview#

DebugMode (torch.utils._debug_mode.DebugMode) is a TorchDispatchMode that intercepts PyTorch runtime calls and emits a hierarchical log of operations. It is particularly useful when you need to understand what actually runs, both in eager mode and under torch.compile or when you need to pinpoint numerical divergence between two runs.

Key capabilities:

  • Runtime logging – Records dispatched operations and TorchInductor compiled Triton kernels.

  • Tensor hashing – Attaches deterministic hashes to inputs/outputs to enable diffing runs to locate numerical divergences.

  • Dispatch hooks – Allows registration of custom hooks to annotate calls

Note

This recipe describes a prototype feature. Prototype features are typically at an early stage for feedback and testing and are subject to change.

Quick start#

The snippet below captures a small eager workload and prints the debug string:

from torch._inductor.decomposition import decomps_to_exclude
import torch
from torch.utils._debug_mode import DebugMode

def run_once():
    x = torch.randn(8, 8)
    y = torch.randn(8, 8)
    return torch.mm(torch.relu(x), y)

with DebugMode() as debug_mode:
    out = run_once()

print("DebugMode output:")
print(debug_mode.debug_string())
DebugMode output:
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]
    aten::relu(t: f32[8, 8])  ->  t: f32[8, 8]
    aten::mm(t: f32[8, 8], t: f32[8, 8])  ->  t: f32[8, 8]

Getting more metadata#

For most investigations, you’ll want to enable stack traces, tensor IDs, and tensor hashing. These features provide metadata to correlate operations back to model code.

DebugMode.log_tensor_hashes decorates the log with hashes for every call. The hash_tensor hash function uses torch.hash_tensor, which returns 0 for tensors whose elements are all the same. The norm hash function uses norm with p=1. With both these functions, especially norm, tensor closeness in numerics is related to hash closeness, so it’s rather interpretable. The default hash_fn is norm.

with (
    DebugMode(
        # record_stack_trace is only supported for eager in pytorch 2.10
        record_stack_trace=True,
        record_ids=True,
    ) as debug_mode,
    DebugMode.log_tensor_hashes(
        hash_fn=["norm"], # this is the default
        hash_inputs=True,
    ),
):
    result = run_once()

print("DebugMode output with more metadata:")
print(
    debug_mode.debug_string(show_stack_trace=True)
)
DebugMode output with more metadata:
    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:59 in run_once, code: x = torch.randn(8, 8)
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t$0: f32[8, 8]  # {'hash': (53.19692662358284,)}

    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:60 in run_once, code: y = torch.randn(8, 8)
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t$1: f32[8, 8]  # {'hash': (49.16436389833689,)}

    # File: /var/lib/workspace/recipes_source/debug_mode_tutorial.py:61 in run_once, code: return torch.mm(torch.relu(x), y)
    aten::relu(t$0: f32[8, 8])  ->  t$2: f32[8, 8]  # {'input_hash': (((53.19692662358284,),), {}), 'hash': (28.665749847888947,)}
    aten::mm(t$2: f32[8, 8], t$1: f32[8, 8])  ->  t$3: f32[8, 8]  # {'input_hash': (((28.665749847888947,), (49.16436389833689,)), {}), 'hash': (80.61730536818504,)}

Each line follows op(args) -> outputs. When record_ids is enabled, tensors are suffixed with $<id> and DTensors are labeled dt.

Log Triton kernels#

Though Triton kernels are not dispatched, DebugMode has custom logic that logs their inputs and outputs.

Inductor-generated Triton kernels show up with a [triton] prefix. Pre/post hash annotations report buffer hashes around each kernel call, which is helpful when isolating incorrect kernels.

def f(x):
    return torch.mm(torch.relu(x), x.T)

x = torch.randn(3, 3, device="cuda")

with (
    DebugMode(record_output=True) as debug_mode,
    DebugMode.log_tensor_hashes(
        hash_inputs=True,
    )
):
    a = torch.compile(f)(x)

print("Triton in DebugMode logs:")
print(debug_mode.debug_string())
/var/lib/ci-user/.local/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:321: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

Triton in DebugMode logs:
    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((6.635273367166519,), {'dtype': None}), 'hash': 6.635273367166519}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((6.635273367166519, None), {}), 'hash': 6.635273367166519}
    aten::_local_scalar_dense(t: f64[])  ->  6.635273367166519  # {'input_hash': ((6.635273367166519,), {}), 'hash': None}
    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((7.733560990542173,), {'dtype': None}), 'hash': 7.733560990542173}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((7.733560990542173, None), {}), 'hash': 7.733560990542173}
    aten::_local_scalar_dense(t: f64[])  ->  7.733560990542173  # {'input_hash': ((7.733560990542173,), {}), 'hash': None}
    [triton] triton_poi_fused_relu_0(in_ptr0=t: f32[3, 3], out_ptr0=t: f32[3, 3], xnumel=9)
    # pre-kernel hashes: {in_ptr0: 6.635273367166519, out_ptr0: 7.733560990542173}
    # post-kernel hashes: {in_ptr0: 6.635273367166519, out_ptr0: 3.1039196252822876}

    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((6.635273367166519,), {'dtype': None}), 'hash': 6.635273367166519}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((6.635273367166519, None), {}), 'hash': 6.635273367166519}
    aten::_local_scalar_dense(t: f64[])  ->  6.635273367166519  # {'input_hash': ((6.635273367166519,), {}), 'hash': None}
    aten::_to_copy(t: f32[3, 3], dtype=torch.float64)  ->  t: f64[3, 3]  # {'input_hash': ((3.1039196252822876,), {'dtype': None}), 'hash': 3.1039196252822876}
    aten::linalg_vector_norm(t: f64[3, 3], 1)  ->  t: f64[]  # {'input_hash': ((3.1039196252822876, None), {}), 'hash': 3.1039196252822876}
    aten::_local_scalar_dense(t: f64[])  ->  3.1039196252822876  # {'input_hash': ((3.1039196252822876,), {}), 'hash': None}
    aten::mm.out(t: f32[3, 3], t: f32[3, 3], out=t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((3.1039196252822876, 6.635273367166519), {'out': 3.6893488147419103e+19}), 'hash': 6.004219740629196}

Numerical debugging with tensor hashes#

If you have numerical divergence between modes, you can use DebugMode to find where the numerical divergence originates. In the example below, you can see that all tensor hashes are the same for eager mode and compiled mode. If any hash is different, then that’s where the numerical divergence is coming from.

def run_model(model, data, *, compile_with=None):
    if compile_with is not None:
        model = torch.compile(model, backend=compile_with)
    with DebugMode(record_output=True) as dm, DebugMode.log_tensor_hashes(
        hash_inputs=True,
    ):
        dm_out = model(*data)
    return dm, dm_out

class Toy(torch.nn.Module):
    def forward(self, x):
        return torch.relu(x).mm(x.T)

inputs = (torch.randn(4, 4),)
dm_eager, _ = run_model(Toy(), inputs)
dm_compiled, _ = run_model(Toy(), inputs, compile_with="aot_eager")

print("Eager mode:")
print(dm_eager.debug_string())
print("Compiled aot_eager mode:")
print(dm_compiled.debug_string())
Eager mode:
    aten::relu(t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((14.7700479850173,), {}), 'hash': 11.188141971826553}
    aten::permute(t: f32[4, 4], [1, 0])  ->  t: f32[4, 4]  # {'input_hash': ((14.7700479850173, [None, None]), {}), 'hash': 14.7700479850173}
    aten::mm(t: f32[4, 4], t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((11.188141971826553, 14.7700479850173), {}), 'hash': 38.937172651290894}
Compiled aot_eager mode:
    aten::relu(t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((14.7700479850173,), {}), 'hash': 11.188141971826553}
    aten::permute(t: f32[4, 4], [1, 0])  ->  t: f32[4, 4]  # {'input_hash': ((14.7700479850173, [None, None]), {}), 'hash': 14.7700479850173}
    aten::mm(t: f32[4, 4], t: f32[4, 4])  ->  t: f32[4, 4]  # {'input_hash': ((11.188141971826553, 14.7700479850173), {}), 'hash': 38.937172651290894}

Now let’s look at an example where the tensor hashes are different. I intentionally wrote a wrong decomposition that decomposes cosine to sin. This will cause numerical divergence.

from torch._dynamo.backends.common import aot_autograd
from torch._dynamo.backends.debugging import get_nop_func

def wrong_decomp(x):
    return torch.sin(x)

decomp_table = {}
decomp_table[torch.ops.aten.cos.default] = wrong_decomp

backend = aot_autograd(
    fw_compiler=get_nop_func(),
    bw_compiler=get_nop_func(),
    decompositions=decomp_table
)

def f(x):
    y = x.relu()
    z = torch.cos(x)
    return y + z

x = torch.randn(3, 3)
with DebugMode(record_output=True) as dm_eager, DebugMode.log_tensor_hashes(
    hash_inputs=True,
):
    f(x)

with DebugMode(record_output=True) as dm_compiled, DebugMode.log_tensor_hashes(
    hash_inputs=True,
):
    torch.compile(f, backend=backend)(x)

print("Eager:")
print(dm_eager.debug_string(show_stack_trace=True))
print()
print("Compiled with wrong decomposition:")
print(dm_compiled.debug_string())
Eager:
    aten::relu(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.542718932032585,), {}), 'hash': 2.4600332379341125}
    aten::cos(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.542718932032585,), {}), 'hash': 6.761354386806488}
    aten::add.Tensor(t: f32[3, 3], t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((2.4600332379341125, 6.761354386806488), {}), 'hash': 9.221387565135956}

Compiled with wrong decomposition:
    aten::relu(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.542718932032585,), {}), 'hash': 2.4600332379341125}
    aten::sin(t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((5.542718932032585,), {}), 'hash': 4.790997065603733}
    aten::add.Tensor(t: f32[3, 3], t: f32[3, 3])  ->  t: f32[3, 3]  # {'input_hash': ((2.4600332379341125, 4.790997065603733), {}), 'hash': 7.251030325889587}

In the eager log, we have aten::cos, but in the compiled log, we have aten::sin. Moreover, the output hash is different between eager and compiled mode. Diffing the two logs would show that the first numerical divergence shows up in the aten::cos call.

Custom dispatch hooks#

Hooks allow you to annotate each call with custom metadata such as GPU memory usage. log_hook returns a mapping that is rendered inline with the debug string.

MB = 1024 * 1024.0

def memory_hook(func, types, args, kwargs, result):
    mem = torch.cuda.memory_allocated() / MB if torch.cuda.is_available() else 0.0
    peak = torch.cuda.max_memory_allocated() / MB if torch.cuda.is_available() else 0.0
    torch.cuda.reset_peak_memory_stats() if torch.cuda.is_available() else None
    return {"mem": f"{mem:.3f} MB", "peak": f"{peak:.3f} MB"}

with (
    DebugMode() as dm,
    DebugMode.dispatch_hooks(log_hook=memory_hook),
):
    run_once()

print("DebugMode output with memory usage:")
print(dm.debug_string())
DebugMode output with memory usage:
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '14.128 MB'}
    aten::randn([8, 8], device=cpu, pin_memory=False)  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}
    aten::relu(t: f32[8, 8])  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}
    aten::mm(t: f32[8, 8], t: f32[8, 8])  ->  t: f32[8, 8]  # {'mem': '8.125 MB', 'peak': '8.125 MB'}

Module boundaries#

record_nn_module=True inserts [nn.Mod] markers that show which module executed each set of operations. As of PyTorch 2.10 it only works in eager mode, but support for compiled modes is under development.

class Foo(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.l1 = torch.nn.Linear(4, 4)
            self.l2 = torch.nn.Linear(4, 4)

        def forward(self, x):
            return self.l2(self.l1(x))

class Bar(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.abc = Foo()
        self.xyz = torch.nn.Linear(4, 4)

    def forward(self, x):
        return self.xyz(self.abc(x))

mod = Bar()
inp = torch.randn(4, 4)
with DebugMode(record_nn_module=True, record_output=False) as debug_mode:
    _ = mod(inp)

print("DebugMode output with stack traces and module boundaries:")
print(debug_mode.debug_string(show_stack_trace=True))
DebugMode output with stack traces and module boundaries:
  [nn.Mod] Bar
    [nn.Mod] Bar.abc
      [nn.Mod] Bar.abc.l1
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
      [nn.Mod] Bar.abc.l2
          aten::t(t: f32[4, 4])
          aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])
    [nn.Mod] Bar.xyz
        aten::t(t: f32[4, 4])
        aten::addmm(t: f32[4], t: f32[4, 4], t: f32[4, 4])

Conclusion#

In this tutorial, we saw how DebugMode gives you a lightweight, runtime-only view of what PyTorch actually executed, whether you are running eager code or compiled graphs. By layering tensor hashing, Triton logging, and custom dispatch hooks you can quickly track down numerical differences. This is especially helpful in debugging bit-wise equivalence between runs.

Total running time of the script: (0 minutes 2.589 seconds)