Control Flow - While Loop#
Created On: Feb 14, 2026 | Last Updated On: Feb 14, 2026
torch.while_loop is a structured control flow operator that runs a body function while a condition is true.
It can logically be seen as implemented as follows:
def while_loop(
cond_fn: Callable[..., bool],
body_fn: Callable[..., tuple],
carried_inputs: tuple,
):
val = carried_inputs
while cond_fn(*val):
val = body_fn(*val)
return val
Warning
torch.while_loop is a prototype feature in PyTorch. It has limited support for input and output types.
Please look forward to a more stable implementation in a future version of PyTorch.
Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
Examples#
Below is a basic example that uses while_loop to iterate until a condition is met:
import torch
from torch._higher_order_ops import while_loop
class M(torch.nn.Module):
def cond_fn(self, iter_count, x):
return iter_count.sum() > 0
def body_fn(self, iter_count, x):
return iter_count - 1, x * 2
def forward(self, init_iter, init_x):
final_iter, final_x = while_loop(self.cond_fn, self.body_fn, (init_iter, init_x))
return final_iter, final_x
m = M()
We can eagerly run the model and expect the results vary based on input shape:
_, final_x = m(torch.tensor([3]), torch.ones(3))
assert torch.equal(final_x, torch.ones(3) * 2**3)
_, final_x = m(torch.tensor([10]), torch.ones(3))
assert torch.equal(final_x, torch.ones(3) * 2**10)
We can export the model for further transformations and deployment. This gives us an exported program that preserves the while_loop structure:
ep = torch.export.export(M(), (torch.tensor([10]), torch.ones(3)))
print(ep)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, init_iter: "i64[1]", init_x: "f32[3]"):
# File: <eval_with_key>.13:9 in forward, code: while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_args_2_0_, l_args_2_1_), ()); cond_fn_0 = body_fn_0 = l_args_2_0_ = l_args_2_1_ = None
while_loop_cond_graph_0 = self.while_loop_cond_graph_0
while_loop_body_graph_0 = self.while_loop_body_graph_0
while_loop = torch.ops.higher_order.while_loop(while_loop_cond_graph_0, while_loop_body_graph_0, (init_iter, init_x), ()); while_loop_cond_graph_0 = while_loop_body_graph_0 = init_iter = init_x = None
getitem: "i64[1]" = while_loop[0]
getitem_1: "f32[3]" = while_loop[1]; while_loop = None
return (getitem, getitem_1)
class while_loop_cond_graph_0(torch.nn.Module):
def forward(self, arg0_1: "i64[1]", arg1_1: "f32[3]"):
# File: <eval_with_key>.14:5 in forward, code: sum_1 = child.sum(); child = None
sum_1: "i64[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None
# File: <eval_with_key>.14:6 in forward, code: gt = sum_1 > 0; sum_1 = None
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None
return gt
class while_loop_body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "i64[1]", arg1_1: "f32[3]"):
# File: <eval_with_key>.15:5 in forward, code: child = child_2 - 1; child_2 = None
sub: "i64[1]" = torch.ops.aten.sub.Tensor(arg0_1, 1); arg0_1 = None
# File: <eval_with_key>.15:6 in forward, code: child_4 = child_3 * 2; child_3 = None
mul: "f32[3]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
return (sub, mul)
Graph signature:
# inputs
init_iter: USER_INPUT
init_x: USER_INPUT
# outputs
getitem: USER_OUTPUT
getitem_1: USER_OUTPUT
Range constraints: {}
Notice that both the condition and body functions become sub-graph attributes of the top-level graph module.
Restrictions#
body_fnmust return tensors or integers with the same metadata (shape, dtype) as inputs.body_fnandcond_fnmust not in-place mutate thecarried_inputs. A clone before mutation is required.body_fnandcond_fnmust not mutate Python variables (e.g., list/dict) created outside the function.body_fnandcond_fn’s output cannot alias any of the inputs. A clone is required.
API Reference#
- torch._higher_order_ops.while_loop.while_loop(cond_fn, body_fn, carried_inputs)[source]#
Run
body_fn(*carried_inputs)whilecond_fn(*carried_inputs)returns a True scalar tensor. Returns the output of body_fn or initial carried_inputs.Warning
torch.while_loop is a prototype feature in PyTorch. It has limited support for input and output types and doesn’t support training currently. Please look forward to a more stable implementation in a future version of PyTorch. Read more about feature classification at: https://pytorch.org/blog/pytorch-feature-classification-changes/#prototype
while_loop is a structured control flow operator. It preserves the loop semantic across the torch.compile and torch.export.
while_loop is equivalent to the following:
def while_loop(cond_fn, body_fn, carried_inputs): val = carried_inputs while cond_fn(*val): val = body_fn(*val) return val
- Parameters:
cond_fn (Callable) – A callable function that returns a boolean Scalar tensor or a python boolean.
body_fn (Callable) – A callable function that takes the same inputs as cond_fn and returns a tuple of tensors or ints
carried_inputs (Tuple of possibly nested dict/list/tuple of tensors or ints) – A tuple of inputs to cond_fn and body_fn. It’s also the initial value of states that are carried across iterations. Note that when pass an integer as carry, the corresponding return of while_loop will be another int with unknown values because we don’t know how many iterations while_loop will run.
Example 1:
def cond_fn(iter, x): return iter.sum() < 10 def body_fn(iter, x): return iter + 1, x.sin() while_loop(cond_fn, body_fn, (torch.zeros(1), torch.randn(3, 4)))
Example 2:
def cond_fn(int_iter, x): return 2 * int_iter < x.shape[0] def body_fn(int_iter, x): return int_iter + 1, x + int_iter while_loop(cond_fn, body_fn, (0, torch.randn(3, 4)))
Restrictions:
body_fn must return tensors or int with the same metadata (e.g.shape, dtype) as inputs.
body_fn and cond_fn must not in-place mutate the carried_inputs. A clone before the mutation is required.
body_fn and cond_fn must not mutate python variables (e.g. list/dict) created outside of the body_fn.
body_fn and cond_fn’s output cannot alias any of the inputs. A clone is required.
Warning
Temporal Limitations:
‘while_loop’ only supports inference right now. Autograd will be supported in the future.