Checkpoints#
Checkpoints save the complete training state so you can resume training after interruption. A checkpoint typically includes:
Model parameters
Optimizer state (momentum buffers, learning rates)
Current epoch number
Best validation loss/accuracy
Creating Checkpoints#
void save_checkpoint(
std::shared_ptr<Net> model,
torch::optim::Adam& optimizer,
int epoch,
const std::string& path) {
torch::serialize::OutputArchive archive;
model->save(archive);
archive.write("epoch", torch::tensor(epoch));
optimizer.save(archive);
archive.save_to(path);
}
Loading Checkpoints#
int load_checkpoint(
std::shared_ptr<Net> model,
torch::optim::Adam& optimizer,
const std::string& path) {
torch::serialize::InputArchive archive;
archive.load_from(path);
model->load(archive);
torch::Tensor epoch_tensor;
archive.read("epoch", epoch_tensor);
optimizer.load(archive);
return epoch_tensor.item<int>();
}
Complete Checkpoint Example#
#include <torch/torch.h>
#include <iostream>
#include <filesystem>
struct Net : torch::nn::Module {
Net() {
fc1 = register_module("fc1", torch::nn::Linear(784, 256));
fc2 = register_module("fc2", torch::nn::Linear(256, 10));
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(fc1->forward(x.view({-1, 784})));
return fc2->forward(x);
}
torch::nn::Linear fc1{nullptr}, fc2{nullptr};
};
int main() {
auto model = std::make_shared<Net>();
auto optimizer = torch::optim::Adam(model->parameters(), 1e-3);
int start_epoch = 0;
const std::string checkpoint_path = "checkpoint.pt";
// Resume from checkpoint if it exists
if (std::filesystem::exists(checkpoint_path)) {
std::cout << "Loading checkpoint..." << std::endl;
start_epoch = load_checkpoint(model, optimizer, checkpoint_path);
std::cout << "Resuming from epoch " << start_epoch << std::endl;
}
// Training loop
for (int epoch = start_epoch; epoch < 100; ++epoch) {
// ... training code ...
// Save checkpoint every 10 epochs
if ((epoch + 1) % 10 == 0) {
save_checkpoint(model, optimizer, epoch + 1, checkpoint_path);
std::cout << "Saved checkpoint at epoch " << epoch + 1 << std::endl;
}
}
return 0;
}
Best Practices#
Save periodically: Save checkpoints at regular intervals (e.g., every epoch or every N batches) to minimize lost work.
Keep multiple checkpoints: Maintain the last few checkpoints in case the most recent one is corrupted or represents a poor model state.
Include all state: Save everything needed to resume training, including learning rate scheduler state if using one.
Verify checkpoints: Occasionally verify that checkpoints can be loaded correctly.