Custom GPU Kernels via Triton¶
PyTorch/XLA now supports Triton kernels, enabling high-performance deep learning model execution on GPUs. Triton, a specialized language and compiler for GPU programming, empowers developers to write custom kernels that leverage the full potential of GPUs for various operations in deep learning models.
Given a Triton kernel defined as follows:
@triton.jit
def add_kernel(
x_ptr, # *Pointer* to first input vector.
y_ptr, # *Pointer* to second input vector.
output_ptr, # *Pointer* to output vector.
n_elements, # Size of the vector.
BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.
# NOTE: `constexpr` so it can be used as a shape value.
):
# Triton add kernel from https://github.com/openai/triton/blob/main/python/tutorials/01-vector-add.py#L28
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
We can run make this kernel a part of the PyTorch/XLA execution graph as follows:
import torch
import torch_xla.experimental.triton as xla_triton
import torch_xla
import triton
import triton.language as tl
size = 16
x = torch.arange(size, dtype=torch.int64).to('xla')
y = torch.arange(size, dtype=torch.int64).to('xla')
output = torch.empty_like(x)
block_size = 8
grid = (triton.cdiv(size, block_size),)
# triton_call takes the same arguments as the triton.jit function, in addition
# to the kernel itself and the grid that is used to execute the kernel.
# All the tl.constexpr terms are passed as kwargs at the end.
payload = xla_triton.triton_call(
x, y, output, size, kernel=add_kernel, grid=grid, BLOCK_SIZE=block_size)
# To make the triton kernel, a part of the PyTorch/XLA graph, we create a
# custom call node with the expected inputs, payload from triton_call,
# the output shapes and output dtypes. The payload already contains information
# regarding how the GPU buffers will be loaded when this node is executed.
output = torch_xla._XLAC._xla_gpu_custom_call([x, y], payload,
[output.shape], [torch.int64])
For more complex kernels, you can also refer to the Triton Flash Attention kernel test in PyTorch/XLA.
Dependencies¶
The Triton integration depends on the triton
package to function. This
code is tested with triton==2.3.0
. To install:
pip install --no-deps triton==2.3.0