torch.compiler.wrap_numpy#
- torch.compiler.wrap_numpy(fn)[source]#
Decorator that turns a function from
np.ndarrays tonp.ndarrays into a function fromtorch.Tensors totorch.Tensors.It is designed to be used with
torch.compile()withfullgraph=True. It allows to compile a NumPy function as if it were a PyTorch function. This allows you to run NumPy code on CUDA or compute its gradients.Note
This decorator does not work without
torch.compile().Example:
>>> # Compile a NumPy function as a Tensor -> Tensor function >>> @torch.compile(fullgraph=True) >>> @torch.compiler.wrap_numpy >>> def fn(a: np.ndarray): >>> return np.sum(a * a) >>> # Execute the NumPy function using Tensors on CUDA and compute the gradients >>> x = torch.arange(6, dtype=torch.float32, device="cuda", requires_grad=True) >>> out = fn(x) >>> out.backward() >>> print(x.grad) tensor([ 0., 2., 4., 6., 8., 10.], device='cuda:0')