functorch.vjp¶
-
functorch.vjp(f, *primals)[source]¶ Standing for the vector-Jacobian product, returns a tuple containing the results of
fapplied toprimalsand a function that, when givencotangents, computes the reverse-mode Jacobian offwith respect toprimalstimescotangents.- Parameters
f (Callable) – A Python function that takes one or more arguments. Must return one or more Tensors.
primals (Tensors) – Positional arguments to
fthat must all be Tensors. The returned function will also be computing the derivative with respect to these arguments
- Returns
Returns a tuple containing the output of
fapplied toprimalsand a function that computes the vjp offwith respect to allprimalsusing the cotangents passed to the returned function. The returned function will return a tuple of each VJP
When used in simple cases,
vjp()behaves the same asgrad()>>> x = torch.randn([5]) >>> f = lambda x: x.sin().sum() >>> (_, vjpfunc) = functorch.vjp(f, x) >>> grad = vjpfunc(torch.tensor(1.))[0] >>> assert torch.allclose(grad, functorch.grad(f)(x))
However,
vjp()can support functions with multiple outputs by passing in the cotangents for each of the outputs>>> x = torch.randn([5]) >>> f = lambda x: (x.sin(), x.cos()) >>> (_, vjpfunc) = functorch.vjp(f, x) >>> vjps = vjpfunc((torch.ones([5]), torch.ones([5]))) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
vjp()can even support outputs being Python structs>>> x = torch.randn([5]) >>> f = lambda x: {'first': x.sin(), 'second': x.cos()} >>> (_, vjpfunc) = functorch.vjp(f, x) >>> cotangents = {'first': torch.ones([5]), 'second': torch.ones([5])} >>> vjps = vjpfunc((cotangents,)) >>> assert torch.allclose(vjps[0], x.cos() + -x.sin())
The function returned by
vjp()will compute the partials with respect to each of theprimals>>> x, y = torch.randn([5, 4]), torch.randn([4, 5]) >>> (_, vjpfunc) = functorch.vjp(torch.matmul, x, y) >>> cotangents = torch.randn([5, 5]) >>> vjps = vjpfunc(cotangents) >>> assert len(vjps) == 2 >>> assert torch.allclose(vjps[0], torch.matmul(cotangents, y.transpose(0, 1))) >>> assert torch.allclose(vjps[1], torch.matmul(x.transpose(0, 1), cotangents))
primalsare the positional arguments forf. All kwargs use their default value>>> x = torch.randn([5]) >>> def f(x, scale=4.): >>> return x * 4. >>> >>> (_, vjpfunc) = functorch.vjp(f, x) >>> vjps = vjpfunc(torch.ones_like(x)) >>> assert torch.allclose(vjps[0], torch.full(x.shape, 4.))
Note
Using PyTorch
torch.no_gradtogether withvjp. Case 1: Usingtorch.no_gradinside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
vjp(f)(x)will respect the innertorch.no_grad.Case 2: Using
vjpinsidetorch.no_gradcontext manager:>>> with torch.no_grad(): >>> vjp(f)(x)
In this case,
vjpwill respect the innertorch.no_grad, but not the outer one. This is becausevjpis a “function transform”: its result should not depend on the result of a context manager outside off.