Shortcuts

(beta) Using TORCH_LOGS python API with torch.compile

Created On: Jan 24, 2024 | Last Updated: Jan 31, 2024 | Last Verified: Nov 05, 2024

Author: Michael Lazos

import logging

This tutorial introduces the TORCH_LOGS environment variable, as well as the Python API, and demonstrates how to apply it to observe the phases of torch.compile.

Note

This tutorial requires PyTorch 2.2.0 or later.

Setup

In this example, we’ll set up a simple Python function which performs an elementwise add and observe the compilation process with TORCH_LOGS Python API.

Note

There is also an environment variable TORCH_LOGS, which can be used to change logging settings at the command line. The equivalent environment variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Skipping because torch.compile is not supported on this device.")
else:
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2


    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()


    separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
    torch._logging.set_logs(output_code=True)
    fn(*inputs)

    separator("")
===================Dynamo Tracing=========================
I0512 16:32:31.244000 635 torch/_dynamo/utils.py:1603] [0/0] ChromiumEventLogger initialized with id 33ede156-2b87-4558-93a2-97a2786ada8d
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0] torchdynamo start compiling fn /var/lib/workspace/recipes_source/torch_logs.py:39, stack (elided 5 frames):
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     sys.exit(main())
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_main(argv)
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_mode.run_make_mode(argv[1:])
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     return make.run_generic_build(args[0])
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     return build_main(args + opts)
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     self._init_builder()
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     self.events.emit('builder-inited')
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     results.append(listener.handler(self.app, *args))
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     ) = generate_dir_rst(
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     intro, title, cost = generate_file_rst(
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks.append(execute_code_block(
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     mem_max, _ = gallery_conf['call_memory'](
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     return 0., func()
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     exec(self.code, self.fake_main.__dict__)
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/recipes_source/torch_logs.py", line 59, in <module>
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]     fn(*inputs)
V0512 16:32:31.245000 635 torch/_dynamo/convert_frame.py:1003] [0/0]
I0512 16:32:31.248000 635 torch/_dynamo/symbolic_convert.py:3324] [0/0] Step 1: torchdynamo start tracing fn /var/lib/workspace/recipes_source/torch_logs.py:39
I0512 16:32:31.248000 635 torch/fx/experimental/symbolic_shapes.py:3334] [0/0] create_env
V0512 16:32:31.251000 635 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:41 in fn (fn)
V0512 16:32:31.251000 635 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             z = x + y
V0512 16:32:31.252000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0512 16:32:31.252000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
V0512 16:32:31.253000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0512 16:32:31.254000 635 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0512 16:32:31.255000 635 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0512 16:32:31.256000 635 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='y', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0512 16:32:31.257000 635 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_y_ L['y'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0512 16:32:31.260000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
V0512 16:32:31.261000 635 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:42 in fn (fn)
V0512 16:32:31.261000 635 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             return z + 2
V0512 16:32:31.261000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST z []
V0512 16:32:31.261000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V0512 16:32:31.262000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable(int: 2)]
V0512 16:32:31.263000 635 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
I0512 16:32:31.263000 635 torch/_dynamo/symbolic_convert.py:3681] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0512 16:32:31.263000 635 torch/_dynamo/symbolic_convert.py:3685] [0/0] RETURN_VALUE triggered compile
V0512 16:32:31.264000 635 torch/_dynamo/output_graph.py:1008] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /var/lib/workspace/recipes_source/torch_logs.py, line 42 in fn>], graph_break=False)
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code] TRACED GRAPH
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]     def forward(self, L_x_: "f32[2, 2][2, 1]cuda:0", L_y_: "f32[2, 2][2, 1]cuda:0"):
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_x_ = L_x_
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_y_ = L_y_
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:41 in fn, code: z = x + y
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         z: "f32[2, 2][2, 1]cuda:0" = l_x_ + l_y_;  l_x_ = l_y_ = None
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:42 in fn, code: return z + 2
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         add_1: "f32[2, 2][2, 1]cuda:0" = z + 2;  z = None
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         return (add_1,)
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0512 16:32:31.266000 635 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
I0512 16:32:31.268000 635 torch/_dynamo/output_graph.py:1515] [0/0] Step 2: calling compiler function inductor
I0512 16:32:32.392000 635 torch/_dynamo/output_graph.py:1520] [0/0] Step 2: done compiler function inductor
I0512 16:32:32.394000 635 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
V0512 16:32:32.394000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[0] 2 None
V0512 16:32:32.395000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[1] 2 None
V0512 16:32:32.395000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[0] 2 None
V0512 16:32:32.395000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[1] 1 None
V0512 16:32:32.395000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].storage_offset() 0 None
V0512 16:32:32.396000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[0] 2 None
V0512 16:32:32.396000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[1] 2 None
V0512 16:32:32.396000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[0] 2 None
V0512 16:32:32.396000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[1] 1 None
V0512 16:32:32.397000 635 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].storage_offset() 0 None
V0512 16:32:32.397000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[0] == 2
V0512 16:32:32.397000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[1] == 2
V0512 16:32:32.398000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[0] == 2
V0512 16:32:32.398000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[1] == 1
V0512 16:32:32.398000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].storage_offset() == 0
V0512 16:32:32.398000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[0] == 2
V0512 16:32:32.399000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[1] == 2
V0512 16:32:32.399000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[0] == 2
V0512 16:32:32.399000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[1] == 1
V0512 16:32:32.400000 635 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].storage_offset() == 0
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2557] [0/0] [__guards] GUARDS:
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] TREE_GUARD_MANAGER:
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] +- RootGuardManager
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:520 in init_ambient_guards
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0)
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=FrameLocalsGuardAccessor(key='y', framelocals_idx=1)
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING
V0512 16:32:32.400000 635 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0512 16:32:32.403000 635 torch/_dynamo/guards.py:2524] [0/0] [__guards] Guard eval latency = 1.47 us
I0512 16:32:32.404000 635 torch/_dynamo/pgo.py:660] [0/0] put_code_state: no cache key, skipping
I0512 16:32:32.404000 635 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
V0512 16:32:32.406000 635 torch/_dynamo/convert_frame.py:1395] skipping: _fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
===================Traced Graph=========================
I0512 16:32:32.407000 635 torch/_dynamo/__init__.py:112] torch._dynamo.reset
I0512 16:32:32.407000 635 torch/_dynamo/__init__.py:145] torch._dynamo.reset_code_caches
===================Fusion Decisions=========================
V0512 16:32:32.778000 635 torch/_inductor/scheduler.py:2562] [0/0] [__fusion] ===== attempting fusion (1/10): 1 nodes =====
V0512 16:32:32.779000 635 torch/_inductor/scheduler.py:2996] [0/0] [__fusion] fuse_nodes_once, candidates:
V0512 16:32:32.779000 635 torch/_inductor/scheduler.py:2998] [0/0] [__fusion]   SchedulerNode(name='op0'), Pointwise(['[2, 2]', 'origins=OrderedSet([add_1, add])'])
V0512 16:32:32.779000 635 torch/_inductor/scheduler.py:3189] [0/0] [__fusion] found 0 possible fusions
V0512 16:32:32.780000 635 torch/_inductor/scheduler.py:2569] [0/0] [__fusion] completed fusion round (1/10): fused 1 nodes into 1 nodes
V0512 16:32:32.780000 635 torch/_inductor/scheduler.py:2569] [0/0] [__fusion]
V0512 16:32:32.780000 635 torch/_inductor/scheduler.py:2576] [0/0] [__fusion] ===== fusion complete (1 iterations) =====
===================Output Code=========================
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] Output code:
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] # AOT ID: ['3_inference']
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import torch
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import math
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import random
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import os
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import tempfile
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from math import inf, nan
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from cmath import nanj
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch import device, empty_strided
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import triton
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import triton.language as tl
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] aten = torch.ops.aten
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] _quantized = torch.ops._quantized
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] async_compile = AsyncCompile()
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] # kernel path: /tmp/torchinductor_ci-user/tmp9o397tqm/sb/csb3bivgknyynhmbxo2aqc2crrn2xmizszz32j7qapmcm7znpgm3.py
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] # Source node to ATen node mapping:
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] #   add_1 => add_1
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] #   z => add
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] # Graph fragment:
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] #   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 2), kwargs = {})
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import triton
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] import triton.language as tl
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] triton_helpers.set_driver_to_gpu()
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] @triton_heuristics.pointwise(
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     size_hints={'x': 4},
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     filename=__file__,
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '1E2C16421D4C3DBA4AD92BFC4278A3CB24C43DEDA6EE7FF9E3FBB1DBB80802DB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': True, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     min_elem_per_thread=0
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] )
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] @triton.jit
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     xnumel = 4
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     xmask = xindex < xnumel
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     x0 = xindex
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     tmp2 = tmp0 + tmp1
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     tmp3 = 2.0
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     tmp4 = tmp2 + tmp3
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp4, xmask)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] ''', device_str='cuda')
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] async_compile.wait(globals())
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] del async_compile
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] def call(args):
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     arg0_1, arg1_1 = args
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     args.clear()
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     assert_size_stride(arg0_1, (2, 2), (2, 1))
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     assert_size_stride(arg1_1, (2, 2), (2, 1))
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         torch.cuda.set_device(0)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         stream0 = get_raw_stream(0)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, stream=stream0)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         del arg0_1
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]         del arg1_1
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     return (buf0, )
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     fn = lambda: call([arg0_1, arg1_1])
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code] if __name__ == "__main__":
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0512 16:32:33.116000 635 torch/_inductor/graph.py:2104] [0/0] [__output_code]
V0512 16:32:33.123000 635 torch/_inductor/graph.py:2115] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/tmp9o397tqm/qy/cqyfjdrcpofkcolsarqskk6fwphyrdujnpbu3b6z5thoqrzmxcdq.py
I0512 16:32:33.582000 635 torch/_inductor/graph.py:2149] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/tmp9o397tqm/qy/cqyfjdrcpofkcolsarqskk6fwphyrdujnpbu3b6z5thoqrzmxcdq.py
============================================

Conclusion

In this tutorial we introduced the TORCH_LOGS environment variable and python API by experimenting with a small number of the available logging options. To view descriptions of all available options, run any python script which imports torch and set TORCH_LOGS to “help”.

Alternatively, you can view the torch._logging documentation to see descriptions of all available logging options.

For more information on torch.compile, see the torch.compile tutorial.

Total running time of the script: ( 0 minutes 3.217 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources