Source code for torch_tensorrt.fx.fx2trt
import logging
import os
import warnings
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence
import numpy
# @manual=//deeplearning/trt/python:py_tensorrt
import tensorrt as trt
import torch
import torch.fx
from torch._ops import OpOverload
from torch.fx.node import _get_qualified_name
from torch.fx.passes.shape_prop import TensorMetadata
from .converter_registry import CONVERTERS
from .input_tensor_spec import InputTensorSpec
from .observer import Observer
from .utils import Frameworks, LowerPrecision, get_dynamic_dims, unified_dtype_converter
_LOGGER: logging.Logger = logging.getLogger(__name__)
TRT_INTERPRETER_CALL_PRE_OBSERVER: Observer[Callable[[torch.fx.GraphModule], None]] = (
    Observer("TRT_INTERPRETER_CALL_PRE_OBSERVER")
)
[docs]class TRTInterpreterResult(NamedTuple):
    engine: Any
    input_names: Sequence[str]
    output_names: Sequence[str]
    serialized_cache: bytearray
[docs]class TRTInterpreter(torch.fx.Interpreter):
    def __init__(
        self,
        module: torch.fx.GraphModule,
        input_specs: List[InputTensorSpec],
        explicit_batch_dimension: bool = False,
        explicit_precision: bool = False,
        logger_level=None,
    ):
        super().__init__(module)
        self.logger = trt.Logger(logger_level or trt.Logger.WARNING)
        self.builder = trt.Builder(self.logger)
        flag = 0
        if explicit_batch_dimension:
            EXPLICIT_BATCH = 1 << (int)(
                trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH
            )
            flag |= EXPLICIT_BATCH
        if explicit_precision:
            EXPLICIT_PRECISION = 1 << (int)(
                trt.NetworkDefinitionCreationFlag.EXPLICIT_PRECISION
            )
            flag |= EXPLICIT_PRECISION
        self.network = self.builder.create_network(flag)
        missing_ops = self.validate_conversion()
        if missing_ops:
            warnings.warn(
                "Interpretation will fail due to missing operations \n"
                + "\n".join(f"{i}" for i in missing_ops)
            )
        self.optimization_profiles: Optional[List] = None
        self.input_specs = input_specs
        self.input_specs_iter = 0
        self.validate_input_specs()
        self._cur_node_name: Optional[str] = None
        self._input_names: List[str] = []
        self._output_names: List[str] = []
        self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
            dict()
        )
    def validate_input_specs(self):
        for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
            if not self.network.has_implicit_batch_dimension:
                assert (
                    has_batch_dim
                ), "It's required to specify batch dimension when it's explicit in TensorRT network."
            dynamic_dims = get_dynamic_dims(shape)
            if len(dynamic_dims):
                assert not self.network.has_implicit_batch_dimension, (
                    "Can't have dynamic dim when "
                    f"batch dim is implicit, got {shape}."
                )
                assert len(
                    shape_ranges
                ), "shape_ranges must be provided when shape has dynamic dim."
                if self.optimization_profiles:
                    assert len(shape_ranges) == len(self.optimization_profiles), (
                        "Number of optimization "
                        f"profiles {len(self.optimization_profiles)} doesn't match with the number of shape_range"
                        f" {len(shape_ranges)} provided."
                    )
                else:
                    self.optimization_profiles = [
                        self.builder.create_optimization_profile()
                        for _ in range(len(shape_ranges))
                    ]
                for shape_range in shape_ranges:
                    assert (
                        len(shape_range) == 3
                    ), f"Expect three elements in shape_range, got {len(shape_range)}"
                    assert all(len(s) == len(shape) for s in shape_range), (
                        "Expect elements in shape_range"
                        f" {shape_range} have the same number of dimension as the provided shape {len(shape)}"
                    )
                    for i in range(len(shape)):
                        if i in dynamic_dims:
                            assert all(
                                shape_range[j][i] <= shape_range[j + 1][i]
                                for j in range(2)
                            ), (
                                "Expect dynamic dim"
                                f" {i} to have incremental value for shapes in shape_range {shape_range}."
                            )
                        else:
                            assert all(s[i] == shape[i] for s in shape_range), (
                                f"Expect non dynamic dim {i} to be the same"
                                f" for all shapes in shape_range {shape_range}."
                            )
            else:
                assert (
                    len(shape_ranges) == 0
                ), "shape_ranges are provided for input that doesn't have dynamic dim."
    def validate_conversion(self):
        missing_converter = set()
        for node in self.module.graph.nodes:
            if node.op == "call_function" and not CONVERTERS.get(node.target):
                missing_converter.add(f"{node.op} {_get_qualified_name(node.target)}")
            elif node.op == "call_method" and not CONVERTERS.get(node.target):
                missing_converter.add(f"{node.op} torch.Tensor.{node.target}")
            elif node.op == "call_module":
                submod = self.fetch_attr(node.target)
                submod_type = getattr(submod, "_base_class_origin", type(submod))
                if not CONVERTERS.get(submod_type):
                    missing_converter.add(f"{node.op} {torch.typename(submod_type)}")
        return missing_converter
    def run(
        self,
        max_batch_size=64,
        max_workspace_size=1 << 25,
        lower_precision=LowerPrecision.FP16,
        sparse_weights=False,
        force_fp32_output=False,
        strict_type_constraints=False,
        algorithm_selector=None,
        timing_cache=None,
        profiling_verbosity=None,
        tactic_sources=None,
    ) -> TRTInterpreterResult:
        """
        Build TensorRT engine with some configs.
        Args:
            max_batch_size: set accordingly for maximum batch size you will use.
            max_workspace_size: set to the maximum size we can afford for temporary buffer
            lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
            sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
            force_fp32_output: force output to be fp32
            strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
            algorithm_selector: set up algorithm selection for certain layer
            timing_cache: enable timing cache for TensorRT
            profiling_verbosity: TensorRT logging level
        Return:
            TRTInterpreterResult
        """
        TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
        # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
        # force_fp32_output=False.
        self.output_fp16 = (
            not force_fp32_output and lower_precision == LowerPrecision.FP16
        )
        if (
            lower_precision == LowerPrecision.INT8
            and not self.builder.platform_has_fast_int8
        ):
            raise RuntimeError("Current platform doesn't support fast native int8!")
        if (
            lower_precision == LowerPrecision.FP16
            and not self.builder.platform_has_fast_fp16
        ):
            warnings.warn("Current platform doesn't support fast native fp16!")
        self.input_specs_iter = 0
        run_module_start_time = datetime.now()
        super().run()
        _LOGGER.info(
            f"TRT INetwork construction elapsed time: {datetime.now() - run_module_start_time}"
        )
        build_engine_start_time = datetime.now()
        self.builder.max_batch_size = max_batch_size
        builder_config = self.builder.create_builder_config()
        builder_config.max_workspace_size = max_workspace_size
        # Speed up TRT build time in the test environment
        if trt.__version__ >= "8.6" and os.environ.get("TRT_TEST_ENV", "0") == "1":
            _LOGGER.info("Set TRT optimization level to 0")
            builder_config.builder_optimization_level = 0
        cache = None
        if timing_cache:
            cache_file = numpy.array(timing_cache)
            cache = builder_config.create_timing_cache(cache_file.tobytes())
        else:
            cache = builder_config.create_timing_cache(b"")
        builder_config.set_timing_cache(cache, False)
        if trt.__version__ >= "8.2":
            builder_config.profiling_verbosity = (
                profiling_verbosity
                if profiling_verbosity
                else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
            )
        if lower_precision == LowerPrecision.FP16:
            builder_config.set_flag(trt.BuilderFlag.FP16)
        if lower_precision == LowerPrecision.INT8:
            builder_config.set_flag(trt.BuilderFlag.INT8)
        if sparse_weights:
            builder_config.set_flag(trt.BuilderFlag.SPARSE_WEIGHTS)
        if strict_type_constraints:
            builder_config.set_flag(trt.BuilderFlag.STRICT_TYPES)
        if self.optimization_profiles:
            for optimization_profile in self.optimization_profiles:
                builder_config.add_optimization_profile(optimization_profile)
        if algorithm_selector:
            builder_config.set_flag(trt.BuilderFlag.DISABLE_TIMING_CACHE)
            builder_config.algorithm_selector = algorithm_selector
        if tactic_sources is not None:
            builder_config.set_tactic_sources(tactic_sources=tactic_sources)
        engine = self.builder.build_engine(self.network, builder_config)
        assert engine
        serialized_cache = (
            bytearray(cache.serialize())
            if builder_config.get_timing_cache()
            else bytearray()
        )
        _LOGGER.info(
            f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
        )
        return TRTInterpreterResult(
            engine, self._input_names, self._output_names, serialized_cache
        )
    def run_node(self, n):
        self._cur_node_name = str(n)
        # add "_itensor_to_tensor_meta"
        kwargs = dict(n.kwargs)
        kwargs["_itensor_to_tensor_meta"] = self._itensor_to_tensor_meta
        n.kwargs = kwargs
        # run the node
        trt_node = super().run_node(n)
        # remove "_itensor_to_tensor_meta"
        kwargs = dict(n.kwargs)
        del kwargs["_itensor_to_tensor_meta"]
        n.kwargs = kwargs
        if isinstance(trt_node, trt.tensorrt.ITensor):
            self._itensor_to_tensor_meta[trt_node] = n.meta.get("tensor_meta")
        return trt_node
    def placeholder(self, target, args, kwargs):
        self._input_names.append(target)
        shape, dtype, _, shape_ranges, has_batch_dim = self.input_specs[
            self.input_specs_iter
        ]
        self.input_specs_iter += 1
        if self.network.has_implicit_batch_dimension:
            if has_batch_dim:
                shape = shape[1:]
        else:
            for i, shape_range in enumerate(shape_ranges):
                assert self.optimization_profiles
                self.optimization_profiles[i].set_shape(target, *shape_range)
        return self.network.add_input(
            name=target,
            shape=tuple(shape),
            dtype=unified_dtype_converter(dtype, Frameworks.TRT),
        )
    def call_module(self, target, args, kwargs):
        assert isinstance(target, str)
        submod = self.fetch_attr(target)
        submod_type = getattr(submod, "_base_class_origin", type(submod))
        converter = CONVERTERS.get(submod_type)
        if not converter:
            raise RuntimeError(
                f"Conversion of module of type {submod_type} not currently supported!"
            )
        assert self._cur_node_name is not None
        return converter(self.network, submod, args, kwargs, self._cur_node_name)
    def call_function(self, target, args, kwargs):
        converter = CONVERTERS.get(target)
        if not converter:
            raise RuntimeError(
                f"Conversion of function {torch.typename(target)} not currently supported!"
            )
        assert self._cur_node_name is not None
        return converter(self.network, target, args, kwargs, self._cur_node_name)
    def call_method(self, target, args, kwargs):
        assert isinstance(target, str)
        converter = CONVERTERS.get(target)
        if not converter:
            raise RuntimeError(
                f"Conversion of method {target} not currently supported!"
            )
        assert self._cur_node_name is not None
        return converter(self.network, target, args, kwargs, self._cur_node_name)
    def output(self, target, args, kwargs):
        assert len(args) == 1
        if isinstance(args[0], tuple):
            outputs = args[0]
        elif isinstance(args[0], list):
            outputs = tuple(args[0])
        else:
            outputs = (args[0],)
        if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
            raise RuntimeError("TensorRT requires all outputs to be Tensor!")
        for i, output in enumerate(outputs):
            if any(
                op_name in output.name.split("_")
                for op_name in (
                    "eq",
                    "gt",
                    "lt",
                    "or",
                    "xor",
                    "and",
                    "not",
                    "ne",
                    "isinf",
                    "any",
                )
            ):
                output_bool = True
            else:
                output_bool = False
            name = f"output{i}"
            output.name = name
            self.network.mark_output(output)
            if output_bool:
                output.dtype = trt.bool
            elif self.output_fp16 and output.dtype == trt.float32:
                output.dtype = trt.float16
            self._output_names.append(name)