Optimizing Repeated Layers with scan
and scan_layers
¶
This is a guide for using scan
and scan_layers
in PyTorch/XLA.
When should you use this¶
Consider using scan_layers
if you have a model with
many homogenous (same shape, same logic) layers, for example LLMs. These models
can be slow to compile. scan_layers
is a drop-in replacement for a for loop over
homogenous layers, such as a bunch of decoder layers. scan_layers
traces the
first layer and reuses the compiled result for all subsequent layers, significantly
reducing the model compile time.
scan
on the other hand is a lower level higher-order-op modeled after
jax.lax.scan
. Its primary purpose is to implement
scan_layers
under the hood. However, you may find it useful
to program loop logic where the loop itself has a first-class
representation in the compiler (specifically, the XLA while
op).
scan_layers
example¶
Typically, a transformer model passes the input embedding through a sequence of homogenous decoder layers:
def run_decoder_layers(self, hidden_states):
for decoder_layer in self.layers:
hidden_states = decoder_layer(hidden_states)
return hidden_states
When this function is lowered into an HLO graph, the for loop is unrolled into a
flat list of operations, resulting in long compile times. To reduce compile
times, replace the for loop with scan_layers
, as shown in
decoder_with_scan.py
:
def run_decoder_layers(self, hidden_states):
from torch_xla.experimental.scan_layers import scan_layers
return scan_layers(self.layers, hidden_states)
You can train this decoder model by running the following command from the root
directory of a pytorch/xla
source checkout.
python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan
scan
example¶
scan
takes a combine function and applies that function over the leading
dimension of tensors while carrying along state:
def scan(
fn: Callable[[Carry, X], tuple[Carry, Y]],
init: Carry,
xs: X,
) -> tuple[Carry, Y]:
...
Use it to loop over the leading dimension of tensors efficiently. If xs
is a single tensor, this function is roughly equal to the following Python code:
def scan(fn, init, xs):
ys = []
carry = init
for i in len(range(xs.size(0))):
carry, y = fn(carry, xs[i])
ys.append(y)
return carry, torch.stack(ys, dim=0)
Under the hood, scan
is implemented efficiently by lowering the loop
into an XLA while
operation. This ensures that only one iteration of the loop
is compiled by XLA.
scan_examples.py
contains some example code showing how to use
scan
. In that file, scan_example_cumsum
uses scan
to implement a cumulative
sum. scan_example_pytree
demonstrates how to pass PyTrees to scan
.
You can run the examples with:
python3 examples/scan/scan_examples.py
The output should look something like the following:
Running example: scan_example_cumsum
Final sum: tensor([6.], device='xla:0')
History of sums tensor([[1.],
[3.],
[6.]], device='xla:0')
Running example: scan_example_pytree
Final carry: {'sum': tensor([15.], device='xla:0'), 'count': tensor([5.], device='xla:0')}
Means over time: tensor([[1.0000],
[1.5000],
[2.0000],
[2.5000],
[3.0000]], device='xla:0')
Use scan cache¶
Because scan
uses AOTAutograd to figure out the backward pass of the input
function/module repeatively on every iteration, it is easy to become tracing-bound compared to
a for loop implementation. As the scan cache is implemented, users can indicate that provided function (or layers if use scan_layers
) is pure to activate the caching mechanism. The cache significantly reduced the tracing overhead especially when the the number of iterations is large. The following command trains an example decoder. It finishes in 9m35s
without the cache and 4m59s
with cache.
time python3 examples/train_decoder_only_base.py scan.decoder_with_scan.DecoderWithScan --hidden-size 128 --num-layers 2 --num-attention-heads 8 --num-key-value-heads 4 --intermediate-size 512 --num-steps 500 --print-metrics --is-decoder-layer-pure
To enable the cache, simply set is_fn_pure
(or is_layer_pure
if use scan_layers
) to True
. For example
final_carry, ys = scan(fn, init_scan, xs_scan, is_fn_pure=is_fn_pure)
scan_layers(layers, input_data, is_layer_pure=True)
Limitations¶
AOTAutograd compatibility requirement¶
The functions/modules passed to scan
and scan_layers
must be AOTAutograd
traceable. In particular, as of PyTorch/XLA 2.6, scan
and scan_layers
cannot
trace functions with custom Pallas kernels. That means if your decoder uses,
for example flash attention, then it is incompatible with scan
. We are working on
supporting this important use case.
Compile time experiments¶
To demonstrate the compile time savings, we’ll train a simple decoder with many
layers on a single TPU chip with for loops vs with scan_layers
.
Run the for loop implementation:
❯ python3 examples/train_decoder_only_base.py \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 02m57s694ms418.595us
ValueRate: 02s112ms586.097us / second
Rate: 0.054285 / second
Percentiles: 1%=023ms113.470us; 5%=023ms113.470us; 10%=023ms113.470us; 20%=023ms113.470us; 50%=54s644ms733.284us; 80%=01m03s028ms571.841us; 90%=01m03s028ms571.841us; 95%=01m03s028ms571.841us;
99%=01m03s028ms571.841us
Run the
scan_layers
implementation:
❯ python3 examples/train_decoder_only_base.py \
scan.decoder_with_scan.DecoderWithScan \
--hidden-size 256 \
--num-layers 50 \
--num-attention-heads 4 \
--num-key-value-heads 2 \
--intermediate-size 2048 \
--num-steps 5 \
--print-metrics
...
Metric: CompileTime
TotalSamples: 3
Accumulator: 29s996ms941.409us
ValueRate: 02s529ms591.388us / second
Rate: 0.158152 / second
Percentiles: 1%=018ms636.571us; 5%=018ms636.571us; 10%=018ms636.571us; 20%=018ms636.571us; 50%=11s983ms003.171us; 80%=18s995ms301.667us; 90%=18s995ms301.667us; 95%=18s995ms301.667us;
99%=18s995ms301.667us
The maximum compile time dropped from 1m03s
to 19s
by
switching to scan_layers
.
References¶
See https://github.com/pytorch/xla/issues/7253 for the design of scan
and
scan_layers
itself.
See the function doc comments of scan
and scan_layers
for details on how to use them.