enable_grad¶
- class torch.enable_grad(orig_func=None)[source][source]¶
- Context-manager that enables gradient calculation. - Enables gradient calculation, if it has been disabled via - no_grador- set_grad_enabled.- This context manager is thread local; it will not affect computation in other threads. - Also functions as a decorator. - Note - enable_grad is one of several mechanisms that can enable or disable gradients locally see Locally disabling gradient computation for more information on how they compare. - Note - This API does not apply to forward-mode AD. - Example::
- >>> x = torch.tensor([1.], requires_grad=True) >>> with torch.no_grad(): ... with torch.enable_grad(): ... y = x * 2 >>> y.requires_grad True >>> y.backward() >>> x.grad tensor([2.]) >>> @torch.enable_grad() ... def doubler(x): ... return x * 2 >>> with torch.no_grad(): ... z = doubler(x) >>> z.requires_grad True >>> @torch.enable_grad() ... def tripler(x): ... return x * 3 >>> with torch.no_grad(): ... z = tripler(x) >>> z.requires_grad True