Note
Go to the end to download the full example code
Debugging Torch-TensorRT Compilation¶
TensorRT conversion can perform many graph transformations and backend specific optimizations that are sometimes hard to inspect. Torch-TensorRT provides a Debugger utility to help visualize FX graphs around lowering passes, monitor engine building, and capture profiling or TensorRT API traces.
In this example, we demonstrate how to:
Enable the Torch-TensorRT Debugger context
Capture and visualize FX graphs before and/or after specific lowering passes
Configure logging directory and verbosity
import os
import tempfile
import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
temp_dir = os.path.join(tempfile.gettempdir(), "torch_tensorrt_debugger_example")
np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
model = models.resnet18(pretrained=False).to("cuda").eval()
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}
with torch_trt.dynamo.Debugger(
log_level="debug",
logging_dir=temp_dir,
engine_builder_monitor=False, # whether to monitor the engine building process
capture_fx_graph_after=[
"complex_graph_detection"
], # fx graph visualization after certain lowering pass
capture_fx_graph_before=[
"remove_detach"
], # fx graph visualization before certain lowering pass
):
trt_gm = torch_trt.dynamo.compile(
exp_program,
tuple(inputs),
use_python_runtime=use_python_runtime,
enabled_precisions=enabled_precisions,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
immutable_weights=False,
reuse_cached_engines=False,
)
trt_output = trt_gm(*inputs)
"""
The logging directory will contain the following files:
- /tmp/torch_tensorrt_debugger_example/
torch_tensorrt_logging.log
- /lowering_passes_visualization/
after_complex_graph_detection.svg
before_remove_detach.svg
"""
Total running time of the script: ( 0 minutes 0.000 seconds)