functorch.jacrev¶
-
functorch.jacrev(f, argnums=0)[source]¶ Computes the Jacobian of
fwith respect to the arg(s) at indexargnumusing reverse mode autodiff- Parameters
- Returns
Returns a function that takes in the same inputs as
fand returns the Jacobian offwith respect to the arg(s) atargnums
A basic usage with a pointwise, unary operation will give a diagonal array as the Jacobian
>>> from functorch import jacrev >>> x = torch.randn(5) >>> jacobian = jacrev(torch.sin)(x) >>> expected = torch.diag(torch.cos(x)) >>> assert torch.allclose(jacobian, expected)
jacrev()can be composed with vmap to produce batched Jacobians:>>> from functorch import jacrev >>> x = torch.randn(64, 5) >>> jacobian = vmap(jacrev(torch.sin))(x) >>> assert jacobian.shape == (64, 5, 5)
Additionally,
jacrev()can be composed with itself to produce Hessians>>> from functorch import jacrev >>> def f(x): >>> return x.sin().sum() >>> >>> x = torch.randn(5) >>> hessian = jacrev(jacrev(f))(x) >>> assert torch.allclose(hessian, torch.diag(-x.sin()))
By default,
jacrev()computes the Jacobian with respect to the first input. However, it can compute the Jacboian with respect to a different argument by usingargnums:>>> from functorch import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=1)(x, y) >>> expected = torch.diag(2 * y) >>> assert torch.allclose(jacobian, expected)
Additionally, passing a tuple to
argnumswill compute the Jacobian with respect to multiple arguments>>> from functorch import jacrev >>> def f(x, y): >>> return x + y ** 2 >>> >>> x, y = torch.randn(5), torch.randn(5) >>> jacobian = jacrev(f, argnums=(0,1))(x, y) >>> expectedX = torch.diag(torch.ones_like(x)) >>> expectedY = torch.diag(2 * y) >>> assert torch.allclose(jacobian[0], expectedX) >>> assert torch.allclose(jacobian[1], expectedY)
Note
Using PyTorch
torch.no_gradtogether withjacrev. Case 1: Usingtorch.no_gradinside a function:>>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c
In this case,
jacrev(f)(x)will respect the innertorch.no_grad.Case 2: Using
jacrevinsidetorch.no_gradcontext manager:>>> with torch.no_grad(): >>> jacrev(f)(x)
In this case,
jacrevwill respect the innertorch.no_grad, but not the outer one. This is becausejacrevis a “function transform”: its result should not depend on the result of a context manager outside off.