torch.func.vjp#
- torch.func.vjp(func, *primals, has_aux=False)[source]#
Standing for the vector-Jacobian product, returns a tuple containing the results of
funcapplied toprimalsand a function that, when givencotangents, computes the reverse-mode Jacobian offuncwith respect toprimalstimescotangents.- Parameters
func (Callable) – A Python function that takes one or more arguments. Must return one or more Tensors.
primals (Tensors) – Positional arguments to
functhat must all be Tensors. The returned function will also be computing the derivative with respect to these argumentshas_aux (bool) – Flag indicating that
funcreturns a(output, aux)tuple where the first element is the output of the function to be differentiated and the second element is other auxiliary objects that will not be differentiated. Default: False.
- Returns
Returns a
(output, vjp_fn)tuple containing the output offuncapplied toprimalsand a function that computes the vjp offuncwith respect to allprimalsusing the cotangents passed to the returned function. Ifhas_aux is True, then instead returns a(output, vjp_fn, aux)tuple. The returnedvjp_fnfunction will return a tuple of each VJP.
When used in simple cases,
vjp()behaves the same asgrad()>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.0))[0] >>> assert torch.allclose(grad, torch.func.grad(f)(x))
However,
vjp()can support functions with multiple outputs by passing in the cotangents for each of the outputs>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()can even support outputs being Python structs>>> x = torch.randn([5]) >>> f = lambda x: {"first": x.sin(), "second": x.cos()} >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> cotangents = {"first": torch.ones([5]), "second": torch.ones([5])} >>> vjps = vjpfunc(cotangents) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by
vjp()will compute the partials with respect to each of theprimals>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = torch.func.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primalsare the positional arguments forf. All kwargs use their default value>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * scale >>> >>> (_, vjpfunc) = torch.func.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.0))
Note
Using PyTorch
torch.no_gradtogether withvjp. Case 1: Usingtorch.no_gradinside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
vjp(f)(x)will respect the innertorch.no_grad.Case 2: Using
vjpinsidetorch.no_gradcontext manager:>>> with torch.no_grad(): >>> vjp(f)(x)
In this case,
vjpwill respect the innertorch.no_grad, but not the outer one. This is becausevjpis a “function transform”: its result should not depend on the result of a context manager outside off.