functorch.grad_and_value#
- functorch.grad_and_value(func, argnums=0, has_aux=False)[source]#
Returns a function to compute a tuple of the gradient and primal, or forward, computation.
- Parameters
func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified
has_auxequalsTrue, function can return a tuple of single-element Tensor and other auxiliary objects:(output, aux).argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to.
argnumscan be single integer or tuple of integers. Default: 0.has_aux (bool) – Flag indicating that
funcreturns a tensor and other auxiliary objects:(output, aux). Default: False.
- Returns
Function to compute a tuple of gradients with respect to its inputs and the forward computation. By default, the output of the function is a tuple of the gradient tensor(s) with respect to the first argument and the primal computation. If specified
has_auxequalsTrue, tuple of gradients and tuple of the forward computation with output auxiliary objects is returned. Ifargnumsis a tuple of integers, a tuple of a tuple of the output gradients with respect to eachargnumsvalue and the forward computation is returned.
See
grad()for examplesWarning
We’ve integrated functorch into PyTorch. As the final step of the integration, functorch.grad_and_value is deprecated as of PyTorch 2.0 and will be deleted in a future version of PyTorch >= 2.3. Please use torch.func.grad_and_value instead; see the PyTorch 2.0 release notes and/or the torch.func migration guide for more details https://pytorch.org/docs/main/func.migrating.html