Note
Go to the end to download the full example code.
Distributed training at scale with PyTorch and Ray Train#
Author: Ricardo Decal
This tutorial shows how to distribute PyTorch training across multiple GPUs using Ray Train and Ray Data for scalable, production-ready model training.
Pre-train a GPT-2 (~124M-parameter) language model using PyTorch and Hugging Face Transformers.
Distribute training across multiple GPUs with Ray Train with minimal code changes.
Stream training data from Hugging Face datasets with Ray Data’s distributed workers.
Save and load distributed checkpoints.
Scale from a single node to a multinode cluster with minimal code changes.
Optimize cost and performance with heterogeneous clusters.
Monitor training with the Ray dashboard.
PyTorch v2.9+.
Ray Train (
ray[train]) v2.52.1+.tiktoken,datasets, andtransformers(Hugging Face).One or more GPUs are recommended but not required. This tutorial is tested on a
g4dn.12xlargeinstance, which has 4 NVIDIA T4 GPUs (16GB of memory per GPU).
Ray Train is a scalable framework for distributed deep learning. Ray Train builds on top of Ray, a unified framework for scaling AI and Python applications that simplifies the complexities of distributed computing. Ray is also open source and part of the PyTorch Foundation.
Ray Train enables you to scale from a single GPU to hundreds of GPUs without rewriting your training loop. Combined with Ray Data for streaming data ingestion, you get an end-to-end distributed training pipeline that handles data loading, sharding, gradient synchronization, checkpointing, and fault tolerance.
Setup#
To install the dependencies, run pip install "ray[train]" torch tiktoken datasets transformers.
Then, import the required libraries:
import os
import tempfile
import time
import numpy as np
import ray
import ray.train
import tiktoken
import torch
from datasets import load_dataset
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from transformers import GPT2Config, GPT2LMHeadModel
# Enable smoke test to run this tutorial quickly.
SMOKE_TEST = True
# Reduce Ray Data verbosity
ray.data.DataContext.get_current().enable_progress_bars = False
ray.data.DataContext.get_current().print_on_execution_start = False
Load the dataset with Ray Data#
This tutorial uses the Wikitext-103 dataset, a collection of over 100 million tokens from verified Good and Featured articles on Wikipedia.
The ray.data.from_huggingface() function converts a Hugging Face
dataset into a Ray Dataset, enabling distributed streaming and
preprocessing across all available nodes.
hf_ds = load_dataset("Salesforce/wikitext", "wikitext-103-raw-v1")
train_ds = ray.data.from_huggingface(hf_ds["train"])
val_ds = ray.data.from_huggingface(hf_ds["validation"])
# Limit dataset size for fast iteration during smoke tests.
if SMOKE_TEST:
train_ds = train_ds.limit(2500)
val_ds = val_ds.limit(2500)
print(f"Dataset schema:\n{train_ds.schema()}")
Downloading readme: 0.00B [00:00, ?B/s]
Downloading readme: 10.5kB [00:00, 41.6MB/s]
Downloading data: 0%| | 0.00/733k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 733k/733k [00:00<00:00, 16.3MB/s]
Downloading data: 0%| | 0.00/157M [00:00<?, ?B/s]
Downloading data: 7%|▋ | 10.5M/157M [00:00<00:03, 44.4MB/s]
Downloading data: 27%|██▋ | 41.9M/157M [00:00<00:00, 126MB/s]
Downloading data: 47%|████▋ | 73.4M/157M [00:00<00:00, 177MB/s]
Downloading data: 67%|██████▋ | 105M/157M [00:00<00:00, 203MB/s]
Downloading data: 87%|████████▋ | 136M/157M [00:00<00:00, 215MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 189MB/s]
Downloading data: 0%| | 0.00/157M [00:00<?, ?B/s]
Downloading data: 13%|█▎ | 21.0M/157M [00:00<00:00, 156MB/s]
Downloading data: 33%|███▎ | 52.4M/157M [00:00<00:00, 208MB/s]
Downloading data: 53%|█████▎ | 83.9M/157M [00:00<00:00, 231MB/s]
Downloading data: 73%|███████▎ | 115M/157M [00:00<00:00, 232MB/s]
Downloading data: 93%|█████████▎| 147M/157M [00:00<00:00, 228MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 222MB/s]
Downloading data: 0%| | 0.00/657k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 657k/657k [00:00<00:00, 15.7MB/s]
Generating test split: 0%| | 0/4358 [00:00<?, ? examples/s]
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 531808.12 examples/s]
Generating train split: 0%| | 0/1801350 [00:00<?, ? examples/s]
Generating train split: 5%|▌ | 91000/1801350 [00:00<00:01, 902480.73 examples/s]
Generating train split: 10%|█ | 184000/1801350 [00:00<00:01, 911276.14 examples/s]
Generating train split: 15%|█▌ | 277000/1801350 [00:00<00:01, 913428.60 examples/s]
Generating train split: 21%|██ | 370000/1801350 [00:00<00:01, 913447.82 examples/s]
Generating train split: 28%|██▊ | 508000/1801350 [00:00<00:01, 912225.59 examples/s]
Generating train split: 33%|███▎ | 601000/1801350 [00:00<00:01, 914529.87 examples/s]
Generating train split: 39%|███▊ | 695000/1801350 [00:00<00:01, 918634.33 examples/s]
Generating train split: 44%|████▎ | 788000/1801350 [00:00<00:01, 917438.71 examples/s]
Generating train split: 49%|████▉ | 880000/1801350 [00:00<00:01, 915746.24 examples/s]
Generating train split: 56%|█████▋ | 1016675/1801350 [00:01<00:00, 909800.43 examples/s]
Generating train split: 62%|██████▏ | 1109675/1801350 [00:01<00:00, 912156.38 examples/s]
Generating train split: 67%|██████▋ | 1201675/1801350 [00:01<00:00, 912873.19 examples/s]
Generating train split: 72%|███████▏ | 1294675/1801350 [00:01<00:00, 914897.22 examples/s]
Generating train split: 77%|███████▋ | 1387675/1801350 [00:01<00:00, 915687.28 examples/s]
Generating train split: 82%|████████▏ | 1479675/1801350 [00:01<00:00, 915036.10 examples/s]
Generating train split: 87%|████████▋ | 1573675/1801350 [00:01<00:00, 919440.80 examples/s]
Generating train split: 92%|█████████▏| 1665675/1801350 [00:01<00:00, 918644.64 examples/s]
Generating train split: 98%|█████████▊| 1759675/1801350 [00:01<00:00, 920273.69 examples/s]
Generating train split: 100%|██████████| 1801350/1801350 [00:01<00:00, 915302.96 examples/s]
Generating validation split: 0%| | 0/3760 [00:00<?, ? examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 712279.62 examples/s]
Downloading readme: 0.00B [00:00, ?B/s]
Downloading readme: 10.5kB [00:00, 43.0MB/s]
Downloading data: 0%| | 0.00/733k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 733k/733k [00:00<00:00, 17.2MB/s]
Downloading data: 0%| | 0.00/157M [00:00<?, ?B/s]
Downloading data: 13%|█▎ | 21.0M/157M [00:00<00:00, 166MB/s]
Downloading data: 33%|███▎ | 52.4M/157M [00:00<00:00, 216MB/s]
Downloading data: 53%|█████▎ | 83.9M/157M [00:00<00:00, 232MB/s]
Downloading data: 73%|███████▎ | 115M/157M [00:00<00:00, 238MB/s]
Downloading data: 94%|█████████▎| 147M/157M [00:00<00:00, 247MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 236MB/s]
Downloading data: 0%| | 0.00/157M [00:00<?, ?B/s]
Downloading data: 13%|█▎ | 21.0M/157M [00:00<00:00, 170MB/s]
Downloading data: 33%|███▎ | 52.4M/157M [00:00<00:00, 221MB/s]
Downloading data: 53%|█████▎ | 83.9M/157M [00:00<00:00, 249MB/s]
Downloading data: 73%|███████▎ | 115M/157M [00:00<00:00, 254MB/s]
Downloading data: 93%|█████████▎| 147M/157M [00:00<00:00, 259MB/s]
Downloading data: 100%|██████████| 157M/157M [00:00<00:00, 245MB/s]
Downloading data: 0%| | 0.00/657k [00:00<?, ?B/s]
Downloading data: 100%|██████████| 657k/657k [00:00<00:00, 15.2MB/s]
Generating test split: 0%| | 0/4358 [00:00<?, ? examples/s]
Generating test split: 100%|██████████| 4358/4358 [00:00<00:00, 619850.68 examples/s]
Generating train split: 0%| | 0/1801350 [00:00<?, ? examples/s]
Generating train split: 5%|▌ | 93000/1801350 [00:00<00:01, 923096.86 examples/s]
Generating train split: 10%|█ | 187000/1801350 [00:00<00:01, 924504.55 examples/s]
Generating train split: 16%|█▌ | 281000/1801350 [00:00<00:01, 926532.58 examples/s]
Generating train split: 21%|██ | 375000/1801350 [00:00<00:01, 927659.89 examples/s]
Generating train split: 26%|██▌ | 468000/1801350 [00:00<00:01, 922125.08 examples/s]
Generating train split: 31%|███ | 562000/1801350 [00:00<00:01, 923556.44 examples/s]
Generating train split: 36%|███▋ | 657000/1801350 [00:00<00:01, 927324.83 examples/s]
Generating train split: 42%|████▏ | 751000/1801350 [00:00<00:01, 928505.92 examples/s]
Generating train split: 47%|████▋ | 844000/1801350 [00:00<00:01, 924990.25 examples/s]
Generating train split: 55%|█████▍ | 982675/1801350 [00:01<00:00, 920584.46 examples/s]
Generating train split: 60%|█████▉ | 1077675/1801350 [00:01<00:00, 924308.70 examples/s]
Generating train split: 68%|██████▊ | 1216675/1801350 [00:01<00:00, 921947.47 examples/s]
Generating train split: 73%|███████▎ | 1310675/1801350 [00:01<00:00, 925091.48 examples/s]
Generating train split: 78%|███████▊ | 1403675/1801350 [00:01<00:00, 922907.53 examples/s]
Generating train split: 83%|████████▎ | 1497675/1801350 [00:01<00:00, 923491.39 examples/s]
Generating train split: 88%|████████▊ | 1591675/1801350 [00:01<00:00, 926346.70 examples/s]
Generating train split: 94%|█████████▎| 1684675/1801350 [00:01<00:00, 922623.82 examples/s]
Generating train split: 99%|█████████▊| 1777675/1801350 [00:01<00:00, 923606.93 examples/s]
Generating train split: 100%|██████████| 1801350/1801350 [00:01<00:00, 923779.15 examples/s]
Generating validation split: 0%| | 0/3760 [00:00<?, ? examples/s]
Generating validation split: 100%|██████████| 3760/3760 [00:00<00:00, 726453.68 examples/s]
2026-02-25 18:58:25,311 WARNING services.py:2137 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 2147467264 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.
2026-02-25 18:58:25,509 INFO worker.py:2014 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py:2062: FutureWarning:
Tip: In future versions of Ray, Ray will no longer override accelerator visible devices env var if num_gpus=0 or num_gpus=None (default). To enable this behavior and turn off this error message, set RAY_ACCEL_ENV_VAR_OVERRIDE_ON_ZERO=0
Dataset schema:
Column Type
------ ----
text string
The schema can look like this:
Column Type
------ ----
text string
This means that the dataset has one column called text and it is a string.
Inspect raw data#
Use take(n) to fetch a small number of rows for inspection.
Each row is a dictionary with the column names as keys.
print("--- Raw data sample ---")
sample = train_ds.take(2)
for i, row in enumerate(sample):
text_preview = (row["text"][:120] + "...") if len(row["text"]) > 120 else row["text"]
print(f" Row {i}: {text_preview!r}")
--- Raw data sample ---
2026-02-25 18:58:28,637 INFO dataset.py:3485 -- Tip: Use `take_batch()` instead of `take() / show()` to return records in pandas or numpy batch format.
2026-02-25 18:58:28,644 INFO logging.py:397 -- Registered dataset logger for dataset dataset_4_0
2026-02-25 18:58:28,653 INFO streaming_executor.py:682 -- [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
2026-02-25 18:58:28,654 WARNING resource_manager.py:136 -- ⚠️ Ray's object store is configured to use only 5.3% of available memory (9.3GiB out of 175.7GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
2026-02-25 18:58:28,733 INFO streaming_executor.py:300 -- ✔️ Dataset dataset_4_0 execution finished in 0.08 seconds
2026-02-25 18:58:28,750 INFO util.py:257 -- Exiting prefetcher's background thread
Row 0: ''
Row 1: ' = Valkyria Chronicles III = \n'
You’ll see output like this:
Row 0: ''
Row 1: ' = Valkyria Chronicles III = '
Each row in Wikitext-103 is a single line from a Wikipedia article.
Consecutive rows belong to the same article, with empty rows separating
paragraphs. New articles begin with a title line like
= Article Title =. The tokenization step below inserts an
<|endoftext|> separator token before each title line so the model
learns to reset context at article boundaries.
Tokenize and chunk the data#
Language models consume fixed-length sequences of token IDs. The preprocessing step converts raw text into token ID sequences for next-token prediction.
This tutorial uses tiktoken with the GPT-2 encoding (vocabulary size
50,257). tiktoken is a fast, standalone tokenizer that has no
dependency on the Hugging Face transformers library.
The tokenize_and_chunk function does the following:
Tokenizes each batch of text, concatenating into a single stream. Article title lines (for example,
= Article Title =) trigger an<|endoftext|>separator so the model resets context at article boundaries.Splits the stream into fixed-length blocks of
block_sizetokens.Returns
input_idsfor each block. During training, the same tensor serves as both input and label becauseGPT2LMHeadModelshifts the labels internally when computing the cross-entropy loss.
BLOCK_SIZE = 256
VOCAB_SIZE = 50257
encoding = tiktoken.get_encoding("gpt2")
EOT_TOKEN = encoding.eot_token # <|endoftext|> token ID (50256)
def _is_article_title(text: str) -> bool:
"""Detect Wikitext article title lines like ' = Some Title = '."""
stripped = text.strip()
return stripped.startswith("= ") and stripped.endswith(" =") and not stripped.startswith("= =")
def tokenize_and_chunk(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
"""Tokenize text and split into fixed-length chunks for language modeling."""
# Reconstruct the original text stream by joining rows with newlines.
# Article title lines signal new articles, so we insert an
# <|endoftext|> separator before them.
all_tokens: list[int] = []
for text in batch["text"]:
if _is_article_title(text):
all_tokens.append(EOT_TOKEN)
all_tokens.extend(encoding.encode_ordinary(text + "\n"))
# Split into fixed-length chunks of block_size tokens.
num_chunks = len(all_tokens) // BLOCK_SIZE
all_tokens = all_tokens[: num_chunks * BLOCK_SIZE]
if num_chunks == 0:
return {"input_ids": []}
tokens_array = np.array(all_tokens, dtype=np.int64).reshape(num_chunks, BLOCK_SIZE)
return {"input_ids": tokens_array}
Apply the tokenization with map_batches(). This operation is lazy,
meaning that Ray Data defers execution until a downstream consumer requests the
results. Lazy execution lets Ray optimize the entire pipeline before any
work begins.
# These do not trigger execution.
train_ds = train_ds.map_batches(tokenize_and_chunk, batch_format="numpy")
val_ds = val_ds.map_batches(tokenize_and_chunk, batch_format="numpy")
Inspect the tokenized output with take(2):
print("--- After tokenization ---")
tokenized_sample = train_ds.take(2)
for i, row in enumerate(tokenized_sample):
ids = row["input_ids"]
print(f" Row {i}: input_ids shape={ids.shape}, first 10 tokens={ids[:10].tolist()}")
print(f" Decoded: {encoding.decode(ids[:30].tolist())!r}...")
--- After tokenization ---
2026-02-25 18:58:29,711 INFO logging.py:397 -- Registered dataset logger for dataset dataset_7_0
2026-02-25 18:58:29,958 INFO streaming_executor.py:300 -- ✔️ Dataset dataset_7_0 execution finished in 0.24 seconds
2026-02-25 18:58:29,961 INFO util.py:257 -- Exiting prefetcher's background thread
Each row now contains a fixed-length input_ids array of 256 tokens.
Streaming execution#
Internally, Ray divides the data into blocks and dispatches them to
workers. This block-based architecture enables streaming execution: as
soon as a stage outputs a block, the next stage can begin processing it
immediately without waiting for previous stages to finish the entire
dataset. This means the map_batches tokenization above runs in a
streaming pipeline with the training loop, so the full dataset never needs
to fit in memory at once.
When training starts, Ray Data logs the execution plan. For this tutorial one possible plan is:
Execution plan: InputDataBuffer[Input]
-> TaskPoolMapOperator[MapBatches(tokenize_and_chunk)]
-> OutputSplitter[split(4, equal=True)]
This tells you exactly how Ray Data will stream through tokenization and split the data across 4 trainer workers.
Define the transformer model#
The model is a decoder-only transformer language model using Hugging Face’s
GPT2LMHeadModel. The hyperparameters below are for the standard GPT-2 “small” architecture.
def create_model():
"""Create a GPT-2 small model with random weights."""
model = GPT2LMHeadModel(GPT2Config(
vocab_size=VOCAB_SIZE,
n_positions=BLOCK_SIZE,
n_embd=768,
n_layer=12,
n_head=12,
))
model.loss_type = "ForCausalLM"
return model
Verify the model size:
model = create_model()
num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params / 1e6:.1f}M")
del model # Free memory before training
Model parameters: 123.8M
You can see approximately 123.8M parameters.
Define the distributed training function#
The training function runs on each worker process. Ray Train
manages the distributed setup: it wraps the model in
DistributedDataParallel, shards the data across workers, and
synchronizes gradients automatically.
The key Ray Train integration points are:
ray.train.get_dataset_shard("train")retrieves the worker’s portion of the dataset, and Ray Data automatically splits the dataset across all workers.ray.train.torch.prepare_model(model)wraps the model inDistributedDataParalleland moves it to the correct GPU.shard.iter_torch_batches(batch_size=...)returns an iterator ofdict[str, torch.Tensor]batches, with tensors automatically placed on the worker’s GPU. Settingprefetch_batches=2opportunistically fetches 2 batches ahead of the current batch.ray.train.report(metrics, checkpoint=...)reports metrics to the driver and saves a checkpoint.
def train_func_per_worker(config: dict):
"""Training function executed by each distributed worker."""
lr = config["lr"]
weight_decay = config["weight_decay"]
max_grad_norm = config["max_grad_norm"]
epochs = config["epochs"]
batch_size = config["batch_size_per_worker"]
max_steps_per_epoch = config.get("max_steps_per_epoch")
# --- Data -----------------------------------------------------------
# Each worker gets an automatic shard of the dataset.
train_data_shard = ray.train.get_dataset_shard("train")
val_data_shard = ray.train.get_dataset_shard("validation")
# --- Model ----------------------------------------------------------
model = create_model()
# prepare_model wraps the model in DistributedDataParallel and places
# it on the correct device.
model = ray.train.torch.prepare_model(model)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# --- Training loop --------------------------------------------------
for epoch in range(epochs):
model.train()
train_loss_sum = 0.0
train_batches = 0
train_tokens = 0
epoch_start = time.perf_counter()
# iter_torch_batches returns dicts of tensors already on the GPU.
for batch in train_data_shard.iter_torch_batches(
batch_size=batch_size, dtypes=torch.long, prefetch_batches=2
):
input_ids = batch["input_ids"]
# GPT2LMHeadModel shifts labels internally to align each
# position with the next token, so we can use input_ids as
# both the input and the labels.
out = model(input_ids=input_ids, labels=input_ids)
loss = out.loss
optimizer.zero_grad()
loss.backward()
# Gradient clipping for training stability
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_grad_norm)
optimizer.step()
train_loss_sum += loss.item()
train_batches += 1
train_tokens += input_ids.numel()
if max_steps_per_epoch and train_batches >= max_steps_per_epoch:
break
train_elapsed = time.perf_counter() - epoch_start
avg_train_loss = train_loss_sum / max(train_batches, 1)
# --- Validation -----------------------------------------------------
model.eval()
val_loss_sum = 0.0
val_batches = 0
with torch.no_grad():
for batch in val_data_shard.iter_torch_batches(
batch_size=batch_size, dtypes=torch.long, prefetch_batches=2
):
input_ids = batch["input_ids"]
out = model(input_ids=input_ids, labels=input_ids)
loss = out.loss
val_loss_sum += loss.item()
val_batches += 1
if max_steps_per_epoch and val_batches >= max_steps_per_epoch:
break
avg_val_loss = val_loss_sum / max(val_batches, 1)
epoch_elapsed = time.perf_counter() - epoch_start
# --- Report metrics and save checkpoint ------------------------------
metrics = {
"train_loss": round(avg_train_loss, 4),
"val_loss": round(avg_val_loss, 4),
"epoch": epoch,
"epoch_time_sec": round(epoch_elapsed, 2),
"epoch_tokens": train_tokens,
"tokens_per_sec": round(train_tokens / max(train_elapsed, 1e-6), 2),
}
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.module.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
os.path.join(temp_checkpoint_dir, "checkpoint.pt"),
)
checkpoint = ray.train.Checkpoint.from_directory(temp_checkpoint_dir)
ray.train.report(metrics=metrics, checkpoint=checkpoint)
Configure and launch distributed training#
The TorchTrainer brings everything together. Running trainer.fit() finally
triggers the execution of the full data pipeline and training loop. The Trainer accepts:
train_func_per_worker: the function each worker executes.train_loop_config: a dictionary of hyperparameters forwarded to the training function.datasets: a dictionary of Ray Datasets. Ray Train automatically splits each dataset across workers.scaling_config: specifies the number of workers and whether to use GPUs.
Setting num_workers=4 launches 4 parallel workers, one per GPU. Ray
Train handles torch.distributed initialization, NCCL backend setup,
and DistributedDataParallel wrapping behind the scenes. In the logs,
you see each worker assigned a rank and device:
Started training worker group of size 4:
* (ip=10.0.176.183, pid=25636) world_rank=0, local_rank=0, node_rank=0
* (ip=10.0.176.183, pid=25637) world_rank=1, local_rank=1, node_rank=0
...
Moving model to device: cuda:0
Wrapping provided model in DistributedDataParallel.
batch_size_per_worker is the number of sequences each worker
processes per gradient step. With 4 workers and a per-worker batch size
of 16, the effective global batch size is 4 × 16 = 64 sequences,
or 64 × 256 = 4,096 tokens per optimizer step.
USE_GPU = torch.cuda.is_available()
NUM_WORKERS = max(torch.cuda.device_count(), 1) # One worker per available GPU
NUM_EPOCHS = 5
BATCH_SIZE_PER_WORKER = 16
LR = 3e-4
WEIGHT_DECAY = 0.1
MAX_GRAD_NORM = 1.0
trainer = TorchTrainer(
train_loop_per_worker=train_func_per_worker,
train_loop_config={
"lr": LR,
"weight_decay": WEIGHT_DECAY,
"max_grad_norm": MAX_GRAD_NORM,
"epochs": NUM_EPOCHS,
"batch_size_per_worker": BATCH_SIZE_PER_WORKER,
"max_steps_per_epoch": 5 if SMOKE_TEST else None,
},
# Register the datasets,
datasets={"train": train_ds, "validation": val_ds},
scaling_config=ScalingConfig(
num_workers=NUM_WORKERS,
use_gpu=USE_GPU,
),
run_config=RunConfig(
name="gpt2-small-pretraining",
storage_path="/tmp/ray-train-checkpoints",
),
)
result = trainer.fit()
(TrainController pid=8095) Attempting to start training worker group of size 4 with the following resources: [{'GPU': 1}] * 4
(RayTrainWorker pid=8247) Setting up process group for: env:// [rank=0, world_size=4]
(TrainController pid=8095) Started training worker group of size 4:
(TrainController pid=8095) - (ip=172.17.0.2, pid=8247) world_rank=0, local_rank=0, node_rank=0
(TrainController pid=8095) - (ip=172.17.0.2, pid=8249) world_rank=1, local_rank=1, node_rank=0
(TrainController pid=8095) - (ip=172.17.0.2, pid=8250) world_rank=2, local_rank=2, node_rank=0
(TrainController pid=8095) - (ip=172.17.0.2, pid=8248) world_rank=3, local_rank=3, node_rank=0
(RayTrainWorker pid=8247) Moving model to device: cuda:0
(RayTrainWorker pid=8247) Wrapping provided model in DistributedDataParallel.
(SplitCoordinator pid=8614) Registered dataset logger for dataset train_8_0
(SplitCoordinator pid=8614) [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
(SplitCoordinator pid=8614) ⚠️ Ray's object store is configured to use only 5.3% of available memory (9.3GiB out of 175.7GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
(SplitCoordinator pid=8614) ✔️ Dataset train_8_0 execution finished in 0.63 seconds
(RayTrainWorker pid=8247) Exiting prefetcher's background thread
(RayTrainWorker pid=8247) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-58-49.963961)
(RayTrainWorker pid=8247) Reporting training result 1: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-58-49.963961), metrics={'train_loss': 9.886, 'val_loss': 8.9904, 'epoch': 0, 'epoch_time_sec': 4.04, 'epoch_tokens': 20480, 'tokens_per_sec': 6076.47}, validation_spec=None)
(SplitCoordinator pid=8613) Registered dataset logger for dataset validation_10_0
(SplitCoordinator pid=8613) [dataset]: A new progress UI is available. To enable, set `ray.data.DataContext.get_current().enable_rich_progress_bars = True` and `ray.data.DataContext.get_current().use_ray_tqdm = False`.
(SplitCoordinator pid=8613) ⚠️ Ray's object store is configured to use only 5.3% of available memory (9.3GiB out of 175.7GiB total). For optimal Ray Data performance, we recommend setting the object store to at least 50% of available memory. You can do this by setting the 'object_store_memory' parameter when calling ray.init() or by setting the RAY_DEFAULT_OBJECT_STORE_MEMORY_PROPORTION environment variable.
(SplitCoordinator pid=8613) ✔️ Dataset validation_10_0 execution finished in 0.16 seconds
(RayTrainWorker pid=8248) Exiting prefetcher's background thread [repeated 7x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(SplitCoordinator pid=8614) Registered dataset logger for dataset train_8_1
(SplitCoordinator pid=8614) ✔️ Dataset train_8_1 execution finished in 0.23 seconds
(RayTrainWorker pid=8247) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-58-57.013528) [repeated 4x across cluster]
(RayTrainWorker pid=8247) Reporting training result 2: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-58-57.013528), metrics={'train_loss': 8.5329, 'val_loss': 8.1828, 'epoch': 1, 'epoch_time_sec': 3.16, 'epoch_tokens': 20480, 'tokens_per_sec': 8054.65}, validation_spec=None) [repeated 4x across cluster]
(RayTrainWorker pid=8248) Exiting prefetcher's background thread [repeated 8x across cluster]
(SplitCoordinator pid=8613) Registered dataset logger for dataset validation_10_1
(SplitCoordinator pid=8613) ✔️ Dataset validation_10_1 execution finished in 0.15 seconds
(SplitCoordinator pid=8614) Registered dataset logger for dataset train_8_2
(SplitCoordinator pid=8614) ✔️ Dataset train_8_2 execution finished in 0.19 seconds
(RayTrainWorker pid=8247) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-59-04.068411) [repeated 4x across cluster]
(RayTrainWorker pid=8247) Reporting training result 3: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-59-04.068411), metrics={'train_loss': 7.6691, 'val_loss': 7.7669, 'epoch': 2, 'epoch_time_sec': 3.15, 'epoch_tokens': 20480, 'tokens_per_sec': 8062.32}, validation_spec=None) [repeated 4x across cluster]
(RayTrainWorker pid=8248) Exiting prefetcher's background thread [repeated 8x across cluster]
(SplitCoordinator pid=8613) Registered dataset logger for dataset validation_10_2
(SplitCoordinator pid=8613) ✔️ Dataset validation_10_2 execution finished in 0.16 seconds
(SplitCoordinator pid=8614) Registered dataset logger for dataset train_8_3
(SplitCoordinator pid=8614) ✔️ Dataset train_8_3 execution finished in 0.19 seconds
(RayTrainWorker pid=8250) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-59-11.105304) [repeated 4x across cluster]
(RayTrainWorker pid=8250) Reporting training result 4: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-59-11.105304), metrics={'train_loss': 7.1449, 'val_loss': 7.6902, 'epoch': 3, 'epoch_time_sec': 3.19, 'epoch_tokens': 20480, 'tokens_per_sec': 7939.03}, validation_spec=None) [repeated 4x across cluster]
(RayTrainWorker pid=8248) Exiting prefetcher's background thread [repeated 8x across cluster]
(SplitCoordinator pid=8613) Registered dataset logger for dataset validation_10_3
(SplitCoordinator pid=8613) ✔️ Dataset validation_10_3 execution finished in 0.15 seconds
(SplitCoordinator pid=8614) Registered dataset logger for dataset train_8_4
(SplitCoordinator pid=8614) ✔️ Dataset train_8_4 execution finished in 0.18 seconds
(RayTrainWorker pid=8247) Checkpoint successfully created at: Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-59-18.082391) [repeated 4x across cluster]
(RayTrainWorker pid=8247) Reporting training result 5: TrainingReport(checkpoint=Checkpoint(filesystem=local, path=/tmp/ray-train-checkpoints/gpt2-small-pretraining/checkpoint_2026-02-25_18-59-18.082391), metrics={'train_loss': 7.3434, 'val_loss': 7.7539, 'epoch': 4, 'epoch_time_sec': 3.11, 'epoch_tokens': 20480, 'tokens_per_sec': 8212.82}, validation_spec=None) [repeated 4x across cluster]
(RayTrainWorker pid=8248) Exiting prefetcher's background thread [repeated 8x across cluster]
(SplitCoordinator pid=8613) Registered dataset logger for dataset validation_10_4
(SplitCoordinator pid=8613) ✔️ Dataset validation_10_4 execution finished in 0.16 seconds
Inspect results#
After training, the Result object contains the final metrics and
checkpoint. result.metrics comes from the last
ray.train.report() call. result.checkpoint contains the
checkpoint from the last ray.train.report() call.
print("\nTraining finished!")
Training finished!
result.metrics contains the metrics dict from the last
ray.train.report() call:
{'train_loss': 7.0646, 'val_loss': 7.6051, 'epoch': 4,
'epoch_time_sec': 12.34, 'epoch_tokens': 20480, 'tokens_per_sec': 1759.8}
The per-worker logs show training loss, validation loss, and throughput metrics for each epoch. With random weights and only a few steps, expect a high loss (~10-11).
Checkpointing#
In production training, you can enable checkpointing to make your training jobs robust to unexpected failures. Checkpointing permits you to take advantage of Ray Train’s fault tolerance mechanisms described in the Fault tolerance section.
Ray Train offers several checkpointing optimizations. Asynchronous uploading enables you to continue training while checkpoints stream to remote storage in the background. Distributed checkpointing uploads shards from each worker in parallel, avoiding a gather step into a single worker’s memory that risks OOM errors for large models.
For a full guide on checkpointing with Ray Train, see the Ray Train checkpointing guide.
Scaling to a multi-node cluster#
The code above runs on a single 4-GPU machine. Scaling to a multi-node cluster requires only two changes:
Increase ``num_workers`` to match the total number of GPUs in the cluster.
Set a shared storage path so that all nodes can access checkpoints.
For example, to train on a cluster of 4 nodes with 4 GPUs each (16 GPUs total):
trainer = TorchTrainer(
train_loop_per_worker=train_func_per_worker,
train_loop_config={...},
datasets={"train": train_ds, "validation": val_ds},
scaling_config=ScalingConfig(
num_workers=16, # 4 nodes x 4 GPUs
use_gpu=True,
),
run_config=RunConfig(
# Shared storage accessible from all nodes
storage_path="s3://my-bucket/ray-checkpoints",
checkpoint_config=CheckpointConfig(num_to_keep=2),
),
)
Ray Train automatically:
Launches workers across all available nodes, bringing up new nodes if needed in an autoscaling Ray cluster.
Shards data across all workers.
No changes to the training function are needed. The same
train_func_per_worker runs identically whether on 1 GPU or 256 GPUs.
This tutorial uses DistributedDataParallel (DDP), which replicates
the full model on every GPU. For larger models that don’t fit on a
single GPU, you can switch to
FullyShardedDataParallel
(FSDP) to shard parameters, gradients, and optimizer states across
workers by setting prepare_model(parallel_strategy="fsdp").
Heterogeneous clusters: separate data and training resources#
Because Ray Data and Ray Train are separate systems, they don’t have to share the same machines. By default, Ray Data preprocessing and training workers all run on the same nodes. However, you can optionally add CPU-only nodes to your cluster and Ray Data automatically schedules preprocessing tasks on them, keeping your expensive GPU nodes free for training.
This is useful when data preprocessing is a bottleneck. If you notice low GPU use because workers are waiting on data, you can add cheaper CPU-only nodes to the cluster and Ray Data scales out preprocessing to them.
For more information, see Configuring data ingest.
Fault tolerance#
Long-running distributed training jobs are vulnerable to hardware failures. These include hardware failures, network failures, or preemption. Without fault tolerance, any of these events can force you to restart training from scratch, wasting time and compute.
Ray Train has features that handle these failures automatically. When a worker process crashes, Ray Train restarts it in place and resumes training. If an entire node goes down, Ray Train provisions a replacement and recovers from the most recent checkpoint so that only a small amount of work is lost. This makes it practical to interrupt training jobs and resume them later.
To enable automatic failure recovery, configure FailureConfig in
your RunConfig. The max_failures parameter controls how many
consecutive failures Ray Train tolerates before giving up:
from ray.train import FailureConfig
run_config = RunConfig(
storage_path="s3://my-bucket/ray-checkpoints",
failure_config=FailureConfig(max_failures=3),
checkpoint_config=CheckpointConfig(num_to_keep=2),
)
For more details, see the Ray Train fault tolerance guide.
Monitor your training jobs#
Monitoring is critical when running distributed training. The Ray dashboard displays real-time metrics including:
Training loss and validation metrics per epoch
GPU utilization and memory usage per worker
Data loading throughput
Worker status and error logs
To view the dashboard, open the link printed in the logs after Ray
initializes. Typically, this link is http://localhost:8265.
The dashboard lets you:
Monitor training progress across all workers
Inspect logs from individual workers
Identify data loading or communication bottlenecks
View resource use for CPU, GPU, and memory per worker
Debug failures with detailed error messages and stack traces
For more information, see the Ray Train monitoring documentation.
Conclusion#
In this tutorial, you:
Pre-trained a GPT-2 (~124M-parameter) language model using Hugging Face Transformers and PyTorch.
Loaded and preprocessed the Wikitext-103 dataset using Ray Data with distributed streaming.
Ran distributed training across 4 GPUs using Ray Train’s
TorchTrainerwith only minimal changes to a standard PyTorch training loop.Learned how to save and load distributed checkpoints for model recovery.
Learned how to scale to multi-node clusters by changing
ScalingConfigandRunConfig.Learned how heterogeneous clusters let you run data preprocessing on CPU nodes and training on GPU nodes for cost and performance optimization.
Learned about Ray Train’s fault tolerance mechanisms for production training jobs.
Monitored training with the Ray dashboard.
Further reading#
Total running time of the script: (1 minutes 14.010 seconds)