torch.func.vjp¶
- torch.func.vjp(func, *primals, has_aux=False)¶
- Standing for the vector-Jacobian product, returns a tuple containing the results of - funcapplied to- primalsand a function that, when given- cotangents, computes the reverse-mode Jacobian of- funcwith respect to- primalstimes- cotangents.- Parameters
- func (Callable) – A Python function that takes one or more arguments. Must return one or more Tensors. 
- primals (Tensors) – Positional arguments to - functhat must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- has_aux (bool) – Flag indicating that - funcreturns a- (output, aux)tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
 
- Returns
- Returns a - (output, vjp_fn)tuple containing the output of- funcapplied to- primalsand a function that computes the vjp of- funcwith respect to all- primalsusing the cotangents passed to the returned function. If- has_aux is True, then instead returns a- (output, vjp_fn, aux)tuple. The returned- vjp_fnfunction will return a tuple of each VJP.
 - When used in simple cases, - vjp()behaves the same as- grad()- >>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x)) - However, - vjp()can support functions with multiple outputs by passing in the cotangents for each of the outputs- >>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) - vjp()can even support outputs being Python structs- >>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin()) - The function returned by - vjp()will compute the partials with respect to each of the- primals- >>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents)) - primalsare the positional arguments for- f. All kwargs use their default value- >>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.)) - Note - Using PyTorch - torch.no_gradtogether with- vjp. Case 1: Using- torch.no_gradinside a function:- >>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c - In this case, - vjp(f)(x)will respect the inner- torch.no_grad.- Case 2: Using - vjpinside- torch.no_gradcontext manager:- >>> with torch.no_grad(): >>> vjp(f)(x) - In this case, - vjpwill respect the inner- torch.no_grad, but not the outer one. This is because- vjpis a “function transform”: its result should not depend on the result of a context manager outside of- f.