Control Flow - Scan#
Created On: Feb 14, 2026 | Last Updated On: Feb 14, 2026
torch.scan is a structured control flow operator that performs an inclusive scan with a combine function.
It is commonly used for cumulative operations like cumsum, cumprod, or more general recurrences.
It can logically be seen as implemented as follows:
def scan(
combine_fn: Callable[[PyTree, PyTree], tuple[PyTree, PyTree]],
init: PyTree,
xs: PyTree,
*,
dim: int = 0,
reverse: bool = False,
) -> tuple[PyTree, PyTree]:
carry = init
ys = []
for i in range(xs.size(dim)):
x_slice = xs.select(dim, i)
carry, y = combine_fn(carry, x_slice)
ys.append(y)
return carry, torch.stack(ys)
Warning
torch.scan is a prototype feature in PyTorch. You may run into miscompiles.
Read more about feature classification at:
https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples#
Below is an example that uses scan to compute a cumulative sum:
import torch
from torch._higher_order_ops import scan
def add(carry: torch.Tensor, x: torch.Tensor):
next_carry = carry + x
y = next_carry.clone() # clone to avoid output-output aliasing
return next_carry, y
init = torch.zeros(1)
xs = torch.arange(5, dtype=torch.float32)
final_carry, cumsum = scan(add, init=init, xs=xs)
print(final_carry)
print(cumsum)
tensor([10.])
tensor([[ 0.],
[ 1.],
[ 3.],
[ 6.],
[10.]])
We can export the model with scan for further transformations and deployment. This example uses dynamic shapes to allow variable sequence length:
class ScanModule(torch.nn.Module):
def forward(self, xs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def combine_fn(carry, x):
next_carry = carry + x
return next_carry, next_carry.clone()
init = torch.zeros_like(xs[0])
return scan(combine_fn, init=init, xs=xs)
mod = ScanModule()
inp = torch.randn(5, 3)
ep = torch.export.export(mod, (inp,), dynamic_shapes={"xs": {0: torch.export.Dim.DYNAMIC}})
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, xs: "f32[s83, 3]"):
# File: /tmp/ipykernel_316/1825554836.py:7 in forward, code: init = torch.zeros_like(xs[0])
select: "f32[3]" = torch.ops.aten.select.int(xs, 0, 0)
zeros_like: "f32[3]" = torch.ops.aten.zeros_like.default(select, pin_memory = False); select = None
# File: <eval_with_key>.8:8 in forward, code: scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_leaves_init_0_], [l_leaves_xs_0_], []); scan_combine_fn_0 = l_leaves_init_0_ = l_leaves_xs_0_ = None
scan_combine_graph_0 = self.scan_combine_graph_0
scan = torch.ops.higher_order.scan(scan_combine_graph_0, [zeros_like], [xs], ()); scan_combine_graph_0 = zeros_like = xs = None
getitem: "f32[3]" = scan[0]
getitem_1: "f32[s83, 3]" = scan[1]; scan = None
return (getitem, getitem_1)
class scan_combine_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[3]", arg1_1: "f32[3]"):
# File: <eval_with_key>.9:5 in forward, code: next_carry = child + child_1; child = child_1 = None
add: "f32[3]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# File: <eval_with_key>.9:6 in forward, code: child_2 = next_carry.clone()
clone: "f32[3]" = torch.ops.aten.clone.default(add)
return [add, clone]
Graph signature:
# inputs
xs: USER_INPUT
# outputs
getitem: USER_OUTPUT
getitem_1: USER_OUTPUT
Range constraints: {s83: VR[2, int_oo]}
Notice that the combine function becomes a sub-graph attribute of the top-level graph module.
Restrictions#
combine_fnmust return tensors with the same metadata (shape, dtype) fornext_carryasinit.combine_fnmust not in-place mutate its inputs. A clone before mutation is required.combine_fnmust not mutate Python variables (e.g., list/dict) created outside the function.combine_fn’s output cannot alias any of the inputs. A clone is required.
API Reference#
- torch._higher_order_ops.scan.scan(combine_fn, init, xs, *, dim=0, reverse=False)[source]#
Performs an inclusive scan with a combine function.
Warning
torch.scanis a prototype feature in PyTorch. You may run into miscompiles. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype- Parameters:
combine_fn (Callable) – A binary callable with type
(Tensor, Tensor) -> (Tensor, Tensor), or if xs is a pytree(pytree, pytree) -> (pytree, pytree). The first input tocombine_fnis the previous or initial scan carry and the second input element tocombine_fnis a slice of the input along dim. The first output element ofcombine_fnis the next scan carry and the second output ofcombine_fnrepresents a slice of the output. This function must be pure, i.e., no lifted arguments are supported at the moment and may not have any side effects.init (torch.Tensor or pytree with tensor leaves) – The initial scan carry, a tensor, or nested pytree of tensors. The
initis expected to have the same pytree structure as the first output element (i.e. carry) ofcombine_fn.xs (torch.Tensor or pytree with tensor leaves) – The input tensor, or nested pytree of tensors.
- Return type:
- Kwargs:
dim (int): the dimension to scan over, default 0. reverse (bool): A boolean stating if the scan should be reversed with respect to
dim, defaultFalse.
- Returns:
- final_carry (torch.Tensor or pytree with tensor leaves),
the final carry of the scan operation with same pytree structure as init.
- out (torch.Tensor or pytree with tensor leaves),
each tensor leaf is a stacked output along first dim, where each slice is the output of a scan iteration.
- Return type:
- Restrictions:
- The combine_fn shouldn’t have any aliasing between input-input, input-output, and output-output. E.g. return a view
or the same tensor as input is not supported. As a workaround, can clone the output to avoid aliasing.
- The combine_fn shouldn’t mutate any inputs. We’ll remove the mutation restriction for inference soon. Please file an issue
if you input mutation support for training is needed.
The combine_fn’s init carry should match the next_carry in pytree structure and in tensor metadata.
Example:
def add(x: torch.Tensor, y: torch.Tensor): next_carry = y = x + y # clone the output to avoid output-output aliasing return next_carry, y.clone() i0 = torch.zeros(1) xs = torch.arange(5) # returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]]) last_carry, cumsum = scan(add, init=i0, xs=xs)