Rate this Page

torch.fx.experimental#

Created On: Feb 07, 2024 | Last Updated On: Apr 03, 2026

Warning

These APIs are experimental and subject to change without notice.

class torch.fx.experimental.sym_node.DynamicInt(val)[source]#

User API for marking dynamic integers in torch.compile. Intended to be compatible with both compile and eager mode.

Example usage:

fn = torch.compile(f)
x = DynamicInt(4)
fn(x)  # compiles x as a dynamic integer input; returns f(4)

torch.fx.experimental.sym_node#

torch.fx.experimental.symbolic_shapes#

ShapeEnv

DimDynamic

Controls how to perform symbol allocation for a dimension.

StrictMinMaxConstraint

For clients: the size at this dimension must be within 'vr' (which specifies a lower and upper bound, inclusive-inclusive) AND it must be non-negative and should not be 0 or 1 (but see NB below).

RelaxedUnspecConstraint

For clients: no explicit constraint; constraint is whatever is implicitly inferred by guards from tracing.

EqualityConstraint

Represent and decide various kinds of equality constraints between input sources.

SymbolicContext

Data structure specifying how we should create symbols in create_symbolic_sizes_strides_storage_offset; e.g., should they be static or dynamic.

StatelessSymbolicContext

Create symbols in create_symbolic_sizes_strides_storage_offset via a symbolic_context determination as given by DimDynamic and DimConstraint.

StatefulSymbolicContext

Create symbols in create_symbolic_sizes_strides_storage_offset via a symbolic_context determination as given by a cache of Source:Symbol.

SubclassSymbolicContext

The correct symbolic context for a given inner tensor of a traceable tensor subclass may differ from that of the outer symbolic context.

DimConstraints

Custom solver for a system of constraints on symbolic dimensions.

ShapeEnvSettings

Encapsulates all shape env settings that could potentially affect FakeTensor dispatch.

ConvertIntKey

CallMethodKey

PropagateUnbackedSymInts

DivideByKey

InnerTensorKey

Specialization

This class is used in multi-graph compilation contexts where we generate multiple specialized graphs and dispatch to the appropriate one at runtime.

is_concrete_int

Utility to check if underlying object in SymInt is concrete value.

is_concrete_bool

Utility to check if underlying object in SymBool is concrete value.

is_concrete_float

Utility to check if underlying object in SymInt is concrete value.

has_free_symbols

Faster version of bool(free_symbols(val))

has_free_unbacked_symbols

Faster version of bool(free_unbacked_symbols(val))

guard_or_true

Try to guard a, if data dependent error encountered just return true.

guard_or_false

Try to guard a, if data dependent error encountered just return false.

guard_size_oblivious

Perform a guard on a symbolic boolean expression in a size oblivious way.

sym_and

and, but for symbolic expressions, without bool casting.

sym_eq

Like ==, but when run on list/tuple, it will recursively test equality and use sym_and to join the results together, without guarding.

sym_or

or, but for symbolic expressions, without bool casting.

constrain_range

Applies a constraint that the passed in SymInt must lie between min-max inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning that it can be used on unbacked SymInts).

constrain_unify

Given two SymInts, constrain them so that they must be equal.

canonicalize_bool_expr

Canonicalize a boolean expression by transforming it into a lt / le inequality and moving all the non-constant terms to the rhs.

statically_known_true

Returns True if x can be simplified to a constant and is true.

statically_known_false

Returns True if x can be simplified to a constant and is False.

has_static_value

User-code friendly utility to check if a value is static or dynamic.

lru_cache

check_consistent

Test that two "meta" values (typically either Tensor or SymInt) have the same values, e.g., after retracing.

compute_unbacked_bindings

After having run fake tensor propagation and producing example_value result, traverse example_value looking for freshly bound unbacked symbols and record their paths for later.

rebind_unbacked

Suppose we are retracing a pre-existing FX graph that previously had fake tensor propagation (and therefore unbacked SymInts).

resolve_unbacked_bindings

When we do fake tensor prop, we oftentimes will allocate new unbacked symints.

is_accessor_node

Helper function to determine if a node is trying to access a symbolic integer such as size, stride, offset or item.

cast_symbool_to_symint_guardless

Converts a SymBool or bool to a SymInt or int without introducing guards.

create_contiguous

error

eval_guards

eval_is_non_overlapping_and_dense

find_symbol_binding_fx_nodes

Find all nodes in an FX graph that bind sympy Symbols.

free_symbols

Recursively collect all free symbols from a value.

free_unbacked_symbols

Like free_symbols, but filtered to only report unbacked symbols

fx_placeholder_targets

fx_placeholder_vals

guard_bool

guard_float

guard_int

guard_scalar

Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float.

guarding_hint_or_throw

Return a concrete hint for a symbolic value, for use in guarding decisions.

has_guarding_hint

Check if a symbolic value has a hint available for guarding.

has_symbolic_sizes_strides

is_nested_int

is_symbol_binding_fx_node

Check if a given FX node is a symbol binding node.

is_symbolic

optimization_hint

Return a concrete hint for a symbolic integer, for use in optimization decisions.

expect_true

log_lru_cache_stats

torch.fx.experimental.proxy_tensor#

make_fx

Given a function f, return a new function which when executed with valid arguments to f, returns an FX GraphModule representing the set of operations that were executed during the course of execution.

handle_sym_dispatch

Call into the currently active proxy tracing mode to do a SymInt/SymFloat/SymBool dispatch trace on a function that operates on these arguments.

get_proxy_mode

Current the currently active proxy tracing mode, or None if we are not currently tracing.

maybe_enable_thunkify

Within this context manager, if you are doing make_fx tracing, we will thunkify all SymNode compute and avoid tracing it into the graph unless it is actually needed.

maybe_disable_thunkify

Within a context, disable thunkification.

thunkify

Delays computation of f until it's called again Also caches the result

track_tensor

track_tensor_tree

decompose

disable_autocast_cache

disable_proxy_modes_tracing

dispatch_trace

extract_val

fake_signature

FX gets confused by varargs, de-confuse it

fetch_object_proxy

fetch_sym_proxy

has_proxy_slot

is_sym_node

maybe_handle_decomp

proxy_call

set_meta

set_original_aten_op

set_proxy_slot

snapshot_fake

torch.fx.experimental.optimization#

extract_subgraph

Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph.

matches_module_pattern

modules_to_mkldnn

For each node, if it's a module that can be preconverted into MKLDNN, then we do so and create a mapping to allow us to convert from the MKLDNN version of the module to the original.

optimize_for_inference

Performs a set of optimization passes to optimize a model for the purposes of inference.

remove_dropout

Removes all dropout layers from the module.

replace_node_module

reset_modules

Maps each module that's been changed with modules_to_mkldnn back to its original.

use_mkl_length

This is a heuristic that can be passed into optimize_for_inference that determines whether a subgraph should be run in MKL by checking if there are more than 2 nodes in it

torch.fx.experimental.recording#

torch.fx.experimental.unification.core#

reify

Replace variables of expression with substitution >>> x, y = var(), var() >>> e = (1, x, (3, y)) >>> s = {x: 2, y: 4} >>> reify(e, s) (1, 2, (3, 4)) >>> e = {1: x, 3: (y, 5)} >>> reify(e, s) {1: 2, 3: (4, 5)}

torch.fx.experimental.unification.multipledispatch.utils#

typename

Get the name of type.

expand_tuples

>>> expand_tuples([1, (2, 3)])

groupby

Group a collection by a key function >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} >>> iseven = lambda x: x % 2 == 0 >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} .

raises

reverse_dict

Reverses direction of dependence dict.

torch.fx.experimental.unification.unification_tools#

assoc

Return a new dict with new key value pair

assoc_in

Return a new dict with new, potentially nested, key value pair

dissoc

Return a new dict with the given key(s) removed.

first

The first element in a sequence

groupby

Group a collection by a key function

keyfilter

Filter items in dictionary by key

keymap

Apply function to keys of dictionary

merge

Merge a collection of dictionaries

merge_with

Merge dictionaries and apply function to combined values

update_in

Update value in a (potentially) nested dictionary

valfilter

Filter items in dictionary by value

valmap

Apply function to values of dictionary

itemfilter

Filter items in dictionary by item

itemmap

Apply function to items of dictionary

torch.fx.experimental.migrate_gradual_types.transform_to_z3#

transform_algebraic_expression

Transforms an algebraic expression to z3 format :param expr: An expression is either a dimension variable or an algebraic-expression

transform_all_constraints

Given a trace, generates constraints and transforms them to z3 format

transform_all_constraints_trace_time

Takes a node and a graph and generates two sets of constraints.

transform_dimension

Takes a dimension variable or a number and transforms it to a tuple according to our scheme :param dimension: The dimension to be transformed :param counter: variable tracking

transform_to_z3

transform_var

Transforms tensor variables to a format understood by z3 :param tensor: Tensor variable or a tensor type potentially with variable dimensions

evaluate_conditional_with_constraints

Given an IR and a node representing a conditional, evaluate the conditional and its negation :param tracer_root: Tracer root for module instances :param node: The node to be evaluated

torch.fx.experimental.migrate_gradual_types.constraint#

torch.fx.experimental.migrate_gradual_types.constraint_generator#

adaptive_inference_rule

assert_inference_rule

batchnorm_inference_rule

bmm_inference_rule

Constraints that match the input to a size 3 tensor and switch the dimensions according to the rules of batch multiplication

embedding_inference_rule

The output shape differs from the input shape in the last dimension

embedding_inference_rule_functional

eq_inference_rule

equality_inference_rule

We generate the constraint: input = output

expand_inference_rule

We generate the exact constraints as we do for tensor additions but we constraint the rank of this expression to be equal to len(n.args[1:]) so that only those cases get considered for the output

full_inference_rule

gt_inference_rule

lt_inference_rule

masked_fill_inference_rule

Similar to addition.

neq_inference_rule

Translates to inconsistent in gradual types.

tensor_inference_rule

If the tensor is a scalar, we will skip it since we do not support scalars yet.

torch_dim_inference_rule

torch_linear_inference_rule

type_inference_rule

We generate the constraint: input = output

view_inference_rule

Similar to reshape but with an extra condition on the strides

register_inference_rule

transpose_inference_rule

Can be considered as a sequence of two index selects, so we generate constraints accordingly

torch.fx.experimental.migrate_gradual_types.constraint_transformation#

apply_padding

We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results

calc_last_two_dims

Generates constraints for the last two dimensions of a convolution or a maxpool output :param constraint: CalcConv or CalcMaxPool :param d: The list of output dimensions

create_equality_constraints_for_broadcasting

Create equality constraints for when no broadcasting occurs :param e1: Input 1 :param e2: Input 2 :param e11: Broadcasted input 1 :param e12: Broadcasted input 2 :param d1: Variables that store dimensions for e1 :param d2: Variables that store dimensions for e2 :param d11: Variables that store dimensions for e11 :param d12: Variables that store dimensions for e22

is_target_div_by_dim

Generate constraints to check if the target dimensions are divisible by the input dimensions :param target: Target dimensions :param dim: Input dimensions

no_broadcast_dim_with_index

param d1:

input 1

register_transformation_rule

transform_constraint

Transforms a constraint into a simpler constraint.

transform_get_item

generate an equality of the form: t = [a1, ..., an] then generate constraints that check if the given index is valid given this particular tensor size.

transform_get_item_tensor

When the index is a tuple, then the output will be a tensor TODO: we have to check if this is the case for all HF models

transform_index_select

The constraints consider the given tensor size, checks if the index is valid and if so, generates a constraint for replacing the input dimension with the required dimension

transform_transpose

Similar to a sequence of two index-selects

valid_index

Given a list of dimensions, checks if an index is valid in the list

valid_index_tensor

if the slice instances exceed the length of the dimensions then this is a type error so we return False

is_dim_div_by_target

Generate constraints to check if the input dimensions is divisible by the target dimensions :param target: Target dimensions :param dim: Input dimensions

torch.fx.experimental.graph_gradual_typechecker#

adaptiveavgpool2d_check

adaptiveavgpool2d_inference_rule

The input and output sizes should be the same except for the last two dimensions taken from the input, which represent width and height

all_eq

For operations where the input shape is equal to the output shape

bn2d_inference_rule

Given a BatchNorm2D instance and a node check the following conditions: - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') - t is consistent with t' - x_2 is consistent with the module's num_features - x_2' is consistent with the module's num_features output type: the more precise type of t and t'

calculate_out_dimension

For calculating h_in and w_out according to the conv2D documentation

conv_refinement_rule

The equality constraints are between the first dimension of the input and output

conv_rule

Represents the output in terms of an algrbraic expression w.r.t the input when possible

element_wise_eq

For element-wise operations and handles broadcasting.

expand_to_tensor_dim

Expand a type to the desired tensor dimension if possible Raise an error otherwise.

first_two_eq

For operations where the first two dimensions of the input and output shape are equal

register_algebraic_expressions_inference_rule

register_inference_rule

register_refinement_rule

transpose_inference_rule

We check that dimensions for the transpose operations are within range of the tensor type of the node

torch.fx.experimental.meta_tracer#

torch.fx.experimental.accelerator_partitioner#

check_dependency

Given a partition,check if there is a circular dependency on this partition using bfs

combine_two_partitions

Given a list of partitions and its two partitions, combine these two partitions into a new one appending to the partitions and remove the previous two partitions from the list of partitions

reorganize_partitions

Given a list of partitions, reorganize partition id, its parents and its children for each partition

reset_partition_device

set_parents_and_children

Given a list of partitions, mark parents and children for each partition

torch.fx.experimental.debug#

set_trace

Sets a breakpoint in gm's generated python code.

torch.fx.experimental.merge_matmul#

are_nodes_independent

Check if all of the given nodes are pairwise-data independent.

may_depend_on

Determine if one node depends on another in a torch.fx.Graph.

merge_matmul

A graph transformation that merges matrix multiplication operations that share the same right-hand side operand into one large matrix multiplication.

torch.fx.experimental.unification.match#

edge

A should be checked before B Tie broken by tie_breaker, defaults to hash

match

ordering

A sane ordering of signatures to check, first to last Topological sort of edges as given by edge and supercedes

supercedes

a is a more specific match than b

torch.fx.experimental.unification.more#

reify_object

Reify a Python object with a substitution >>> class Foo(object): .

unifiable

Register standard unify and reify operations on class This uses the type and __dict__ or __slots__ attributes to define the nature of the term See Also: >>> class A(object): .

unify_object

Unify two Python objects Unifies their type and __dict__ attributes >>> class Foo(object): .

torch.fx.experimental.unification.multipledispatch.conflict#

ambiguities

All signature pairs such that A is ambiguous with B

ambiguous

A is consistent with B but neither is strictly more specific

consistent

It is possible for an argument list to satisfy both A and B

edge

A should be checked before B Tie broken by tie_breaker, defaults to hash

ordering

A sane ordering of signatures to check, first to last Topological sort of edges as given by edge and supercedes

super_signature

A signature that would break ambiguities

supercedes

A is consistent and strictly more specific than B

torch.fx.experimental.unification.multipledispatch.core#

dispatch

Dispatch function on the types of the inputs Supports dispatch on all non-keyword arguments.

ismethod

Is func a method? Note that this has to work as the method is defined but before the class is defined.

torch.fx.experimental.unification.multipledispatch.dispatcher#

ambiguity_warn

Raise warning when ambiguity is detected.

halt_ordering

Deprecated interface to temporarily disable ordering.

restart_ordering

Deprecated interface to temporarily resume ordering.

source

str_signature

String representation of type signature >>> str_signature((int, float)) 'int, float'

variadic_signature_matches

variadic_signature_matches_iter

Check if a set of input types matches a variadic signature.

warning_text

The text for ambiguity warnings

torch.fx.experimental.unification.multipledispatch.variadic#

isvariadic

Check whether the type obj is variadic.

torch.fx.experimental.unification.utils#

freeze

Freeze container to hashable form >>> freeze(1) 1 >>> freeze([1, 2]) (1, 2) >>> freeze({1: 2}) # doctest: +SKIP frozenset([(1, 2)])

hashable

raises

reverse_dict

Reverses direction of dependence dict.

transitive_get

Transitive dict.get >>> d = {1: 2, 2: 3, 3: 4} >>> d.get(1) 2 >>> transitive_get(1, d) 4

xfail

torch.fx.experimental.unification.variable#

var

variables

Context manager for logic variables

vars

torch.fx.experimental.unify_refinements#

check_for_type_equality

A check equality to be used in fixed points.

infer_symbolic_types

Calls our symbolic inferencer twice.

infer_symbolic_types_single_pass

Calls our symbolic inferencer once.

substitute_all_types

Apply the most general unifier to all types in a graph till reaching a fixed point.

substitute_solution_one_type

Apply the most general unifier to a type

unify_eq

Apply unification to a set of equality constraints

torch.fx.experimental.validator#