torchtune.utils¶
Checkpointing¶
Checkpointer which reads and writes "full-model" checkpoints in HF's format. |
|
Checkpointer which reads and writes "full-model" checkpoints in Meta's format. |
Distributed¶
Initialize torch.distributed. |
|
Function that gets the current world size (aka total number of ranks) and rank number of the current trainer. |
Mixed Precision¶
Intelligently determines, based on the dtype if mixed precision training is supported and returns the builtin torch.autocast if applicable. |
|
Returns a gradient scaler for mixed-precision training. |
|
Get the torch.dtype corresponding to the given precision string. |
|
Return a list of supported dtypes for finetuning. |
Memory Management¶
Utility to setup activation checkpointing and wrap the model for checkpointing. |
Metric Logging¶
Logger for use w/ Weights and Biases application (https://wandb.ai/). |
|
Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html). |
|
Logger to standard output. |
|
Logger to disk. |
Data¶
Pad a batch of sequences to the longest sequence length in the batch, and convert integer lists to tensors. |
Miscellaneous¶
TuneRecipeArgParser is a helpful utility subclass of the argparse ArgumentParser that adds a builtin argument "config". |
|
Get a logger with a stream handler. |
|
Function that takes or device or device string, verifies it's correct and availabe given the machine and distributed settings, and returns a torch.device. |
|
Function that sets seed for pseudo-random number generators across commonly used libraries. |