Rate this Page

(beta) Using TORCH_LOGS python API with torch.compile#

Created On: Jan 24, 2024 | Last Updated: Jan 31, 2024 | Last Verified: Nov 05, 2024

Author: Michael Lazos

import logging

This tutorial introduces the TORCH_LOGS environment variable, as well as the Python API, and demonstrates how to apply it to observe the phases of torch.compile.

Note

This tutorial requires PyTorch 2.2.0 or later.

Setup#

In this example, we’ll set up a simple Python function which performs an elementwise add and observe the compilation process with TORCH_LOGS Python API.

Note

There is also an environment variable TORCH_LOGS, which can be used to change logging settings at the command line. The equivalent environment variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Skipping because torch.compile is not supported on this device.")
else:
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2


    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()


    separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
    torch._logging.set_logs(output_code=True)
    fn(*inputs)

    separator("")
===================Dynamo Tracing=========================
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0] torchdynamo start compiling fn /var/lib/workspace/recipes_source/torch_logs.py:39, stack (elided 5 frames):
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/bin/sphinx-build", line 7, in <module>
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     sys.exit(main())
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 339, in main
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return make_main(argv)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 213, in make_main
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return make_mode.run_make_mode(argv[1:])
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 181, in run_make_mode
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return make.run_generic_build(args[0])
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 169, in run_generic_build
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return build_main(args + opts)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 293, in build_main
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 272, in __init__
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     self._init_builder()
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 343, in _init_builder
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     self.events.emit('builder-inited')
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 97, in emit
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     results.append(listener.handler(self.app, *args))
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 757, in generate_gallery_rst
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     ) = generate_dir_rst(
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 606, in generate_dir_rst
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     results = parallel(
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 607, in <genexpr>
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/var/lib/workspace/conf.py", line 85, in wrapper
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     p.start()
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     self._popen = self._Popen(self)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return _default_context.get_context().Process._Popen(process_obj)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return Popen(process_obj)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     self._launch(process_obj)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     code = process_obj._bootstrap(parent_sentinel=child_r)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     self.run()
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     self._target(*self._args, **self._kwargs)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/var/lib/workspace/conf.py", line 73, in call_fn
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     result = func(*args, **kwargs)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1374, in generate_file_rst
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     output_blocks, time_elapsed = execute_script(
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1192, in execute_script
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     execute_code_block(
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1048, in execute_code_block
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 876, in _exec_and_get_memory
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     mem_max, _ = call_memory(
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1725, in _sg_call_memory_noop
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     return 0.0, func()
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 794, in __call__
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     exec(self.code, self.fake_main.__dict__)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]   File "/var/lib/workspace/recipes_source/torch_logs.py", line 59, in <module>
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]     fn(*inputs)
V0807 18:34:51.540000 22302 torch/_dynamo/convert_frame.py:1055] [0/0]
I0807 18:34:51.544000 22302 torch/_dynamo/symbolic_convert.py:3320] [0/0] Step 1: torchdynamo start tracing fn /var/lib/workspace/recipes_source/torch_logs.py:39
I0807 18:34:51.545000 22302 torch/fx/experimental/symbolic_shapes.py:3767] [0/0] create_env
V0807 18:34:51.548000 22302 torch/_dynamo/symbolic_convert.py:1237] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:41 in fn (fn)
V0807 18:34:51.548000 22302 torch/_dynamo/symbolic_convert.py:1237] [0/0] [__trace_source]             z = x + y
V0807 18:34:51.550000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0807 18:34:51.550000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
V0807 18:34:51.550000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0807 18:34:51.552000 22302 torch/_dynamo/variables/builder.py:3373] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], specialize_on=[[], []], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0807 18:34:51.553000 22302 torch/_dynamo/output_graph.py:2614] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0807 18:34:51.554000 22302 torch/_dynamo/variables/builder.py:3373] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], specialize_on=[[], []], view_base_context=None, tensor_source=LocalSource(local_name='y', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0807 18:34:51.555000 22302 torch/_dynamo/output_graph.py:2614] [0/0] create_graph_input L_y_ L['y'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0807 18:34:51.558000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
V0807 18:34:51.558000 22302 torch/_dynamo/symbolic_convert.py:1237] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:42 in fn (fn)
V0807 18:34:51.558000 22302 torch/_dynamo/symbolic_convert.py:1237] [0/0] [__trace_source]             return z + 2
V0807 18:34:51.559000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_FAST z []
V0807 18:34:51.559000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V0807 18:34:51.559000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable(int: 2)]
V0807 18:34:51.561000 22302 torch/_dynamo/symbolic_convert.py:1260] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
I0807 18:34:51.561000 22302 torch/_dynamo/symbolic_convert.py:3648] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0807 18:34:51.561000 22302 torch/_dynamo/symbolic_convert.py:3652] [0/0] RETURN_VALUE triggered compile
V0807 18:34:51.561000 22302 torch/_dynamo/output_graph.py:1263] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /var/lib/workspace/recipes_source/torch_logs.py, line 42 in fn>], graph_break=False)
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code] TRACED GRAPH
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]  ===== __compiled_fn_1_35a22ccc_42e3_4562_8af1_8c0bee882506 =====
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]  /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]     def forward(self, L_x_: "f32[2, 2][2, 1]cuda:0", L_y_: "f32[2, 2][2, 1]cuda:0"):
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]         l_x_ = L_x_
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]         l_y_ = L_y_
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:41 in fn, code: z = x + y
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]         z: "f32[2, 2][2, 1]cuda:0" = l_x_ + l_y_;  l_x_ = l_y_ = None
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:42 in fn, code: return z + 2
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]         add_1: "f32[2, 2][2, 1]cuda:0" = z + 2;  z = None
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]         return (add_1,)
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]
V0807 18:34:51.565000 22302 torch/_dynamo/output_graph.py:1667] [0/0] [__graph_code]
I0807 18:34:51.567000 22302 torch/_dynamo/output_graph.py:1842] [0/0] Step 2: calling compiler function inductor
I0807 18:34:52.653000 22302 torch/fx/experimental/symbolic_shapes.py:5238] [0/0] produce_guards
I0807 18:34:52.655000 22302 torch/fx/experimental/symbolic_shapes.py:5238] [0/0] produce_guards
I0807 18:34:52.659000 22302 torch/_dynamo/output_graph.py:1847] [0/0] Step 2: done compiler function inductor
I0807 18:34:52.661000 22302 torch/fx/experimental/symbolic_shapes.py:5238] [0/0] produce_guards
V0807 18:34:52.662000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['x'].size()[0] 2 None
V0807 18:34:52.662000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['x'].size()[1] 2 None
V0807 18:34:52.663000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['x'].stride()[0] 2 None
V0807 18:34:52.663000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['x'].stride()[1] 1 None
V0807 18:34:52.663000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['x'].storage_offset() 0 None
V0807 18:34:52.664000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['y'].size()[0] 2 None
V0807 18:34:52.664000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['y'].size()[1] 2 None
V0807 18:34:52.664000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['y'].stride()[0] 2 None
V0807 18:34:52.665000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['y'].stride()[1] 1 None
V0807 18:34:52.665000 22302 torch/fx/experimental/symbolic_shapes.py:5458] [0/0] track_symint L['y'].storage_offset() 0 None
V0807 18:34:52.665000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['x'].size()[0] == 2
V0807 18:34:52.666000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['x'].size()[1] == 2
V0807 18:34:52.666000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['x'].stride()[0] == 2
V0807 18:34:52.666000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['x'].stride()[1] == 1
V0807 18:34:52.667000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['x'].storage_offset() == 0
V0807 18:34:52.667000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['y'].size()[0] == 2
V0807 18:34:52.667000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['y'].size()[1] == 2
V0807 18:34:52.667000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['y'].stride()[0] == 2
V0807 18:34:52.668000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['y'].stride()[1] == 1
V0807 18:34:52.668000 22302 torch/fx/experimental/symbolic_shapes.py:5679] [0/0] Skipping guard L['y'].storage_offset() == 0
V0807 18:34:52.668000 22302 torch/_dynamo/guards.py:3064] [0/0] [__guards] GUARDS:
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards]
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] TREE_GUARD_MANAGER:
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] +- RootGuardManager
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | +- LAMBDA_GUARD: torch._functorch.aot_autograd.utils.top_saved_tensors_hooks ids == None  # _dynamo/output_graph.py:633 in init_ambient_guards
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:621 in init_ambient_guards
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0)
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=FrameLocalsGuardAccessor(key='y', framelocals_idx=1)
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards] | | +- NO_TENSOR_ALIASING
V0807 18:34:52.669000 22302 torch/_dynamo/guards.py:2863] [0/0] [__guards]
V0807 18:34:52.689000 22302 torch/_dynamo/guards.py:2894] [0/0] [__guards] Guard eval latency = 35.65 us
I0807 18:34:52.689000 22302 torch/_dynamo/pgo.py:785] [0/0] put_code_state: no cache key, skipping
I0807 18:34:52.690000 22302 torch/_dynamo/convert_frame.py:1175] [0/0] run_gc_after_compile: running gc
V0807 18:34:52.693000 22302 torch/_dynamo/convert_frame.py:1458] skipping: inner (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_compile.py)
V0807 18:34:52.694000 22302 torch/_dynamo/convert_frame.py:1458] skipping: disable (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/decorators.py)
V0807 18:34:52.694000 22302 torch/_dynamo/convert_frame.py:1458] skipping: innermost_fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0807 18:34:52.695000 22302 torch/_dynamo/convert_frame.py:1458] skipping: __init__ (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0807 18:34:52.695000 22302 torch/_dynamo/convert_frame.py:1458] skipping: __init__ (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0807 18:34:52.695000 22302 torch/_dynamo/convert_frame.py:1458] skipping: nothing (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0807 18:34:52.696000 22302 torch/_dynamo/convert_frame.py:1458] skipping: __call__ (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
V0807 18:34:52.696000 22302 torch/_dynamo/convert_frame.py:1458] skipping: _fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
===================Traced Graph=========================
I0807 18:34:52.697000 22302 torch/_dynamo/__init__.py:118] torch._dynamo.reset
I0807 18:34:52.697000 22302 torch/_dynamo/__init__.py:151] torch._dynamo.reset_code_caches
===================Fusion Decisions=========================
===================Output Code=========================
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] Output code:
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] # AOT ID: ['0_inference']
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import torch
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import math
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import random
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import os
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import tempfile
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from math import inf, nan
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from cmath import nanj
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch import device, empty_strided
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import triton
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import triton.language as tl
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] aten = torch.ops.aten
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] _quantized = torch.ops._quantized
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] assert_alignment = torch._C._dynamo.guards.assert_alignment
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] async_compile = AsyncCompile()
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] # kernel path: /tmp/torchinductor_ci-user/vt/cvt77cvt5ebfadc7bf2hd62rk5ltbuirvztu6wl4zumberx3xacq.py
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] # Source node to ATen node mapping:
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] #   add_1 => add_1
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] #   z => add
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] # Graph fragment:
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] #   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 2), kwargs = {})
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import triton
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] import triton.language as tl
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] triton_helpers.set_driver_to_gpu()
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] @triton_heuristics.pointwise(
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     size_hints={'x': 4},
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     filename=__file__,
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '5521EADCB2516098F638687B39B477AA524882055648F5AE9FFB68D065B487C6', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'tiling_scores': {'x': 32}},
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     min_elem_per_thread=0
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] )
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] @triton.jit
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     xnumel = 4
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     xmask = xindex < xnumel
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     x0 = xindex
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     tmp2 = tmp0 + tmp1
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     tmp3 = 2.0
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     tmp4 = tmp2 + tmp3
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp4, xmask)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] ''', device_str='cuda')
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] async_compile.wait(globals())
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] del async_compile
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] def call(args):
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     arg0_1, arg1_1 = args
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     args.clear()
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     assert_size_stride(arg0_1, (2, 2), (2, 1))
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     assert_size_stride(arg1_1, (2, 2), (2, 1))
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         torch.cuda.set_device(0)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         stream0 = get_raw_stream(0)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, stream=stream0)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         del arg0_1
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]         del arg1_1
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     return (buf0, )
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     fn = lambda: call([arg0_1, arg1_1])
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code] if __name__ == "__main__":
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0807 18:34:52.820000 22302 torch/_inductor/codecache.py:1236] [0/0] [__output_code]
V0807 18:34:52.825000 22302 torch/_inductor/codecache.py:1237] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/oc/coccdv2xhccfuzx37vyl7mznhjudr66zpmjc2e2uxtmmrfs5rlkn.py
============================================

Conclusion#

In this tutorial we introduced the TORCH_LOGS environment variable and python API by experimenting with a small number of the available logging options. To view descriptions of all available options, run any python script which imports torch and set TORCH_LOGS to “help”.

Alternatively, you can view the torch._logging documentation to see descriptions of all available logging options.

For more information on torch.compile, see the torch.compile tutorial.

Total running time of the script: (0 minutes 3.017 seconds)