torch.autograd.function.FunctionCtx.set_materialize_grads¶
- FunctionCtx.set_materialize_grads(value)[source]¶
Sets whether to materialize output grad tensors. Default is
True.This should be called only from inside the
forward()methodIf
True, undefined output grad tensors will be expanded to tensors full of zeros prior to calling thebackward()method.- Example::
>>> class SimpleFunc(Function): >>> @staticmethod >>> def forward(ctx, x): >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> return g1 + g2 # No check for None necessary >>> >>> # We modify SimpleFunc to handle non-materialized grad outputs >>> class Func(Function): >>> @staticmethod >>> def forward(ctx, x): >>> ctx.set_materialize_grads(False) >>> ctx.save_for_backward(x) >>> return x.clone(), x.clone() >>> >>> @staticmethod >>> @once_differentiable >>> def backward(ctx, g1, g2): >>> x, = ctx.saved_tensors >>> grad_input = torch.zeros_like(x) >>> if g1 is not None: # We must check for None now >>> grad_input += g1 >>> if g2 is not None: >>> grad_input += g2 >>> return grad_input >>> >>> a = torch.tensor(1., requires_grad=True) >>> b, _ = Func.apply(a) # induces g2 to be undefined