Shortcuts

Source code for torchtune.config._utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from argparse import Namespace
from importlib import import_module
from types import ModuleType
from typing import Any, Dict, List, Union

from omegaconf import DictConfig, OmegaConf

from torchtune.config._errors import InstantiationError
from torchtune.utils._logging import get_logger, log_rank_zero


[docs]def log_config(recipe_name: str, cfg: DictConfig) -> None: """ Logs the resolved config (merged YAML file and CLI overrides) to rank zero. Args: recipe_name (str): name of the recipe to display cfg (DictConfig): parsed config object """ logger = get_logger("DEBUG") cfg_str = OmegaConf.to_yaml(cfg, resolve=True, sort_keys=True) log_rank_zero( logger=logger, msg=f"Running {recipe_name} with resolved config:\n\n{cfg_str}" )
def _has_component(node: Union[Dict[str, Any], DictConfig]) -> bool: return (OmegaConf.is_dict(node) or isinstance(node, dict)) and "_component_" in node def _get_component_from_path(path: str) -> Any: """ Return an object by name or dotted path, importing as necessary. The base functionality relies on ``getattr()`` and handles all possible exceptions accordingly. Based on Hydra's `_locate` from Facebook Research: https://github.com/facebookresearch/hydra/blob/main/hydra/_internal/utils.py#L614 Args: path (str): Dotted path of the object Returns: Any: The object Raises: InstantiationError: If there is an exception loading the object from the provided path ValueError: If a relative or invalid dotpath is passed in """ if path == "": raise ValueError("Empty path") parts = [part for part in path.split(".")] for part in parts: # If a relative path is passed in, the first part will be empty if not len(part): raise ValueError( f"Error loading '{path}': invalid dotstring." + "\nRelative imports are not supported." ) # First module requires trying to import to validate part0 = parts[0] try: obj = import_module(part0) except ImportError as exc_import: raise InstantiationError( f"Error loading '{path}':\n{repr(exc_import)}" + f"\nAre you sure that module '{part0}' is installed?" ) from exc_import # Subsequent components can be checked via getattr() on first module # It can either be an attribute that we can return or a submodule that we # can import and continue searching for m in range(1, len(parts)): part = parts[m] try: obj = getattr(obj, part) # If getattr fails, check to see if it's a module we can import and # continue down the path except AttributeError as exc_attr: parent_dotpath = ".".join(parts[:m]) if isinstance(obj, ModuleType): mod = ".".join(parts[: m + 1]) try: obj = import_module(mod) continue except ModuleNotFoundError as exc_import: raise InstantiationError( f"Error loading '{path}':\n{repr(exc_import)}" + f"\nAre you sure that '{part}' is importable from module '{parent_dotpath}'?" ) from exc_import # Any other error trying to import module can be raised as # InstantiationError except Exception as exc_import: raise InstantiationError( f"Error loading '{path}':\n{repr(exc_import)}" ) from exc_import # If the component is not an attribute nor a module, it doesn't exist raise InstantiationError( f"Error loading '{path}':\n{repr(exc_attr)}" + f"\nAre you sure that '{part}' is an attribute of '{parent_dotpath}'?" ) from exc_attr return obj def _merge_yaml_and_cli_args(yaml_args: Namespace, cli_args: List[str]) -> DictConfig: """ Takes the direct output of argparse's parse_known_args which returns known args as a Namespace and unknown args as a dotlist (in our case, yaml args and cli args, respectively) and merges them into a single OmegaConf DictConfig. If a cli arg overrides a yaml arg with a _component_ field, the cli arg can be specified with the parent field directly, e.g., model=torchtune.models.lora_llama2_7b instead of model._component_=torchtune.models.lora_llama2_7b. Nested fields within the component should be specified with dot notation, e.g., model.lora_rank=16. Example: >>> config.yaml: >>> a: 1 >>> b: >>> _component_: torchtune.models.my_model >>> c: 3 >>> tune full_finetune --config config.yaml b=torchtune.models.other_model b.c=4 >>> yaml_args, cli_args = parser.parse_known_args() >>> conf = _merge_yaml_and_cli_args(yaml_args, cli_args) >>> print(conf) >>> {"a": 1, "b": {"_component_": "torchtune.models.other_model", "c": 4}} Args: yaml_args (Namespace): Namespace containing args from yaml file, components should have _component_ fields cli_args (List[str]): List of key=value strings Returns: DictConfig: OmegaConf DictConfig containing merged args Raises: ValueError: If a cli override is not in the form of key=value """ # Convert Namespace to simple dict yaml_kwargs = vars(yaml_args) cli_dotlist = [] for arg in cli_args: # If CLI override uses the remove flag (~), remove the key from the yaml config if arg.startswith("~"): dotpath = arg[1:].split("=")[0] if "_component_" in dotpath: raise ValueError( f"Removing components from CLI is not supported: ~{dotpath}" ) try: _remove_key_by_dotpath(yaml_kwargs, dotpath) except (KeyError, ValueError): raise ValueError( f"Could not find key {dotpath} in yaml config to remove" ) from None continue # Get other overrides that should be specified as key=value try: k, v = arg.split("=") except ValueError: raise ValueError( f"Command-line overrides must be in the form of key=value, got {arg}" ) from None # If a cli arg overrides a yaml arg with a _component_ field, update the # key string to reflect this if k in yaml_kwargs and _has_component(yaml_kwargs[k]): k += "._component_" # TODO: this is a hack but otherwise we can't pass strings with leading zeroes # to define the checkpoint file format. We manually override OmegaConf behavior # by prepending the value with !!str to force a string type if "max_filename" in k: v = "!!str " + v cli_dotlist.append(f"{k}={v}") # Merge the args cli_conf = OmegaConf.from_dotlist(cli_dotlist) yaml_conf = OmegaConf.create(yaml_kwargs) # CLI takes precedence over yaml args return OmegaConf.merge(yaml_conf, cli_conf) def _remove_key_by_dotpath(nested_dict: Dict[str, Any], dotpath: str) -> None: """ Removes a key specified by dotpath from a nested dict. Errors should handled by the calling function. Args: nested_dict (Dict[str, Any]): Dict to remove key from dotpath (str): dotpath of key to remove, e.g., "a.b.c" """ path = dotpath.split(".") def delete_non_component(d: Dict[str, Any], key: str) -> None: if _has_component(d[key]): raise ValueError( f"Removing components from CLI is not supported: ~{dotpath}" ) del d[key] def recurse_and_delete(d: Dict[str, Any], path: List[str]) -> None: if len(path) == 1: delete_non_component(d, path[0]) else: recurse_and_delete(d[path[0]], path[1:]) if not d[path[0]]: delete_non_component(d, path[0]) recurse_and_delete(nested_dict, path)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources