Source code for torch_tensorrt.runtime._multi_device_safe_mode
import logging
from typing import Any
import torch
import torch_tensorrt
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
    _PY_RT_MULTI_DEVICE_SAFE_MODE = torch.ops.tensorrt.get_multi_device_safe_mode()
else:
    _PY_RT_MULTI_DEVICE_SAFE_MODE = False
logger = logging.getLogger(__name__)
class _MultiDeviceSafeModeContextManager(object):
    """Helper class used in conjunction with `set_multi_device_safe_mode`
    Used to enable `set_multi_device_safe_mode` as a dual-purpose context manager
    """
    def __init__(self, old_mode: bool) -> None:
        self.old_mode = old_mode
    def __enter__(self) -> "_MultiDeviceSafeModeContextManager":
        return self
    def __exit__(self, *args: Any) -> None:
        # Set multi-device safe mode back to old mode in Python
        global _PY_RT_MULTI_DEVICE_SAFE_MODE
        _PY_RT_MULTI_DEVICE_SAFE_MODE = self.old_mode
        # Set multi-device safe mode back to old mode in C++
        if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
            torch.ops.tensorrt.set_multi_device_safe_mode(self.old_mode)
[docs]def set_multi_device_safe_mode(mode: bool) -> _MultiDeviceSafeModeContextManager:
    """Sets the runtime (Python-only and default) into multi-device safe mode
    In the case that multiple devices are available on the system, in order for the
    runtime to execute safely, additional device checks are necessary. These checks
    can have a performance impact so they are therefore opt-in. Used to suppress
    the warning about running unsafely in a multi-device context.
    Arguments:
        mode (bool): Enable (``True``) or disable (``False``) multi-device checks
    Example:
        .. code-block:: py
            with torch_tensorrt.runtime.set_multi_device_safe_mode(True):
                results = trt_compiled_module(*inputs)
    """
    # Fetch existing safe mode and set new mode for Python
    global _PY_RT_MULTI_DEVICE_SAFE_MODE
    old_mode = _PY_RT_MULTI_DEVICE_SAFE_MODE
    _PY_RT_MULTI_DEVICE_SAFE_MODE = mode
    # Set new mode for C++
    if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
        torch.ops.tensorrt.set_multi_device_safe_mode(mode)
    logger.info(f"Set multi-device safe mode to {mode}")
    # Return context manager in case the function is used in a `with` call
    return _MultiDeviceSafeModeContextManager(old_mode)