Shortcuts

Source code for torchtune.config._validate

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import inspect

from omegaconf import DictConfig
from torchtune.config._errors import ConfigError
from torchtune.config._utils import _get_component_from_path, _has_component


[docs]def validate(cfg: DictConfig) -> None: """ Ensure that all components in the config can be instantiated correctly Args: cfg (DictConfig): The config to validate Raises: ConfigError: If any component cannot be instantiated """ errors = [] for node, nodedict in cfg.items(): if _has_component(nodedict): try: _component_ = _get_component_from_path(nodedict.get("_component_")) kwargs = {k: v for k, v in nodedict.items() if k != "_component_"} sig = inspect.signature(_component_) sig.bind(**kwargs) # Some objects require other objects as arguments, like optimizers, # lr_schedulers, datasets, etc. Try doing partial instantiation except TypeError as e: if "missing a required argument" in str(e): sig.bind_partial(**kwargs) else: # inspect.signature does not retain the function name in the # exception, so we manually add it back in e = TypeError(f"{_component_.__name__} {str(e)}") errors.append(e) if errors: raise ConfigError(errors)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources