CustomFromMask¶
- class torch.nn.utils.prune.CustomFromMask(mask)[source]¶
- classmethod apply(module, name, mask)[source]¶
- Add pruning on the fly and reparametrization of a tensor. - Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. 
 - apply_mask(module)¶
- Simply handles the multiplication between the parameter being pruned and the generated mask. - Fetches the mask and the original tensor from the module and returns the pruned version of the tensor. - Parameters
- module (nn.Module) – module containing the tensor to prune 
- Returns
- pruned version of the input tensor 
- Return type
- pruned_tensor (torch.Tensor) 
 
 - prune(t, default_mask=None, importance_scores=None)¶
- Compute and returns a pruned version of input tensor - t.- According to the pruning rule specified in - compute_mask().- Parameters
- t (torch.Tensor) – tensor to prune (of same dimensions as - default_mask).
- importance_scores (torch.Tensor) – tensor of importance scores (of same shape as - t) used to compute mask for pruning- t. The values in this tensor indicate the importance of the corresponding elements in the- tthat is being pruned. If unspecified or None, the tensor- twill be used in its place.
- default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones. 
 
- Returns
- pruned version of tensor - t.
 
 - remove(module)¶
- Remove the pruning reparameterization from a module. - The pruned parameter named - nameremains permanently pruned, and the parameter named- name+'_orig'is removed from the parameter list. Similarly, the buffer named- name+'_mask'is removed from the buffers.- Note - Pruning itself is NOT undone or reversed!