Custom Autograd Functions#
PyTorch allows you to define custom autograd functions with custom forward and backward implementations.
Function Base Class#
-
template<class T>
struct Function# To use custom autograd operations, implement a Function subclass with static forward and backward functions:
forwardcan take as many arguments as you want and should return either a variable list or a Variable. Use of any direct Variable arguments will be registered in the graph but no vectors/sets or any other data structures will be traversed. You can use std::optional<Tensor> as one of the arguments and it will be registered as a variable in the graph if the argument has a value. It should take a pointer totorch::autograd::AutogradContextas the first argument. Variables can be saved in thectxusingctx->save_for_backward(seetorch::autograd::AutogradContext::save_for_backward) and other data can be saved in thectx->saved_datamap (seetorch::autograd::AutogradContext::saved_data) in the form of<std::string, at::IValue>pairs.backwardshould take a pointer totorch::autograd::AutogradContextand a variable list containing as many Variables as there were outputs fromforwardas arguments. It should return as many Variables as there were inputs with each of them containing the gradient w.r.t. its corresponding input. Variables saved inforwardcan be accessed withctx->get_saved_variables(seetorch::autograd::AutogradContext::get_saved_variables) and other saved data can be accessed fromctx->saved_data. To enable compiled autograd support (torch.compile for backward) for your custom autograd operation, you can set MyFunction::is_traceable (see Function::istraceable notes below).For example:
class MyFunction : public Function<MyFunction> { public: static constexpr bool is_traceable = true; static variable_list forward(AutogradContext *ctx, int n, Variable var) { // Save data for backward in context ctx->saved_data["n"] = n; var.mul_(n); // Mark var as modified by inplace operation ctx->mark_dirty({var}); return {var}; } static variable_list backward(AutogradContext *ctx, variable_list grad_output) { // Use data saved in forward auto n = ctx->saved_data["n"].toInt(); return {grad_output[0]*n}; } };
To use
MyFunction:Variable x; auto y = MyFunction::apply(6, x); // Example backward call y[0].sum().backward();
Public Static Functions
Public Static Attributes
-
static constexpr bool is_traceable = false#
-
static constexpr bool is_traceable = false#
AutogradContext#
-
struct AutogradContext#
Context to save information during
forwardthat can be accessed inbackwardin custom autograd operations (seetorch::autograd::Functionfor details).Public Functions
-
AutogradContext() = default#
-
AutogradContext(const AutogradContext &other) = delete#
-
AutogradContext &operator=(const AutogradContext &other) = delete#
-
AutogradContext(AutogradContext &&other) = delete#
-
AutogradContext &operator=(AutogradContext &&other) = delete#
-
~AutogradContext() = default#
-
AutogradContext(PackedArgs &packed_args)#
-
void save_for_backward(variable_list to_save)#
Saves the list of variables for a future call to
backward.This should be called at most once from inside of
forward.
-
void mark_dirty(const variable_list &inputs)#
Marks variables in the list as modified in an in-place operation.
This should be called at most once from inside of
forwardand all arguments should be inputs.
-
void mark_non_differentiable(const variable_list &outputs)#
Marks outputs in the list as not requiring gradients.
This should be called at most once from inside of
forwardand all arguments should be outputs.
-
void set_materialize_grads(bool value)#
-
variable_list get_saved_variables() const#
Get the list of variables that were saved in
forwardusingsave_for_backward().Before returning them to the user, a check is made to ensure that they were not modified by any in-place operations.
-
const std::unordered_set<at::TensorImpl*> &get_and_bump_dirty() const#
-
const std::unordered_set<at::TensorImpl*> &get_non_differentiable() const#
-
bool needs_input_grad(size_t output_edge_index) const#
Expose the Node’s
task_should_compute_outputmethod to the cpp custom autograd Function asneeds_input_grad.
-
bool needs_input_grad(std::initializer_list<IndexRange> idxs) const#
Public Members
-
ska::flat_hash_map<std::string, at::IValue> saved_data#
Can be used to save non-variable data for
backward.
-
AutogradContext() = default#
Creating Custom Functions#
To create a custom autograd function, inherit from torch::autograd::Function
and implement the static forward and backward methods:
Example:
class MyReLU : public torch::autograd::Function<MyReLU> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor input) {
ctx->save_for_backward({input});
return input.clamp_min(0);
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto input = saved[0];
auto grad_output = grad_outputs[0];
auto grad_input = grad_output * (input > 0).to(grad_output.dtype());
return {grad_input};
}
};
// Usage
auto output = MyReLU::apply(input);
Custom Kernels and AutoDispatchBelowADInplaceOrView#
For users implementing custom kernels who want to redispatch below Autograd dispatch
keys, use at::AutoDispatchBelowADInplaceOrView instead of InferenceMode:
class ROIAlignFunction : public torch::autograd::Function<ROIAlignFunction> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
int64_t sampling_ratio,
bool aligned) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["aligned"] = aligned;
ctx->saved_data["input_shape"] = input.sizes();
ctx->save_for_backward({rois});
at::AutoDispatchBelowADInplaceOrView guard;
auto result = roi_align(
input, rois, spatial_scale, pooled_height,
pooled_width, sampling_ratio, aligned);
return {result};
}
};
For customized inplace and view kernels, see the custom kernel tutorial for more details.