Rate this Page

torch.compiler.wrap_numpy#

torch.compiler.wrap_numpy(fn)[source]#

Decorator that turns a function from np.ndarrays to np.ndarrays into a function from torch.Tensors to torch.Tensors.

It is designed to be used with torch.compile() with fullgraph=True. It allows to compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code on CUDA or compute its gradients.

Note

This decorator does not work without torch.compile().

Example:

>>> # Compile a NumPy function as a Tensor -> Tensor function
>>> @torch.compile(fullgraph=True)
>>> @torch.compiler.wrap_numpy
>>> def fn(a: np.ndarray):
>>>     return np.sum(a * a)
>>> # Execute the NumPy function using Tensors on CUDA and compute the gradients
>>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True)
>>> out = fn(x)
>>> out.backward()
>>> print(x.grad)
tensor([ 0.,  2.,  4.,  6.,  8., 10.], device='cuda:0')