Rate this Page

Control Flow - Map#

Created On: Feb 14, 2026 | Last Updated On: Feb 14, 2026

torch.map is a structured control flow operator that applies a function over the leading dimension of input tensors. It can logically be seen as implemented as follows:

def map(
    f: Callable[[PyTree, ...], PyTree],
    xs: Union[PyTree, torch.Tensor],
    *args,
):
    out = []
    for idx in range(xs.size(0)):
        xs_sliced = xs.select(0, idx)
        out.append(f(xs_sliced, *args))
    return torch.stack(out)

Warning

torch._higher_order_ops.map is a prototype feature in PyTorch. You may run into miscompiles. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Examples#

Below is an example that uses map to apply a function over a batch:

import torch
from torch._higher_order_ops import map

def f(x):
    return x.sin() + x.cos()

xs = torch.randn(3, 4, 5)  # batch of 3 tensors, each 4x5
# Applies f to each of the 3 slices
result = map(f, xs)  # returns tensor of shape [3, 4, 5]
print(result)
tensor([[[-0.5316,  0.6268,  0.8421,  1.0386,  0.4735],
         [ 1.3890, -0.5484,  1.0015,  1.1960,  1.3527],
         [ 0.9428,  1.4092,  0.6069, -1.1853, -0.6351],
         [ 0.6303,  0.7859,  1.4136,  1.2687,  1.3567]],

        [[ 0.4728,  1.4138, -1.3743,  1.3723,  1.2608],
         [-1.0398,  1.2727,  1.2964, -0.8071,  1.3535],
         [ 1.2383,  1.2231,  1.1476,  0.6243,  0.2824],
         [ 0.3139,  0.5240,  1.4049,  1.4076, -0.7916]],

        [[ 1.0722, -0.2128,  1.4060, -1.1296,  1.1801],
         [-0.3149,  1.4080,  1.4141,  1.4051, -0.5227],
         [ 0.7779,  1.3482,  1.3971,  0.6301, -1.2667],
         [ 1.3709,  1.2926,  0.8846, -0.6284,  0.0842]]])

We can export the model with map for further transformations and deployment. This example uses dynamic shapes to allow variable batch size:

class MapModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, xs: torch.Tensor) -> torch.Tensor:
        def body_fn(x):
            return x.sin() + x.cos()

        return map(body_fn, xs)

mod = MapModule()
inp = torch.randn(3, 4)
ep = torch.export.export(mod, (inp,), dynamic_shapes={"xs": {0: torch.export.Dim.DYNAMIC}})
print(ep)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, xs: "f32[s83, 4]"):
            # File: <eval_with_key>.8:7 in forward, code: map_impl = torch.ops.higher_order.map_impl(map_body_0, [l_flat_xs_0_], []);  map_body_0 = l_flat_xs_0_ = None
            body_graph_0 = self.body_graph_0
            map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], []);  body_graph_0 = xs = None
            getitem: "f32[s83, 4]" = map_impl[0];  map_impl = None
            return (getitem,)
    
        class body_graph_0(torch.nn.Module):
            def forward(self, xs: "f32[4]"):
                # File: <eval_with_key>.9:5 in forward, code: sin = child.sin()
                sin: "f32[4]" = torch.ops.aten.sin.default(xs)
    
                # File: <eval_with_key>.9:6 in forward, code: cos = child.cos();  child = None
                cos: "f32[4]" = torch.ops.aten.cos.default(xs);  xs = None
    
                # File: <eval_with_key>.9:7 in forward, code: add = sin + cos;  sin = cos = None
                add: "f32[4]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
                return (add,)
    
Graph signature: 
    # inputs
    xs: USER_INPUT
    
    # outputs
    getitem: USER_OUTPUT
    
Range constraints: {s83: VR[2, int_oo]}

Notice that torch.map is lowered to torch.ops.higher_order.map_impl, and the body function becomes a sub-graph attribute of the top-level graph module.

Restrictions#

  • Mapped xs can only consist of tensors.

  • Leading dimensions of all tensors in xs must be consistent and non-zero.

  • The body function must not mutate inputs.

API Reference#

torch._higher_order_ops.map.map(f, xs, *args)[source]#

Performs a map of f with xs. Intuitively, you can think of the semantic being:

out = []
for idx in len(xs.size(0)):
    xs_sliced = xs.select(0, idx)
    out.append(f(xs_sliced, *args))
torch.stack(out)

Warning

torch._higher_order_ops.map is a prototype feature in PyTorch. It currently does not support autograd and you may run into miscompiles. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype

Parameters:
  • f (Callable) – a callable that takes an input x, that could either be a single Tensor or a nested dict, list of tensors and some additional inputs

  • xs (Any | Tensor) – the inputs that’re to be mapped over. We’ll iterate over the first dim of each x and perform f on each slice.

  • *args (TypeVarTuple) – additional arguments provided to each step of f. They could also be omitted and map is able to automatically figure out the read dependency.

Returns:

the stacked output for each step of f

Example:

def f(xs):
    return xs[0] + xs[1] + const1 + const2


xs = [torch.randn(2, 3), torch.randn(2, 3)]
const1 = torch.randn(2, 3)
const2 = torch.randn(2, 3)
# returns a tensor of shape [2, 2, 3]
torch._higher_order_ops.map(f, xs)