RandomStructured¶
- class torch.nn.utils.prune.RandomStructured(amount, dim=-1)[source]¶
- Prune entire (currently unpruned) channels in a tensor at random. - Parameters
- amount (int or float) – quantity of parameters to prune. If - float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If- int, it represents the absolute number of parameters to prune.
- dim (int, optional) – index of the dim along which we define channels to prune. Default: -1. 
 
 - classmethod apply(module, name, amount, dim=-1)[source]¶
- Add pruning on the fly and reparametrization of a tensor. - Adds the forward pre-hook that enables pruning on the fly and the reparametrization of a tensor in terms of the original tensor and the pruning mask. - Parameters
- module (nn.Module) – module containing the tensor to prune 
- name (str) – parameter name within - moduleon which pruning will act.
- amount (int or float) – quantity of parameters to prune. If - float, should be between 0.0 and 1.0 and represent the fraction of parameters to prune. If- int, it represents the absolute number of parameters to prune.
- dim (int, optional) – index of the dim along which we define channels to prune. Default: -1. 
 
 
 - apply_mask(module)¶
- Simply handles the multiplication between the parameter being pruned and the generated mask. - Fetches the mask and the original tensor from the module and returns the pruned version of the tensor. - Parameters
- module (nn.Module) – module containing the tensor to prune 
- Returns
- pruned version of the input tensor 
- Return type
- pruned_tensor (torch.Tensor) 
 
 - compute_mask(t, default_mask)[source]¶
- Compute and returns a mask for the input tensor - t.- Starting from a base - default_mask(which should be a mask of ones if the tensor has not been pruned yet), generate a random mask to apply on top of the- default_maskby randomly zeroing out channels along the specified dim of the tensor.- Parameters
- t (torch.Tensor) – tensor representing the parameter to prune 
- default_mask (torch.Tensor) – Base mask from previous pruning iterations, that need to be respected after the new mask is applied. Same dims as - t.
 
- Returns
- mask to apply to - t, of same dims as- t
- Return type
- mask (torch.Tensor) 
- Raises
- IndexError – if - self.dim >= len(t.shape)
 
 - prune(t, default_mask=None, importance_scores=None)¶
- Compute and returns a pruned version of input tensor - t.- According to the pruning rule specified in - compute_mask().- Parameters
- t (torch.Tensor) – tensor to prune (of same dimensions as - default_mask).
- importance_scores (torch.Tensor) – tensor of importance scores (of same shape as - t) used to compute mask for pruning- t. The values in this tensor indicate the importance of the corresponding elements in the- tthat is being pruned. If unspecified or None, the tensor- twill be used in its place.
- default_mask (torch.Tensor, optional) – mask from previous pruning iteration, if any. To be considered when determining what portion of the tensor that pruning should act on. If None, default to a mask of ones. 
 
- Returns
- pruned version of tensor - t.
 
 - remove(module)¶
- Remove the pruning reparameterization from a module. - The pruned parameter named - nameremains permanently pruned, and the parameter named- name+'_orig'is removed from the parameter list. Similarly, the buffer named- name+'_mask'is removed from the buffers.- Note - Pruning itself is NOT undone or reversed!