Rate this Page

PyTorch: Defining New autograd Functions#

Created On: Dec 03, 2020 | Last Updated: May 15, 2026 | Last Verified: Nov 05, 2024

A third order polynomial, trained to predict \(y=\sin(x)\) from \(-\pi\) to \(\pi\) by minimizing squared Euclidean distance. Instead of writing the polynomial as \(y=a+bx+cx^2+dx^3\), we write the polynomial as \(y=a+b P_3(c+dx)\) where \(P_3(x)=\frac{1}{2}\left(5x^3-3x\right)\) is the Legendre polynomial of degree three.

This implementation computes the forward pass using operations on PyTorch Tensors, and uses PyTorch autograd to compute gradients.

In this implementation we implement our own custom autograd function to perform \(P_3'(x)\). By mathematics, \(P_3'(x)=\frac{3}{2}\left(5x^2-1\right)\)

import torch
import math


class LegendrePolynomial3(torch.autograd.Function):
    """
    We can implement our own custom autograd Functions by subclassing
    torch.autograd.Function and implementing the forward and backward passes
    which operate on Tensors.
    """

    @staticmethod
    def forward(input):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. Check out `Extending torch.autograd <https://docs.pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd>`_
        for further details.
        """
        return 0.5 * (5 * input ** 3 - 3 * input)

    @staticmethod
    def setup_context(ctx, inputs, output):
        """
        Store input for use in the backward pass using ``ctx.save_for_backward``.
        Other objects can be stored directly as attributes on the ctx object,
        such as ``ctx.my_object = my_object``.
        """
        input, = inputs
        ctx.save_for_backward(input)

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
        input, = ctx.saved_tensors
        return grad_output * 1.5 * (5 * input ** 2 - 1)


dtype = torch.float
device = (
    torch.accelerator.current_accelerator().type
    if torch.accelerator.is_available()
    else "cpu"
)

# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# Create random Tensors for weights. For this example, we need
# 4 weights: y = a + b * P3(c + d * x), these weights need to be initialized
# not too far from the correct result to ensure convergence.
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
b = torch.full((), -1.0, device=device, dtype=dtype, requires_grad=True)
c = torch.full((), 0.0, device=device, dtype=dtype, requires_grad=True)
d = torch.full((), 0.3, device=device, dtype=dtype, requires_grad=True)

learning_rate = 5e-6
for t in range(2000):
    # To apply our Function, we use Function.apply method. We alias this as 'P3'.
    P3 = LegendrePolynomial3.apply

    # Forward pass: compute predicted y using operations; we compute
    # P3 using our custom autograd operation.
    y_pred = a + b * P3(c + d * x)

    # Compute and print loss
    loss = (y_pred - y).pow(2).sum()
    if t % 100 == 99:
        print(t, loss.item())

    # Use autograd to compute the backward pass.
    loss.backward()

    # Update weights using gradient descent
    with torch.no_grad():
        a -= learning_rate * a.grad
        b -= learning_rate * b.grad
        c -= learning_rate * c.grad
        d -= learning_rate * d.grad

        # Manually zero the gradients after updating weights
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

print(f'Result: y = {a.item()} + {b.item()} * P3({c.item()} + {d.item()} x)')