torch.nn.utils.fuse_linear_bn_weights#
- torch.nn.utils.fuse_linear_bn_weights(linear_w, linear_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b)[source]#
- Fuse linear module parameters and BatchNorm module parameters into new linear module parameters. - Parameters
- linear_w (torch.Tensor) – Linear weight. 
- linear_b (Optional[torch.Tensor]) – Linear bias. 
- bn_rm (torch.Tensor) – BatchNorm running mean. 
- bn_rv (torch.Tensor) – BatchNorm running variance. 
- bn_eps (float) – BatchNorm epsilon. 
- bn_w (torch.Tensor) – BatchNorm weight. 
- bn_b (torch.Tensor) – BatchNorm bias. 
 
- Returns
- Fused linear weight and bias. 
- Return type
- Tuple[torch.nn.Parameter, torch.nn.Parameter]