Control Flow - Map#
Created On: Feb 14, 2026 | Last Updated On: Feb 14, 2026
torch.map is a structured control flow operator that applies a function over the leading dimension of
input tensors. It can logically be seen as implemented as follows:
def map(
f: Callable[[PyTree, ...], PyTree],
xs: Union[PyTree, torch.Tensor],
*args,
):
out = []
for idx in range(xs.size(0)):
xs_sliced = xs.select(0, idx)
out.append(f(xs_sliced, *args))
return torch.stack(out)
Warning
torch._higher_order_ops.map is a prototype feature in PyTorch. You may run into miscompiles.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples#
Below is an example that uses map to apply a function over a batch:
import torch
from torch._higher_order_ops import map
def f(x):
return x.sin() + x.cos()
xs = torch.randn(3, 4, 5) # batch of 3 tensors, each 4x5
# Applies f to each of the 3 slices
result = map(f, xs) # returns tensor of shape [3, 4, 5]
print(result)
tensor([[[ 0.0980, 0.2002, 1.0139, 1.2563, -1.2498],
[ 1.2818, 0.1662, 0.1792, 1.1624, 0.5724],
[ 0.7773, -1.1511, 1.3047, 1.4084, 0.9477],
[ 1.2964, -0.9956, 0.4071, 1.1873, 1.3374]],
[[ 0.9802, 1.2030, -0.7410, 1.3835, 0.8395],
[-1.2970, 0.0224, 1.1744, -0.1499, 0.2326],
[ 0.6750, 0.6006, 0.4395, 0.6112, 1.4117],
[-0.4250, 1.1176, -1.3966, 0.6573, -1.2978]],
[[ 0.3239, -0.0327, 1.2138, -0.1304, 1.0815],
[-0.8843, 1.3024, 0.1844, 0.7046, 1.0315],
[ 0.9116, -0.8725, 1.4142, 0.9328, 1.2899],
[ 0.4364, -0.9016, -1.2707, 0.6944, 1.3802]]])
We can export the model with map for further transformations and deployment. This example uses dynamic shapes to allow variable batch size:
class MapModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, xs: torch.Tensor) -> torch.Tensor:
def body_fn(x):
return x.sin() + x.cos()
return map(body_fn, xs)
mod = MapModule()
inp = torch.randn(3, 4)
ep = torch.export.export(mod, (inp,), dynamic_shapes={"xs": {0: torch.export.Dim.DYNAMIC}})
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, xs: "f32[s83, 4]"):
# File: <eval_with_key>.8:7 in forward, code: map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_flat_xs_0_], []); map_body_0 = l_flat_xs_0_ = None
body_graph_0 = self.body_graph_0
map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], []); body_graph_0 = xs = None
getitem: "f32[s83, 4]" = map_impl[0]; map_impl = None
return (getitem,)
class body_graph_0(torch.nn.Module):
def forward(self, xs: "f32[4]"):
# File: <eval_with_key>.9:5 in forward, code: sin = child.sin()
sin: "f32[4]" = torch.ops.aten.sin.default(xs)
# File: <eval_with_key>.9:6 in forward, code: cos = child.cos(); child = None
cos: "f32[4]" = torch.ops.aten.cos.default(xs); xs = None
# File: <eval_with_key>.9:7 in forward, code: add = sin + cos; sin = cos = None
add: "f32[4]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None
return (add,)
Graph signature:
# inputs
xs: USER_INPUT
# outputs
getitem: USER_OUTPUT
Range constraints: {s83: VR[2, int_oo]}
Notice that torch.map is lowered to torch.ops.higher_order.map_impl, and the body function becomes a
sub-graph attribute of the top-level graph module.
Restrictions#
Mapped
xscan only consist of tensors.Leading dimensions of all tensors in
xsmust be consistent and non-zero.The body function must not mutate inputs.
API Reference#
- torch._higher_order_ops.map.map(f, xs, *args)[source]#
Performs a map of f with xs. Intuitively, you can think of the semantic being:
out = [] for idx in len(xs.size(0)): xs_sliced = xs.select(0, idx) out.append(f(xs_sliced, *args)) torch.stack(out)
Warning
torch._higher_order_ops.mapis a prototype feature in PyTorch. It currently does not support autograd and you may run into miscompiles. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype- Parameters:
f (Callable) – a callable that takes an input x, that could either be a single Tensor or a nested dict, list of tensors and some additional inputs
xs (Any | Tensor) – the inputs that’re to be mapped over. We’ll iterate over the first dim of each x and perform f on each slice.
*args (TypeVarTuple) – additional arguments provided to each step of f. They could also be omitted and map is able to automatically figure out the read dependency.
- Returns:
the stacked output for each step of f
Example:
def f(xs): return xs[0] + xs[1] + const1 + const2 xs = [torch.randn(2, 3), torch.randn(2, 3)] const1 = torch.randn(2, 3) const2 = torch.randn(2, 3) # returns a tensor of shape [2, 2, 3] torch._higher_order_ops.map(f, xs)