Saving and Loading#
The primary interface for serialization uses the torch::save and
torch::load functions, which can save and load tensors, modules,
and optimizers.
Save Functions#
-
template<typename Value, typename ...SaveToArgs>
void torch::save(const Value &value, SaveToArgs&&... args)# Serializes the given
value.There must be an overload of
operator<<betweenserialize::OutputArchiveandValuefor this method to be well-formed. Currently, such an overload is provided for (subclasses of):torch::Tensor
To perform the serialization, a
serialize::OutputArchiveis constructed, and all arguments after thevalueare forwarded to itssave_tomethod. For example, you can pass a filename, or anostream... code-block:: cpp
torch::nn::Linear model(3, 4);
torch::save(model, “model.pt”);
torch::optim::SGD sgd(model->parameters(), 0.9); // 0.9 is learning rate
std::ostringstream stream;
// Note that the same stream cannot be used in multiple torch::save(…)
// invocations, otherwise the header will be corrupted.
torch::save(sgd, stream);
auto tensor = torch::ones({3, 4});
torch::save(tensor, “my_tensor.pt”);
-
template<typename ...SaveToArgs>
void torch::save(const std::vector<torch::Tensor> &tensor_vec, SaveToArgs&&... args)# Serializes the given
tensor_vecof typestd::vector<torch::Tensor>.To perform the serialization, a
serialize::OutputArchiveis constructed, and all arguments after thetensor_vecare forwarded to itssave_tomethod. For example, you can pass a filename, or anostream... code-block:: cpp
std::vectortorch::Tensor tensor_vec = { torch::randn({1, 2}),
torch::randn({3, 4}) }; torch::save(tensor_vec, “my_tensor_vec.pt”);
std::vectortorch::Tensor tensor_vec = { torch::randn({5, 6}),
torch::randn({7, 8}) }; std::ostringstream stream;
// Note that the same stream cannot be used in multiple torch::save(…)
// invocations, otherwise the header will be corrupted.
torch::save(tensor_vec, stream);
Load Functions#
-
template<typename Value, typename ...LoadFromArgs>
void torch::load(Value &value, LoadFromArgs&&... args)# Deserializes the given
value.There must be an overload of
operator>>betweenserialize::InputArchiveandValuefor this method to be well-formed. Currently, such an overload is provided for (subclasses of):torch::Tensor
To perform the serialization, a
serialize::InputArchiveis constructed, and all arguments after thevalueare forwarded to itsload_frommethod. For example, you can pass a filename, or anistream... code-block:: cpp
torch::nn::Linear model(3, 4);
torch::load(model, “model.pt”);
torch::optim::SGD sgd(model->parameters(), 0.9); // 0.9 is learning rate
std::istringstream stream(”…”);
torch::load(sgd, stream);
auto tensor = torch::ones({3, 4});
torch::load(tensor, “my_tensor.pt”);
-
template<typename ...LoadFromArgs>
void torch::load(std::vector<torch::Tensor> &tensor_vec, LoadFromArgs&&... args)# Deserializes the given
tensor_vecof typestd::vector<torch::Tensor>.To perform the serialization, a
serialize::InputArchiveis constructed, and all arguments after thevalueare forwarded to itsload_frommethod. For example, you can pass a filename, or anistream... code-block:: cpp
std::vectortorch::Tensor tensor_vec;
torch::load(tensor_vec, “my_tensor_vec.pt”);
std::vectortorch::Tensor tensor_vec;
std::istringstream stream(”…”);
torch::load(tensor_vec, stream);
Saving and Loading Tensors#
// Save a tensor
torch::Tensor tensor = torch::randn({2, 3});
torch::save(tensor, "tensor.pt");
// Load a tensor
torch::Tensor loaded;
torch::load(loaded, "tensor.pt");
Saving and Loading Modules#
// Define a model
struct Net : torch::nn::Module {
Net() {
fc1 = register_module("fc1", torch::nn::Linear(784, 64));
fc2 = register_module("fc2", torch::nn::Linear(64, 10));
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(fc1->forward(x));
return fc2->forward(x);
}
torch::nn::Linear fc1{nullptr}, fc2{nullptr};
};
// Save model
auto model = std::make_shared<Net>();
torch::save(model, "model.pt");
// Load model
auto loaded_model = std::make_shared<Net>();
torch::load(loaded_model, "model.pt");
Saving Optimizer State#
auto model = std::make_shared<Net>();
auto optimizer = torch::optim::Adam(model->parameters(), 0.001);
// Train...
// Save both model and optimizer
torch::save(model, "model.pt");
torch::save(optimizer, "optimizer.pt");
// Load both
torch::load(model, "model.pt");
torch::load(optimizer, "optimizer.pt");