Shortcuts

Source code for torchtune.utils._checkpointing._checkpointer_utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import json
from enum import Enum
from pathlib import Path
from typing import Any, Dict

import torch
from safetensors import safe_open


[docs]class ModelType(Enum): """ModelType is used by the checkpointer to distinguish between different model architectures. If you are adding a new model that follows a different format than those in the repo already, you can add a new ModelType to gate on weight conversion logic unique to that model. Example: >>> # Usage in a checkpointer class >>> def load_checkpoint(self, ...): >>> ... >>> if self._model_type == MY_NEW_MODEL: >>> state_dict = my_custom_state_dict_mapping(state_dict) """ GEMMA = "gemma" """Gemma family of models. See :func:`~torchtune.models.gemma.gemma`""" LLAMA2 = "llama2" """Llama2 family of models. See :func:`~torchtune.models.llama2.llama2`""" LLAMA3 = "llama3" """Llama3 family of models. See :func:`~torchtune.models.llama3.llama3`""" MISTRAL = "mistral" """Mistral family of models. See :func:`~torchtune.models.mistral.mistral`""" PHI3_MINI = "phi3_mini" """Phi-3 family of models. See :func:`~torchtune.models.phi3.phi3`""" MISTRAL_REWARD = "mistral_reward" """Mistral model with a classification head. See :func:`~torchtune.models.mistral.mistral_classifier`"""
def get_path(input_dir: Path, filename: str, missing_ok: bool = False) -> Path: """ Utility to recover and validate the path for a given file within a given directory. Args: input_dir (Path): Directory containing the file filename (str): Name of the file missing_ok (bool): Whether to raise an error if the file is missing. Returns: Path: Path to the file Raises: ValueError: If the file is missing and missing_ok is False. """ if not input_dir.is_dir(): raise ValueError(f"{input_dir} is not a valid directory.") file_path = Path.joinpath(input_dir, filename) # If missing_ok is False, raise an error if the path is invalid if not missing_ok and not file_path.is_file(): raise ValueError(f"No file with name: {filename} found in {input_dir}.") return file_path def safe_torch_load( checkpoint_path: Path, weights_only: bool = True, mmap: bool = True ) -> Dict[str, Any]: """ Utility to load a checkpoint file onto CPU in a safe manner. Provides separate handling for safetensors files. Args: checkpoint_path (Path): Path to the checkpoint file. weights_only (bool): Whether to load only tensors, primitive types, and dictionaries (passthrough to torch.load). Default: True mmap (bool): Whether to mmap from disk into CPU memory. Default: True Returns: Dict[str, Any]: State dict from the checkpoint file. Raises: ValueError: If the checkpoint file is not found or cannot be loaded. """ try: # convert the path into a string since pathlib Path and mmap don't work # well together is_safetensors_file = ( True if str(checkpoint_path).endswith(".safetensors") else False ) if is_safetensors_file: result = {} with safe_open(checkpoint_path, framework="pt", device="cpu") as f: for k in f.keys(): result[k] = f.get_tensor(k) state_dict = result else: state_dict = torch.load( str(checkpoint_path), map_location="cpu", mmap=mmap, weights_only=weights_only, ) except Exception as e: raise ValueError(f"Unable to load checkpoint from {checkpoint_path}. ") from e return state_dict def save_config(path: Path, config: Dict[str, Any]) -> None: """ Save a configuration dictionary to a file. Args: path (Path): Path to save the configuration file. config (Dict[str, Any]): Configuration dictionary to save. """ if not path.is_dir(): path.mkdir(exist_ok=True) file_path = Path.joinpath(path, "config.json") if not file_path.exists(): with open(file_path, "w") as f: json.dump(config, f)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources