functorch.combine_state_for_ensemble¶
-
functorch.combine_state_for_ensemble(models) → func, params, buffers[source]¶ Prepares a list of torch.nn.Modules for ensembling with
vmap().Given a list of
Mnn.Modulesof the same class, stacks all of their parameters and buffers together to makeparamsandbuffers. Each parameter and buffer in the result will have an additional dimension of sizeM.combine_state_for_ensemble()also returnsfunc, a functional version of one of the models inmodels. One cannot directly runfunc(params, buffers, *args, **kwargs)directly, you probably want to usevmap(func, ...)(params, buffers, *args, **kwargs)Here’s an example of how to ensemble over a very simple model:
num_models = 5 batch_size = 64 in_features, out_features = 3, 3 models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] data = torch.randn(batch_size, 3) fmodel, params, buffers = combine_state_for_ensemble(models) output = vmap(fmodel, (0, 0, None))(params, buffers, data) assert output.shape == (num_models, batch_size, out_features)