Shortcuts

Source code for torchrl.record.loggers.common

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import abc
from collections.abc import Sequence
from typing import Any

import torch
from tensordict import TensorDictBase
from torch import Tensor


__all__ = ["Logger"]


def _make_metrics_safe(
    metrics: dict[str, Any] | TensorDictBase,
    *,
    keys_sep: str = "/",
) -> dict[str, Any]:
    """Convert metric values to be safe for cross-process logging.

    This function converts torch tensors to CPU/Python types, which is
    necessary when logging metrics to external services (e.g., wandb, mlflow)
    that may run in separate processes without GPU access.

    For regular dicts, the implementation batches CUDA->CPU transfers using
    non_blocking=True and synchronizes once via a CUDA event, avoiding the
    overhead of multiple implicit synchronizations that would occur if calling
    .item() on each CUDA tensor individually.

    For TensorDict inputs, this leverages TensorDict's efficient batch `.to()`
    method which transfers all tensors in a single operation.

    Args:
        metrics: Dictionary or TensorDict of metric names to values. Values can
            be torch.Tensor (CUDA or CPU), Python scalars, or other types.
        keys_sep: Separator used to flatten nested TensorDict keys into strings.
            Defaults to "/". Only used for TensorDict inputs.

    Returns:
        Dictionary with the same keys but tensor values converted to
        Python scalars (for single-element tensors) or lists (for
        multi-element tensors). Non-tensor values are passed through unchanged.
    """
    if isinstance(metrics, TensorDictBase):
        return _make_metrics_safe_tensordict(metrics, keys_sep=keys_sep)

    out: dict[str, Any] = {}
    cpu_tensors: dict[str, Tensor] = {}
    has_cuda_tensors = False

    # First pass: identify tensors and start non-blocking CUDA->CPU transfers
    for key, value in metrics.items():
        if isinstance(value, Tensor):
            if value.is_cuda:
                # Non-blocking transfer - queues the copy without waiting
                value = value.detach().to("cpu", non_blocking=True)
                has_cuda_tensors = True
            else:
                value = value.detach()
            cpu_tensors[key] = value
        else:
            out[key] = value

    # Explicit sync: use a CUDA event instead of global synchronize() - this
    # only waits for work up to the point the event was recorded, not ALL
    # pending GPU work.
    if has_cuda_tensors:
        event = torch.cuda.Event()
        event.record()
        event.synchronize()

    # Second pass: convert CPU tensors to Python scalars/lists
    for key, value in cpu_tensors.items():
        if value.numel() == 1:
            out[key] = value.item()
        else:
            out[key] = value.tolist()

    return out


def _make_metrics_safe_tensordict(
    metrics: TensorDictBase,
    *,
    keys_sep: str = "/",
) -> dict[str, Any]:
    """Convert TensorDict metric values to be safe for cross-process logging.

    This leverages TensorDict's efficient batch `.to()` method which transfers
    all tensors in a single operation, then converts to Python scalars.

    Args:
        metrics: TensorDict of metric names to tensor values.
        keys_sep: Separator used to flatten nested keys into strings.

    Returns:
        Dictionary with flattened string keys and Python scalar/list values.
    """
    # TensorDict's .to() efficiently batches all tensor transfers
    metrics = metrics.to("cpu", non_blocking=True)

    # Sync if CUDA is in use - the event sync is cheap if no work is pending
    if torch.cuda.is_initialized():
        event = torch.cuda.Event()
        event.record()
        event.synchronize()

    # Flatten nested keys and convert to dict
    flat_dict = metrics.flatten_keys(keys_sep).to_dict()

    # Convert tensors to Python scalars/lists
    out: dict[str, Any] = {}
    for key, value in flat_dict.items():
        if isinstance(value, Tensor):
            value = value.detach()
            if value.numel() == 1:
                out[key] = value.item()
            else:
                out[key] = value.tolist()
        else:
            out[key] = value

    return out


[docs] class Logger: """A template for loggers.""" def __init__(self, exp_name: str, log_dir: str) -> None: self.exp_name = exp_name self.log_dir = log_dir self.experiment = self._create_experiment() @abc.abstractmethod def _create_experiment(self) -> Experiment: # noqa: F821 ... @abc.abstractmethod def log_scalar(self, name: str, value: float, step: int | None = None) -> None: ... @abc.abstractmethod def log_video( self, name: str, video: Tensor, step: int | None = None, **kwargs ) -> None: ... @abc.abstractmethod def log_hparams(self, cfg: DictConfig | dict) -> None: # noqa: F821 ... @abc.abstractmethod def __repr__(self) -> str: ... @abc.abstractmethod def log_histogram(self, name: str, data: Sequence, **kwargs): ... def log_metrics( self, metrics: dict[str, Any] | TensorDictBase, step: int | None = None, *, keys_sep: str = "/", ) -> dict[str, Any]: """Log multiple scalar metrics at once. This method efficiently handles tensor values by batching CUDA->CPU transfers and performing a single synchronization, avoiding the overhead of multiple implicit syncs that would occur when logging tensors one at a time. This is particularly useful when logging to services running in separate processes (e.g., Ray actors) that may not have GPU access. Args: metrics: Dictionary or TensorDict mapping metric names to values. Tensor values are automatically converted to Python scalars/lists. For TensorDict inputs, nested keys are flattened using ``keys_sep``. step: Optional step value for all metrics. keys_sep: Separator used to flatten nested TensorDict keys into strings. Defaults to "/". Only used for TensorDict inputs. Returns: The converted metrics dictionary (with tensors converted to Python types). """ safe_metrics = _make_metrics_safe(metrics, keys_sep=keys_sep) for name, value in safe_metrics.items(): self.log_scalar(name, value, step=step) return safe_metrics

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources