torch.func.grad¶
- torch.func.grad(func, argnums=0, has_aux=False)¶
- gradoperator helps computing gradients of- funcwith respect to the input(s) specified by- argnums. This operator can be nested to compute higher-order gradients.- Parameters
- func (Callable) – A Python function that takes one or more arguments. Must return a single-element Tensor. If specified - has_auxequals- True, function can return a tuple of single-element Tensor and other auxiliary objects:- (output, aux).
- argnums (int or Tuple[int]) – Specifies arguments to compute gradients with respect to. - argnumscan be single integer or tuple of integers. Default: 0.
- has_aux (bool) – Flag indicating that - funcreturns a tensor and other auxiliary objects:- (output, aux). Default: False.
 
- Returns
- Function to compute gradients with respect to its inputs. By default, the output of the function is the gradient tensor(s) with respect to the first argument. If specified - has_auxequals- True, tuple of gradients and output auxiliary objects is returned. If- argnumsis a tuple of integers, a tuple of output gradients with respect to each- argnumsvalue is returned.
- Return type
 - Example of using - grad:- >>> from torch.func import grad >>> x = torch.randn([]) >>> cos_x = grad(lambda x: torch.sin(x))(x) >>> assert torch.allclose(cos_x, x.cos()) >>> >>> # Second-order gradients >>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) >>> assert torch.allclose(neg_sin_x, -x.sin()) - When composed with - vmap,- gradcan be used to compute per-sample-gradients:- >>> from torch.func import grad, vmap >>> batch_size, feature_size = 3, 5 >>> >>> def model(weights, feature_vec): >>> # Very simple linear model with activation >>> assert feature_vec.dim() == 1 >>> return feature_vec.dot(weights).relu() >>> >>> def compute_loss(weights, example, target): >>> y = model(weights, example) >>> return ((y - target) ** 2).mean() # MSELoss >>> >>> weights = torch.randn(feature_size, requires_grad=True) >>> examples = torch.randn(batch_size, feature_size) >>> targets = torch.randn(batch_size) >>> inputs = (weights, examples, targets) >>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs) - Example of using - gradwith- has_auxand- argnums:- >>> from torch.func import grad >>> def my_loss_func(y, y_pred): >>> loss_per_sample = (0.5 * y_pred - y) ** 2 >>> loss = loss_per_sample.mean() >>> return loss, (y_pred, loss_per_sample) >>> >>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True) >>> y_true = torch.rand(4) >>> y_preds = torch.rand(4, requires_grad=True) >>> out = fn(y_true, y_preds) >>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample)) - Note - Using PyTorch - torch.no_gradtogether with- grad.- Case 1: Using - torch.no_gradinside a function:- >>> def f(x): >>> with torch.no_grad(): >>> c = x ** 2 >>> return x - c - In this case, - grad(f)(x)will respect the inner- torch.no_grad.- Case 2: Using - gradinside- torch.no_gradcontext manager:- >>> with torch.no_grad(): >>> grad(f)(x) - In this case, - gradwill respect the inner- torch.no_grad, but not the outer one. This is because- gradis a “function transform”: its result should not depend on the result of a context manager outside of- f.