BasePruningMethod¶
- 
class 
torch.nn.utils.prune.BasePruningMethod[source]¶ Abstract base class for creation of new pruning techniques.
Provides a skeleton for customization requiring the overriding of methods such as
compute_mask()andapply().- 
classmethod 
apply(module, name, *args, importance_scores=None, **kwargs)[source]¶ Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask.
- Parameters
 module (nn.Module) – module containing the tensor to prune
name (str) – parameter name within
moduleon which pruning will act.args – arguments passed on to a subclass of
BasePruningMethodimportance_scores (torch.Tensor) – tensor of importance scores (of same shape as module parameter) used to compute mask for pruning. The values in this tensor indicate the importance of the corresponding elements in the parameter being pruned. If unspecified or None, the parameter will be used in its place.
kwargs – keyword arguments passed on to a subclass of a
BasePruningMethod
- 
apply_mask(module)[source]¶ Simply handles the multiplication between the parameter being pruned and the generated mask. Fetches the mask and the original tensor from the module and returns the pruned version of the tensor.
- Parameters
 module (nn.Module) – module containing the tensor to prune
- Returns
 pruned version of the input tensor
- Return type
 pruned_tensor (torch.Tensor)
- 
abstract 
compute_mask(t, default_mask)[source]¶ Computes and returns a mask for the input tensor
t. Starting from a basedefault_mask(which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of thedefault_maskaccording to the specific pruning method recipe.- Parameters
 t (torch.Tensor) – tensor representing the importance scores of the
to prune. (parameter) –
default_mask (torch.Tensor) – Base mask from previous pruning
iterations –
need to be respected after the new mask is (that) –
Same dims as t. (applied.) –
- Returns
 mask to apply to
t, of same dims ast- Return type
 mask (torch.Tensor)
- 
prune(t, default_mask=None, importance_scores=None)[source]¶ Computes and returns a pruned version of input tensor
taccording to the pruning rule specified incompute_mask().- Parameters
 t (torch.Tensor) – tensor to prune (of same dimensions as
default_mask).importance_scores (torch.Tensor) – tensor of importance scores (of same shape as
t) used to compute mask for pruningt. The values in this tensor indicate the importance of the corresponding elements in thetthat is being pruned. If unspecified or None, the tensortwill be used in its place.default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones.
- Returns
 pruned version of tensor
t.
- 
remove(module)[source]¶ Removes the pruning reparameterization from a module. The pruned parameter named
nameremains permanently pruned, and the parameter namedname+'_orig'is removed from the parameter list. Similarly, the buffer namedname+'_mask'is removed from the buffers.Note
Pruning itself is NOT undone or reversed!
- 
classmethod