torch.autograd.function.FunctionCtx.save_for_backward¶
-
FunctionCtx.save_for_backward(*tensors)[source]¶ Saves given tensors for a future call to
backward().This should be called at most once, and only from inside the
forward()method. This should only be called with input or output tensorsIn
backward(), saved tensors can be accessed through thesaved_tensorsattribute. Before returning them to the user, a check is made to ensure they weren’t used in any in-place operation that modified their content.Arguments can also be
None. This is a no-op.See Extending torch.autograd for more details on how to use this method.
- Example::
>>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int): >>> w = x * y * z >>> out = x * y + y * z + w >>> ctx.save_for_backward(x, y, out) >>> ctx.z = z # z is not a tensor >>> ctx.w = w # w is neither input nor output >>> return out >>> >>> @staticmethod >>> def backward(ctx, grad_out): >>> x, y, out = ctx.saved_tensors >>> z = ctx.z >>> gx = grad_out * (y + y * z) >>> gy = grad_out * (x + z + x * z) >>> gz = None >>> return gx, gy, gz >>> >>> a = torch.tensor(1., requires_grad=True, dtype=torch.double) >>> b = torch.tensor(2., requires_grad=True, dtype=torch.double) >>> c = 4 >>> d = Func.apply(a, b, c)