Observer / Callback System#
Torch-TensorRT exposes a lightweight instrumentation hook — the Observer / @observable
system — that lets you attach callbacks to internal compiler functions without modifying
source code. This is useful for:
Logging intermediate graph states during development.
Collecting timing or shape statistics across compilation stages.
Writing integration tests that assert on internal state.
Capturing conversion context snapshots for debugging.
The system lives in py/torch_tensorrt/dynamo/observer.py and is thread- and
context-safe via Python contextvars.
Core Concepts#
ObserverA named, context-variable-backed event source. Any function can hold an
Observerinstance and callobserver.observe(*args)at a point of interest. Registered callbacks receive those arguments.@observable()A decorator that wraps a function and automatically fires
preandpostobservers around each invocation, passing anObserveContextto each callback.ObserveContextDataclass passed to every
@observablecallback with fields:callable— the observed function object.args— positional arguments passed to it.kwargs— keyword arguments passed to it.return_value— return value (only populated inpostcallbacks).
Using Observer Directly#
Any internal function can expose an Observer instance as a module-level constant:
# inside a Torch-TensorRT module (illustrative)
from torch_tensorrt.dynamo.observer import Observer
on_subgraph_converted: Observer = Observer(name="on_subgraph_converted")
def convert_subgraph(gm, settings):
result = _do_convert(gm, settings)
on_subgraph_converted.observe(gm, result) # fire the event
return result
Consumers attach a callback using the context-manager form of Observer.add():
from torch_tensorrt.dynamo.observer import Observer
records = []
def capture(gm, result):
records.append((gm.graph, result.serialized_engine))
with on_subgraph_converted.add(capture):
trt_gm = torch_tensorrt.dynamo.compile(exported_program, arg_inputs=inputs)
print(f"Captured {len(records)} subgraph conversions")
The callback is automatically de-registered when the with block exits, so it is
only active for the compilation inside the block. Callbacks registered outside a
with statement are not automatically removed — call observer._get_callbacks().remove(callback)
manually.
Using @observable#
@observable() is more ergonomic when you own the function being observed. It adds
a .observers attribute of type CallableObservers with two sub-observers:
.pre (fires before the function) and .post (fires after, with the return value):
from torch_tensorrt.dynamo.observer import observable, ObserveContext
@observable()
def my_lowering_pass(gm, settings):
# ... apply transformations ...
return gm
# Attach a pre-callback
def log_before(ctx: ObserveContext):
print(f"About to run {ctx.callable.__name__} on graph with "
f"{len(list(ctx.args[0].graph.nodes))} nodes")
# Attach a post-callback
def log_after(ctx: ObserveContext):
print(f"After pass: graph has "
f"{len(list(ctx.return_value.graph.nodes))} nodes")
with my_lowering_pass.observers.pre.add(log_before), \
my_lowering_pass.observers.post.add(log_after):
my_lowering_pass(gm, settings)
Error Handling#
By default, exceptions raised inside a callback are caught and logged at INFO
level; they do not propagate to the caller. This ensures that instrumentation bugs
never crash a production compilation.
During unit tests you can opt in to exception re-raising:
import torch_tensorrt.dynamo.observer as obs_module
obs_module.RETHROW_CALLBACK_EXCEPTION = True
# ... run test ...
obs_module.RETHROW_CALLBACK_EXCEPTION = False # restore
Concurrency and Context Isolation#
Callbacks are stored in a contextvars.ContextVar, which means each Python
execution context (thread, async task, concurrent.futures worker) has its own
callback registry. A callback registered in one thread is not visible to observers
firing in another thread. This makes the system safe to use in multi-threaded
compilation pipelines.
Observing Existing Compiler Functions#
To observe a function that you do not own (i.e., you cannot decorate it), wrap it with
_make_observable and monkey-patch:
from torch_tensorrt.dynamo import observer as _obs
import torch_tensorrt.dynamo._compiler as _compiler
# Wrap the target function
original = _compiler.compile_module
_compiler.compile_module = _obs._make_observable(original)
records = []
with _compiler.compile_module.observers.post.add(
lambda ctx: records.append(ctx.return_value)
):
result = torch_tensorrt.dynamo.compile(exported_program, arg_inputs=inputs)
# Restore
_compiler.compile_module = original
Warning
Monkey-patching internal functions is fragile across Torch-TensorRT versions.
Prefer using the official Observer instances exposed by each module when
available.