.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "advanced/numpy_extensions_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_advanced_numpy_extensions_tutorial.py: Creating Extensions Using NumPy and SciPy ========================================= **Author**: `Adam Paszke `_ **Updated by**: `Adam Dziedzic `_ In this tutorial, we shall go through two tasks: 1. Create a neural network layer with no parameters. - This calls into **numpy** as part of its implementation 2. Create a neural network layer that has learnable weights - This calls into **SciPy** as part of its implementation .. GENERATED FROM PYTHON SOURCE LINES 19-23 .. code-block:: Python import torch from torch.autograd import Function .. GENERATED FROM PYTHON SOURCE LINES 24-33 Parameter-less example ---------------------- This layer doesn’t particularly do anything useful or mathematically correct. It is aptly named ``BadFFTFunction`` **Layer Implementation** .. GENERATED FROM PYTHON SOURCE LINES 33-57 .. code-block:: Python from numpy.fft import rfft2, irfft2 class BadFFTFunction(Function): @staticmethod def forward(ctx, input): numpy_input = input.detach().numpy() result = abs(rfft2(numpy_input)) return input.new(result) @staticmethod def backward(ctx, grad_output): numpy_go = grad_output.numpy() result = irfft2(numpy_go) return grad_output.new(result) # since this layer does not have any parameters, we can # simply declare this as a function, rather than as an ``nn.Module`` class def incorrect_fft(input): return BadFFTFunction.apply(input) .. GENERATED FROM PYTHON SOURCE LINES 58-59 **Example usage of the created layer:** .. GENERATED FROM PYTHON SOURCE LINES 59-66 .. code-block:: Python input = torch.randn(8, 8, requires_grad=True) result = incorrect_fft(input) print(result) result.backward(torch.randn(result.size())) print(input) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 5.6235, 6.7740, 8.4758, 8.1919, 2.1324], [ 0.6544, 1.2237, 8.7218, 8.8866, 7.3615], [15.6006, 9.3279, 11.7777, 3.5517, 6.5761], [12.6021, 7.7603, 11.4339, 11.5106, 7.6696], [ 0.8482, 4.9518, 2.0096, 3.3502, 0.7314], [12.6021, 7.8861, 10.5822, 8.0748, 7.6696], [15.6006, 4.0719, 4.0253, 3.4041, 6.5761], [ 0.6544, 14.1585, 7.4716, 10.2084, 7.3615]], grad_fn=) tensor([[ 1.4118, 0.4746, 2.5304, -0.3566, -0.2607, 0.7273, 0.2161, 1.1891], [ 1.5440, 0.6967, 0.0048, -0.0358, -0.7080, 1.3881, -0.4777, 0.3856], [-1.7932, -0.5466, 0.4830, -0.6864, -0.1794, 0.0249, -1.1434, -0.7422], [ 1.5912, 0.1402, -0.4471, 0.5496, 2.2869, -0.6024, -1.2530, -0.5738], [ 1.4356, -0.9356, -1.7598, -0.4742, 1.1706, 0.0035, 0.7622, 1.2726], [-0.7856, 0.3806, 1.4863, -0.9791, 1.0701, -0.0978, 1.6235, 0.8385], [ 0.3401, 0.6180, -1.7912, 0.0830, 0.0402, 0.8822, 0.0818, -0.6901], [ 0.0297, -0.1391, -0.2911, -2.1535, -1.6773, 0.1310, -1.6628, 0.9732]], requires_grad=True) .. GENERATED FROM PYTHON SOURCE LINES 67-79 Parametrized example -------------------- In deep learning literature, this layer is confusingly referred to as convolution while the actual operation is cross-correlation (the only difference is that filter is flipped for convolution, which is not the case for cross-correlation). Implementation of a layer with learnable weights, where cross-correlation has a filter (kernel) that represents weights. The backward pass computes the gradient ``wrt`` the input and the gradient ``wrt`` the filter. .. GENERATED FROM PYTHON SOURCE LINES 79-120 .. code-block:: Python from numpy import flip import numpy as np from scipy.signal import convolve2d, correlate2d from torch.nn.modules.module import Module from torch.nn.parameter import Parameter class ScipyConv2dFunction(Function): @staticmethod def forward(ctx, input, filter, bias): # detach so we can cast to NumPy input, filter, bias = input.detach(), filter.detach(), bias.detach() result = correlate2d(input.numpy(), filter.numpy(), mode='valid') result += bias.numpy() ctx.save_for_backward(input, filter, bias) return torch.as_tensor(result, dtype=input.dtype) @staticmethod def backward(ctx, grad_output): grad_output = grad_output.detach() input, filter, bias = ctx.saved_tensors grad_output = grad_output.numpy() grad_bias = np.sum(grad_output, keepdims=True) grad_input = convolve2d(grad_output, filter.numpy(), mode='full') # the previous line can be expressed equivalently as: # grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full') grad_filter = correlate2d(input.numpy(), grad_output, mode='valid') return torch.from_numpy(grad_input), torch.from_numpy(grad_filter).to(torch.float), torch.from_numpy(grad_bias).to(torch.float) class ScipyConv2d(Module): def __init__(self, filter_width, filter_height): super(ScipyConv2d, self).__init__() self.filter = Parameter(torch.randn(filter_width, filter_height)) self.bias = Parameter(torch.randn(1, 1)) def forward(self, input): return ScipyConv2dFunction.apply(input, self.filter, self.bias) .. GENERATED FROM PYTHON SOURCE LINES 121-122 **Example usage:** .. GENERATED FROM PYTHON SOURCE LINES 122-131 .. code-block:: Python module = ScipyConv2d(3, 3) print("Filter and bias: ", list(module.parameters())) input = torch.randn(10, 10, requires_grad=True) output = module(input) print("Output from the convolution: ", output) output.backward(torch.randn(8, 8)) print("Gradient for the input map: ", input.grad) .. rst-class:: sphx-glr-script-out .. code-block:: none Filter and bias: [Parameter containing: tensor([[-0.3271, -2.0505, -0.2789], [-0.8936, 0.7400, 0.1147], [ 1.1680, 2.1045, 0.3120]], requires_grad=True), Parameter containing: tensor([[-1.7226]], requires_grad=True)] Output from the convolution: tensor([[ 0.7337, -4.8225, -3.3873, -0.9430, -1.1647, -0.1573, -3.0916, -3.2917], [-2.2022, -0.5409, 5.2196, 2.6951, -0.1417, -3.4234, 1.7732, -0.7400], [-8.0354, 0.2321, 6.5192, -2.1820, -4.6774, -1.2630, 2.6269, 1.5937], [-2.4494, -2.1779, -7.3569, -4.6962, 0.6701, -1.4932, -5.5694, -2.8525], [-0.5736, -9.8618, -7.8595, -4.9695, -4.9647, -3.7938, -6.8478, -4.9518], [ 0.3105, 1.3747, 1.7348, 2.3005, 2.2008, 2.1138, 2.9398, 2.2056], [-3.4970, 1.5987, 2.5571, 7.9044, 3.4645, 1.0564, 1.1910, -1.2174], [-4.0318, -0.6134, -0.1584, -3.6865, -6.4534, -2.0046, -1.2485, 2.1961]], grad_fn=) Gradient for the input map: tensor([[ 0.3849, 2.6197, 1.6585, 0.4479, 0.5675, 0.9848, -2.3317, 1.4914, 3.9742, 0.5174], [ 0.8583, -1.4710, -0.2439, 0.6531, -0.2770, -0.6077, -0.7289, -1.2579, -2.0810, -0.2572], [-1.3233, 0.8060, -2.8457, -6.2144, -2.9105, -3.2937, 0.9284, 1.4736, -5.5423, -0.8690], [ 1.8381, -3.9095, -5.5275, -0.6487, -1.2372, -0.2035, 7.4600, 3.6750, 1.4586, 0.1452], [-3.2526, -3.0768, 5.9892, 10.4888, 3.6597, 0.1161, -0.2891, -3.0995, 6.5086, 1.0216], [ 2.1941, 8.8648, 6.2531, -1.8756, -0.0673, -2.1077, -3.6290, 0.9347, 0.4615, 0.0633], [ 1.3973, -3.7327, -6.9368, -1.6595, 2.9632, 5.5094, -1.7956, -0.9187, -6.9448, -1.0489], [-2.8219, -3.5799, 4.5271, 1.4632, -2.8899, -1.2257, 1.1567, -5.1378, -2.0207, -0.2228], [ 1.1449, 3.7648, -0.3085, -5.4121, -2.6019, 0.4138, 1.9015, -0.9819, 0.9172, 0.1887], [-0.2935, -2.1723, -3.1014, 1.8279, 5.1196, 2.4966, 1.1638, 0.5475, -0.1992, -0.0432]]) .. GENERATED FROM PYTHON SOURCE LINES 132-133 **Check the gradients:** .. GENERATED FROM PYTHON SOURCE LINES 133-141 .. code-block:: Python from torch.autograd.gradcheck import gradcheck moduleConv = ScipyConv2d(3, 3) input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)] test = gradcheck(moduleConv, input, eps=1e-6, atol=1e-4) print("Are the gradients correct: ", test) .. rst-class:: sphx-glr-script-out .. code-block:: none Are the gradients correct: True .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.587 seconds) .. _sphx_glr_download_advanced_numpy_extensions_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: numpy_extensions_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: numpy_extensions_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: numpy_extensions_tutorial.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_