--- myst: html_meta: description: Checkpointing in PyTorch C++ — saving and resuming training state. keywords: PyTorch, C++, checkpoint, save, resume, training state --- # Checkpoints Checkpoints save the complete training state so you can resume training after interruption. A checkpoint typically includes: - Model parameters - Optimizer state (momentum buffers, learning rates) - Current epoch number - Best validation loss/accuracy ## Creating Checkpoints ```cpp void save_checkpoint( std::shared_ptr model, torch::optim::Adam& optimizer, int epoch, const std::string& path) { torch::serialize::OutputArchive archive; model->save(archive); archive.write("epoch", torch::tensor(epoch)); optimizer.save(archive); archive.save_to(path); } ``` ## Loading Checkpoints ```cpp int load_checkpoint( std::shared_ptr model, torch::optim::Adam& optimizer, const std::string& path) { torch::serialize::InputArchive archive; archive.load_from(path); model->load(archive); torch::Tensor epoch_tensor; archive.read("epoch", epoch_tensor); optimizer.load(archive); return epoch_tensor.item(); } ``` ## Complete Checkpoint Example ```cpp #include #include #include struct Net : torch::nn::Module { Net() { fc1 = register_module("fc1", torch::nn::Linear(784, 256)); fc2 = register_module("fc2", torch::nn::Linear(256, 10)); } torch::Tensor forward(torch::Tensor x) { x = torch::relu(fc1->forward(x.view({-1, 784}))); return fc2->forward(x); } torch::nn::Linear fc1{nullptr}, fc2{nullptr}; }; int main() { auto model = std::make_shared(); auto optimizer = torch::optim::Adam(model->parameters(), 1e-3); int start_epoch = 0; const std::string checkpoint_path = "checkpoint.pt"; // Resume from checkpoint if it exists if (std::filesystem::exists(checkpoint_path)) { std::cout << "Loading checkpoint..." << std::endl; start_epoch = load_checkpoint(model, optimizer, checkpoint_path); std::cout << "Resuming from epoch " << start_epoch << std::endl; } // Training loop for (int epoch = start_epoch; epoch < 100; ++epoch) { // ... training code ... // Save checkpoint every 10 epochs if ((epoch + 1) % 10 == 0) { save_checkpoint(model, optimizer, epoch + 1, checkpoint_path); std::cout << "Saved checkpoint at epoch " << epoch + 1 << std::endl; } } return 0; } ``` ## Best Practices 1. **Save periodically**: Save checkpoints at regular intervals (e.g., every epoch or every N batches) to minimize lost work. 2. **Keep multiple checkpoints**: Maintain the last few checkpoints in case the most recent one is corrupted or represents a poor model state. 3. **Include all state**: Save everything needed to resume training, including learning rate scheduler state if using one. 4. **Verify checkpoints**: Occasionally verify that checkpoints can be loaded correctly.