Extending PyTorch

In this note we’ll cover ways of extending torch.nn, torch.autograd, and writing custom C extensions utilizing our C libraries.

Extending torch.autograd

Adding operations to autograd requires implementing a new Function subclass for each operation. Recall that Function s are what autograd uses to compute the results and gradients, and encode the operation history. Every new function requires you to implement 3 methods:

  • __init__ (optional) - if your operation is parametrized by/uses objects different than Variable s, you should pass them as arguments to __init__. For example, AddConstant function takes a scalar to add, while Transpose requires specifying which two dimensions to swap. If your function doesn’t require any additional parameters, you can skip it.
  • forward() - the code that performs the operation. It can take as many arguments as you want, with some of them being optional, if you specify the default values. Keep in mind that only Variable s will be passed in here. You can return either a single Variable output, or a tuple of Variable s if there are multiple. Also, please refer to the docs of Function to find descriptions of useful methods that can be called only from forward().
  • backward() - gradient formula. It will be given as many arguments as there were outputs, with each of them representing gradient w.r.t. that output. It should return as many Tensor s as there were inputs, with each of them containing the gradient w.r.t. corresponding input. If you inputs didn’t require gradient (see needs_input_grad), or it was non-differentiable, you can return None. Also, if you have optional arguments to forward() you can return more gradients than there were inputs, as long as they’re all None.

Below you can find code for a Linear function from torch.nn, with additional comments:

# Inherit from Function
class Linear(Function):

    # bias is an optional argument
    def forward(self, input, weight, bias=None):
        self.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    def backward(self, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = self.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if self.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if self.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and self.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias

Now, to make it easier to use these custom ops, we recommend wrapping them in small helper functions:

def linear(input, weight, bias=None):
    # First braces create a Function object. Any arguments given here
    # will be passed to __init__. Second braces will invoke the __call__
    # operator, that will then use forward() to compute the result and
    # return it.
    return Linear()(input, weight, bias)

You probably want to check if the backward method you implemented actually computes the derivatives of your function. It is possible by comparing with numerical approximations using small finite differences:

from torch.autograd import gradcheck

# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
input = (Variable(torch.randn(20,20).double(), requires_grad=True),)
test = gradcheck(Linear(), input, eps=1e-6, atol=1e-4)
print(test)

Extending torch.nn

nn exports two kinds of interfaces - modules and their functional versions. You can extend it in both ways, but we recommend using modules for all kinds of layers, that hold any parameters or buffers, and recommend using a functional form parameter-less operations like activation functions, pooling, etc.

Adding a functional version of an operation is already fully covered in the section above.

Adding a Module

Since nn heavily utilizes autograd, adding a new Module requires implementing a Function that performs the operation and can compute the gradient. From now on let’s assume that we want to implement a Linear module and we have the function implementated as in the listing above. There’s very little code required to add this. Now, there are two functions that need to be implemented:

  • __init__ (optional) - takes in arguments such as kernel sizes, numbers of features, etc. and initializes parameters and buffers.
  • forward() - instantiates a Function and uses it to perform the operation. It’s very similar to a functional wrapper shown above.

This is how a Linear module can be implemented:

class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        self.input_features = input_features
        self.output_features = output_features

        # nn.Parameter is a special kind of Variable, that will get
        # automatically registered as Module's parameter once it's assigned
        # as an attribute. Parameters and buffers need to be registered, or
        # they won't appear in .parameters() (doesn't apply to buffers), and
        # won't be converted when e.g. .cuda() is called. You can use
        # .register_buffer() to register buffers.
        # nn.Parameters can never be volatile and, different than Variables,
        # they require gradients by default.
        self.weight = nn.Parameter(torch.Tensor(input_features, output_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            # You should always register all possible parameters, but the
            # optional ones can be None if you want.
            self.register_parameter('bias', None)

        # Not a very smart way to initialize weights
        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return Linear()(input, self.weight, self.bias)

Writing custom C extensions

Coming soon. For now you can find an example at GitHub.