torch.nn.utils.prune.custom_from_mask#
- torch.nn.utils.prune.custom_from_mask(module, name, mask)[source]#
- Prune tensor corresponding to parameter called - namein- moduleby applying the pre-computed mask in- mask.- Modifies module in place (and also return the modified module) by: - adding a named buffer called - name+'_mask'corresponding to the binary mask applied to the parameter- nameby the pruning method.
- replacing the parameter - nameby its pruned version, while the original (unpruned) parameter is stored in a new parameter named- name+'_orig'.
 - Parameters
- Returns
- modified (i.e. pruned) version of the input module 
- Return type
- module (nn.Module) 
 - Examples - >>> from torch.nn.utils import prune >>> m = prune.custom_from_mask( ... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0]) ... ) >>> print(m.bias_mask) tensor([0., 1., 0.])