• Docs >
  • Debugging Torch-TensorRT Compilation
Shortcuts

Debugging Torch-TensorRT Compilation

TensorRT conversion can perform many graph transformations and backend specific optimizations that are sometimes hard to inspect. Torch-TensorRT provides a Debugger utility to help visualize FX graphs around lowering passes, monitor engine building, and capture profiling or TensorRT API traces.

In this example, we demonstrate how to:

  1. Enable the Torch-TensorRT Debugger context

  2. Capture and visualize FX graphs before and/or after specific lowering passes

  3. Configure logging directory and verbosity

import os
import tempfile

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models

temp_dir = os.path.join(tempfile.gettempdir(), "torch_tensorrt_debugger_example")

np.random.seed(0)
torch.manual_seed(0)
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]


model = models.resnet18(pretrained=False).to("cuda").eval()
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
workspace_size = 20 << 30
min_block_size = 0
use_python_runtime = False
torch_executed_ops = {}

with torch_trt.dynamo.Debugger(
    log_level="debug",
    logging_dir=temp_dir,
    engine_builder_monitor=False,  # whether to monitor the engine building process
    capture_fx_graph_after=[
        "complex_graph_detection"
    ],  # fx graph visualization after certain lowering pass
    capture_fx_graph_before=[
        "remove_detach"
    ],  # fx graph visualization before certain lowering pass
):

    trt_gm = torch_trt.dynamo.compile(
        exp_program,
        tuple(inputs),
        use_python_runtime=use_python_runtime,
        enabled_precisions=enabled_precisions,
        min_block_size=min_block_size,
        torch_executed_ops=torch_executed_ops,
        immutable_weights=False,
        reuse_cached_engines=False,
    )

    trt_output = trt_gm(*inputs)


"""
The logging directory will contain the following files:
- /tmp/torch_tensorrt_debugger_example/
    torch_tensorrt_logging.log
    - /lowering_passes_visualization/
        after_complex_graph_detection.svg
        before_remove_detach.svg
"""

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources