Shortcuts

Source code for torchtune.utils._logging

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import warnings
from functools import lru_cache, wraps
from typing import Callable, Optional, TypeVar

from torch import distributed as dist

T = TypeVar("T", bound=type)


[docs]def get_logger(level: Optional[str] = None) -> logging.Logger: """ Get a logger with a stream handler. Args: level (Optional[str]): The logging level. See https://docs.python.org/3/library/logging.html#levels for list of levels. Example: >>> logger = get_logger("INFO") >>> logger.info("Hello world!") INFO:torchtune.utils._logging:Hello world! Returns: logging.Logger: The logger. """ logger = logging.getLogger(__name__) if not logger.hasHandlers(): logger.addHandler(logging.StreamHandler()) if level is not None: level = getattr(logging, level.upper()) logger.setLevel(level) return logger
@lru_cache(None) def log_once(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: """ Logs a message only once. LRU cache is used to ensure a specific message is logged only once, similar to how :func:`~warnings.warn` works when the ``once`` rule is set via command-line or environment variable. Args: logger (logging.Logger): The logger. msg (str): The warning message. level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. Defaults to ``logging.INFO``. """ log_rank_zero(logger=logger, msg=msg, level=level) def deprecated(msg: str = "") -> Callable[[T], T]: """ Decorator to mark an object as deprecated and print additional message. Args: msg (str): additional information to print after warning. Returns: Callable[[T], T]: the decorated object. """ @lru_cache(maxsize=1) def warn(obj): warnings.warn( f"{obj.__name__} is deprecated and will be removed in future versions. " + msg, category=FutureWarning, stacklevel=3, ) def decorator(obj): @wraps(obj) def wrapper(*args, **kwargs): warn(obj) return obj(*args, **kwargs) return wrapper return decorator def log_rank_zero(logger: logging.Logger, msg: str, level: int = logging.INFO) -> None: """ Logs a message only on rank zero. Args: logger (logging.Logger): The logger. msg (str): The warning message. level (int): The logging level. See https://docs.python.org/3/library/logging.html#levels for values. Defaults to ``logging.INFO``. """ rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 if rank != 0: return logger.log(level, msg)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources