• Docs >
  • Speed Up Tracing with @assume_pure
Shortcuts

Speed Up Tracing with @assume_pure

This document explains how to use torch_xla.experimental.assume_pure to eliminate lazy tensor tracing overhead. See this blog post for a primer on how lazy tensor tracing (operation recording) works.

Background and motivation

PyTorch/XLA’s lazy tensor tracing ensures correct execution by recording an operation graph (lazy tensor IR) when running PyTorch operations. For complex models, this tracing overhead can exceed the execution time of the graph, leading to performance bottlenecks. When training a model, the layers in the model must be re-traced on every training step. That’s because there’s no guarantee that the layers will do the same thing in different training steps. As an extreme example, a layer’s forward() function may call math.random() and decide what code to run based on a pseudo random number.

Re-tracing can introduce unnecessary overhead. In many cases, the layers in your model will do exactly the same thing when given the same input tensor shapes. In other words, given the same input, the function return the same output. Often, the layers also will not will not perform side-effects such as saving the tensor to a file or adding it to a global list. Such functions are called “pure functions”.

Any PyTorch/XLA function decorated with @assume_pure will only be traced once for each unique input tensor shape and dtype combination. PyTorch/XLA will cache the traced computation instead of repeatedly tracing the same operations.

How to use @assume_pure

Using @assume_pure with a function

If you know your function is pure, decorate your function with @assume_pure:

import torch
import torch_xla
from torch_xla.experimental.assume_pure import assume_pure

@assume_pure
def do_some_math(
    # You can pass any number of XLA tensors.
    a: torch.Tensor,
    b: torch.Tensor,

    # Non-tensor arguments are also supported, and passing different values will
    # trigger re-tracing and caching more computations.
    c: int,
):
    # Evaluate some pure expressions.
    return a @ b + c

# Simulate a training loop.
# Even if we run this function ten times, it will only be traced once.
for i in range(10):
    v = do_some_math(
        torch.tensor([1.0], device='xla'),
        torch.tensor([2.0], device='xla'),
        c=42,
    )
    print(v)

Using @assume_pure with a nn.Module

If you have a pure nn.Module i.e. its forward behavior only depends on the input arguments and the model parameters, we can use torch.func.functional_call to convert the module into a pure function and pass that to assume_pure:

import torch
import torch.nn as nn
from torch.func import functional_call
from torch_xla.experimental.assume_pure import assume_pure

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 10)
    def forward(self, x):
        return self.linear(x)

# Create module and move to XLA device
module = MyModule()
module = module.to('xla')

# Convert module's forward pass into a pure function
pure_forward = lambda params, buffers, x: functional_call(module, (params, buffers), (x,))

# Wrap the pure function with @assume_pure
cached_forward = assume_pure(pure_forward)

# Simulate a training loop
# Even if we run the model ten times, its forward function will only be traced once.
params = dict(module.named_parameters())
buffers = dict(module.named_buffers())
for i in range(10):
    x = torch.randn(5, 10, device='xla')
    y = cached_forward(params, buffers, x)
    print(y)

Benchmarks

The unit tests contain a benchmark that traces an example 100 layer decoder-only language model:

~/pytorch/xla
❯ TESTBRIDGE_TEST_ONLY=test_trace_transformer_with_spda_attention python3 test/test_assume_pure.py --benchmark_iterations 100
[...]
No `@assume_pure` time: 140.1342 ms
`@assume_pure` time: 24.1658 ms

The version with @assume_pure is much faster.

Importantly, the @assume_pure running time does not scale with increasing complexity inside the model. That’s because we only trace the model once, paying a fixed up-front cost, and then later runs will reuse the cached XLA computation.

Limitations

Currently, all operations in a function wrapped with @assume_pure must be PyTorch upstream operations (e.g. torch.einsum, torch.sin, …), or these PyTorch/XLA operations:

  • torch_xla.experimental.assume_pure (recursive assume_pure)

  • torch_xla.distributed.spmd.mark_sharding

More PyTorch/XLA operations (e.g. flash_attention) will be supported in the future.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources