functorch.jvp¶
-
functorch.jvp(func, primals, tangents, *, strict=False, has_aux=False)[source]¶ Standing for the Jacobian-vector product, returns a tuple containing the output of func(*primals) and the “Jacobian of
funcevaluated atprimals” timestangents. This is also known as forward-mode autodiff.- Parameters
func (function) – A Python function that takes one or more arguments, one of which must be a Tensor, and returns one or more Tensors
primals (Tensors) – Positional arguments to
functhat must all be Tensors. The returned function will also be computing the derivative with respect to these argumentstangents (Tensors) – The “vector” for which Jacobian-vector-product is computed. Must be the same structure and sizes as the inputs to
func.has_aux (bool) – Flag indicating that
funcreturns a(output, aux)tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns
Returns a
(output, jvp_out)tuple containing the output offuncevaluated atprimalsand the Jacobian-vector product. Ifhas_aux is True, then instead returns a(output, jvp_out, aux)tuple.
Warning
PyTorch’s forward-mode AD coverage on operators is not very good at the moment. You may see this API error out with “forward-mode AD not implemented for operator X”. If so, please file us a bug report and we will prioritize it.
jvp is useful when you wish to compute gradients of a function R^1 -> R^N
>>> from functorch import jvp >>> x = torch.randn([]) >>> f = lambda x: x * torch.tensor([1., 2., 3]) >>> value, grad = jvp(f, (x,), (torch.tensor(1.),)) >>> assert torch.allclose(value, f(x)) >>> assert torch.allclose(grad, torch.tensor([1., 2, 3]))
jvp()can support functions with multiple inputs by passing in the tangents for each of the inputs>>> from functorch import jvp >>> x = torch.randn(5) >>> y = torch.randn(5) >>> f = lambda x, y: (x * y) >>> _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) >>> assert torch.allclose(output, x + y)