torch.onnx.ops#
Created On: Jun 10, 2025 | Last Updated On: Jun 10, 2025
ONNX operators as native torch.fx operators.
This module provides a set of functions to create ONNX operators in the FX graph which are exportable to ONNX.
Symbolic Operators#
Operators that can be used to create any ONNX ops in the FX graph symbolically.
These operators do not do actual computation. It’s recommended that you used them
inside an if torch.onnx.is_in_onnx_export
block.
- torch.onnx.ops.symbolic(domain_op, /, inputs, attrs=None, *, dtype, shape, version=None, metadata_props=None)[source]#
Create a symbolic FX operator to represent an arbitrary ONNX operator.
This function is used to create a symbolic operator with a single output. To create an operator with multiple outputs, use
symbolic_multi_out()
.You may use
if torch.onnx.is_in_onnx_export()
to conditionally enable the symbolic logic only duringtorch.onnx.export()
.Example:
class CustomOp(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Normal torch operators can interleave with the symbolic ops during ONNX export x = x + 1 # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensor will have the specified dtype and shape val = torch.onnx.ops.symbolic( "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), dtype=x.dtype, shape=x.shape, version=1, ) # The result of the symbolic op can be used in normal torch operations during ONNX export return torch.nn.functional.relu(val) # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
- Parameters
domain_op (str) – The domain and operator name, separated by “::”. For example, “custom_domain::CustomOp”.
inputs (Sequence[torch.Tensor | None]) – The input tensors to the operator.
attrs (dict[str, int | float | str | bool | Sequence[int] | Sequence[float] | Sequence[str] | Sequence[bool]] | None) – The attributes of the operator. The keys are attribute names and the values are attribute values. Valid attribute types are int, float, str, bool, and lists of int, float, str, and bool. Tensor attributes are unsupported.
dtype (torch.dtype | int) – The data type of the output tensor.This can be either a torch.dtype or an integer representing the ONNX data type.
shape (Sequence[int | torch.SymInt]) – The shape of the output tensor. This can be a list of integers or SymInt values.
version (int | None) – The version of the opset used for the operator.
metadata_props (dict[str, str] | None) – Metadata properties for the ONNX node. This is a dictionary of str-str pairs.
- Returns
The output tensor of the operator.
- Return type
- torch.onnx.ops.symbolic_multi_out(domain_op, /, inputs, attrs=None, *, dtypes, shapes, version=None, metadata_props=None)[source]#
Create a symbolic FX operator to represent an arbitrary ONNX operator with multiple outputs.
You may use
if torch.onnx.is_in_onnx_export()
to conditionally enable the symbolic logic only duringtorch.onnx.export()
.Example:
class CustomOp(torch.nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: # Normal torch operators can interleave with the symbolic ops during ONNX export x = x + 1 # Create a symbolic ONNX operator with the name "CustomOp" in the "custom_domain" domain. # The output tensors will have the specified dtypes and shapes (out1, out2) = torch.onnx.ops.symbolic( "custom_domain::CustomOp", (x,), dict(attr_key="attr_value"), dtypes=(x.dtype, torch.float32), shapes=(x.shape, [1, 2, 3]), version=1, ) # The result of the symbolic op can be used in normal torch operations during ONNX export return torch.nn.functional.relu(out1 + out2) # You may then export this model to ONNX using torch.onnx.export(..., dynamo=True).
- Parameters
domain_op (str) – The domain and operator name, separated by “::”. For example, “custom_domain::CustomOp”.
inputs (Sequence[torch.Tensor | None]) – The input tensors to the operator.
attrs (dict[str, int | float | str | bool | Sequence[int] | Sequence[float] | Sequence[str] | Sequence[bool]] | None) – The attributes of the operator. The keys are attribute names and the values are attribute values. Valid attribute types are int, float, str, bool, and lists of int, float, str, and bool. Tensor attributes are unsupported.
dtypes (Sequence[torch.dtype | int]) – The data types of the output tensors. This can be a list of torch.dtype or integers representing the ONNX data types. The length of this list must be the number of outputs.
shapes (Sequence[Sequence[int | torch.SymInt]]) – The shapes of the output tensors. This can be a list of lists of integers or SymInt values. The length of this list must be the number of outputs.
version (int | None) – The version of the opset used for the operator.
metadata_props (dict[str, str] | None) – Metadata properties for the ONNX node. This is a dictionary of str-str pairs.
- Returns
A list of output tensors of the operator.
- Return type
Sequence[torch.Tensor]
ONNX Operators#
The following operators are implemented as native PyTorch ops and can be exported as
ONNX operators. They can be used natively in an nn.Module
.
For example, you can define a module:
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
)
and export it to ONNX using:
input_data = torch.rand(2, 3, 4, 8)
position_ids_data = torch.randint(0, 50, (2, 3)).long()
sin_cache_data = torch.rand(50, 4)
cos_cache_data = torch.rand(50, 4)
dynamic_shapes = {
"input_data": {0: torch.export.Dim.DYNAMIC},
"cos_cache_data": None,
"sin_cache_data": None,
"position_ids_data": {0: torch.export.Dim.DYNAMIC},
}
onnx_program = torch.onnx.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
dynamic_shapes=dynamic_shapes,
dynamo=True,
opset_version=23,
)
Printing the ONNX program will show the ONNX operators used in the graph:
<...>
graph(
name=main_graph,
inputs=(
%"input_data"<FLOAT,[s0,3,4,8]>,
%"cos_cache_data"<FLOAT,[50,4]>,
%"sin_cache_data"<FLOAT,[50,4]>,
%"position_ids_data"<INT64,[s0,3]>
),
outputs=(
%"rotary_embedding"<FLOAT,[s0,3,4,8]>
),
) {
0 | # rotary_embedding
%"rotary_embedding"<FLOAT,[s0,3,4,8]> ⬅️ ::RotaryEmbedding(%"input_data", %"cos_cache_data", %"sin_cache_data", %"position_ids_data")
return %"rotary_embedding"<FLOAT,[s0,3,4,8]>
}
with the corresponding ExportedProgram
:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, input_data: "f32[s0, 3, 4, 8]", cos_cache_data: "f32[50, 4]", sin_cache_data: "f32[50, 4]", position_ids_data: "i64[s0, 3]"):
rotary_embedding: "f32[s0, 3, 4, 8]" = torch.ops.onnx.RotaryEmbedding.opset23(input_data, cos_cache_data, sin_cache_data, position_ids_data); input_data = cos_cache_data = sin_cache_data = position_ids_data = None
return (rotary_embedding,)
- torch.onnx.ops.rotary_embedding(X, cos_cache, sin_cache, position_ids=None, *, interleaved=False, num_heads=0, rotary_embedding_dim=0)[source]#
RotaryEmbedding op in ONNX.
https://onnx.ai/onnx/operators/onnx__RotaryEmbedding.html
RotaryEmbedding is the implementation of rotary positional embeddings (RoPE) based on the paper https://arxiv.org/pdf/2104.09864. The key advantage of RoPE is that it allows the model to understand both the absolute position of a token and the relative distances between tokens. This is achieved through a rotational mechanism where the extent of rotation is computed based on the token’s absolute position (position_ids).
The rotational mechanism is defined by sine and cosine functions that are used to represent the rotation angles. For each token in the sequence, its positional embedding is computed by rotating its embedding vector. This is done by splitting the embedding vector either into two halves or interleaving every alternate token and applying the rotation matrix to each half of the embedding vector. The rotation matrix is parameterized by the token’s position in the sequence. The rotated halves of the embedding vector are concatenated to form the final positional embedding for each token. The rotated positional embeddings are used in the self-attention mechanism. The rotation ensures that the model captures both absolute and relative positional information.
- Parameters
X (torch.Tensor) – The input tensor representing the token embeddings. 4D tensor with shape (batch_size, num_heads, sequence_length, head_size) or 3D tensor with shape (batch_size, sequence_length, hidden_size). For cases with a 4D input tensor, head_size has to be even. For cases with a 3D input tensor, num_heads attribute must be provided and hidden_size must be an even multiple of num_heads where hidden_size = num_heads * head_size
cos_cache (torch.Tensor) – The cosine values for the rotation. 2D tensor with shape (max_position_id_plus_1, head_size / 2) for full rotation or (max_position_id_plus_1, rotary_embedding_dim / 2) for partial rotation when position_ids are provided. 3D tensor with shape (batch_size, sequence_length, head_size / 2) for full rotation or (batch_size, sequence_length, rotary_embedding_dim / 2) for partial rotation when position_ids are not provided. max_position_id_plus_1 is a parameter to the model.
sin_cache (torch.Tensor) – The sine values for the rotation. 2D tensor with shape (max_position_id_plus_1, head_size / 2) for full rotation or (max_position_id_plus_1, rotary_embedding_dim / 2) for partial rotation when position_ids are provided. 3D tensor with shape (batch_size, sequence_length, head_size / 2) for full rotation or (batch_size, sequence_length, rotary_embedding_dim / 2) for partial rotation when position_ids are not provided. max_position_id_plus_1 is a parameter to the model.
position_ids (torch.Tensor | None) – The position indices for the tokens. 2D tensor with shape (batch_size, sequence_length).
interleaved (bool) – Rotate using interleaved pattern. Default value is 0 (False).
num_heads (int) – Number of attention heads. Must be provided when input is a 3D tensor.
rotary_embedding_dim (int) – Rotary embedding dimension used to apply partial rotary embeddings.
- Returns
Tensor with same shape as input.
- Return type
ONNX to ATen Decomposition Table#
You can use torch.onnx.ops.aten_decompositions()
to obtain a decomposition table
to decompose ONNX operators defined above to ATen operators.
class Model(torch.nn.Module):
def forward(
self, input_data, cos_cache_data, sin_cache_data, position_ids_data
):
return torch.onnx.ops.rotary_embedding(
input_data,
cos_cache_data,
sin_cache_data,
position_ids_data,
)
model = Model()
ep = torch.export.export(
model,
(input_data, cos_cache_data, sin_cache_data, position_ids_data),
)
# The program can be decomposed into aten ops
ep_decomposed = ep.run_decompositions(torch.onnx.ops.aten_decompositions())