Shortcuts

(beta) Using TORCH_LOGS python API with torch.compile

Created On: Jan 24, 2024 | Last Updated: Jan 31, 2024 | Last Verified: Nov 05, 2024

Author: Michael Lazos

import logging

This tutorial introduces the TORCH_LOGS environment variable, as well as the Python API, and demonstrates how to apply it to observe the phases of torch.compile.

Note

This tutorial requires PyTorch 2.2.0 or later.

Setup

In this example, we’ll set up a simple Python function which performs an elementwise add and observe the compilation process with TORCH_LOGS Python API.

Note

There is also an environment variable TORCH_LOGS, which can be used to change logging settings at the command line. The equivalent environment variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Skipping because torch.compile is not supported on this device.")
else:
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2


    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()


    separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
    torch._logging.set_logs(output_code=True)
    fn(*inputs)

    separator("")
===================Dynamo Tracing=========================
I0618 22:35:12.599000 24001 torch/_dynamo/utils.py:1603] [0/0] ChromiumEventLogger initialized with id b79bef2e-9c21-46e0-a38b-4e454658defb
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0] torchdynamo start compiling fn /var/lib/workspace/recipes_source/torch_logs.py:39, stack (elided 5 frames):
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     sys.exit(main())
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 288, in main
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_main(argv)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 193, in make_main
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_mode.run_make_mode(argv[1:])
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return make.run_generic_build(args[0])
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return build_main(args + opts)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 272, in build_main
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 256, in __init__
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     self._init_builder()
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 314, in _init_builder
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     self.events.emit('builder-inited')
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     results.append(listener.handler(self.app, *args))
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     ) = generate_dir_rst(
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     intro, title, cost = generate_file_rst(
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/conf.py", line 79, in wrapper
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     p.start()
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     self._popen = self._Popen(self)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return _default_context.get_context().Process._Popen(process_obj)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return Popen(process_obj)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     self._launch(process_obj)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     code = process_obj._bootstrap(parent_sentinel=child_r)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     self.run()
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     self._target(*self._args, **self._kwargs)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/conf.py", line 67, in call_fn
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     result = func(*args, **kwargs)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks.append(execute_code_block(
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     mem_max, _ = gallery_conf['call_memory'](
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     return 0., func()
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     exec(self.code, self.fake_main.__dict__)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/recipes_source/torch_logs.py", line 59, in <module>
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]     fn(*inputs)
V0618 22:35:12.601000 24001 torch/_dynamo/convert_frame.py:1003] [0/0]
I0618 22:35:12.605000 24001 torch/_dynamo/symbolic_convert.py:3324] [0/0] Step 1: torchdynamo start tracing fn /var/lib/workspace/recipes_source/torch_logs.py:39
I0618 22:35:12.605000 24001 torch/fx/experimental/symbolic_shapes.py:3334] [0/0] create_env
V0618 22:35:12.608000 24001 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:41 in fn (fn)
V0618 22:35:12.608000 24001 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             z = x + y
V0618 22:35:12.610000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0618 22:35:12.610000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
V0618 22:35:12.610000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0618 22:35:12.612000 24001 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0618 22:35:12.613000 24001 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0618 22:35:12.614000 24001 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='y', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0618 22:35:12.615000 24001 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_y_ L['y'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0618 22:35:12.618000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
V0618 22:35:12.618000 24001 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:42 in fn (fn)
V0618 22:35:12.618000 24001 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             return z + 2
V0618 22:35:12.618000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST z []
V0618 22:35:12.619000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V0618 22:35:12.619000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable(int: 2)]
V0618 22:35:12.620000 24001 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
I0618 22:35:12.621000 24001 torch/_dynamo/symbolic_convert.py:3681] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0618 22:35:12.621000 24001 torch/_dynamo/symbolic_convert.py:3685] [0/0] RETURN_VALUE triggered compile
V0618 22:35:12.621000 24001 torch/_dynamo/output_graph.py:1008] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /var/lib/workspace/recipes_source/torch_logs.py, line 42 in fn>], graph_break=False)
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code] TRACED GRAPH
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]     def forward(self, L_x_: "f32[2, 2][2, 1]cuda:0", L_y_: "f32[2, 2][2, 1]cuda:0"):
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_x_ = L_x_
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_y_ = L_y_
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:41 in fn, code: z = x + y
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         z: "f32[2, 2][2, 1]cuda:0" = l_x_ + l_y_;  l_x_ = l_y_ = None
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:42 in fn, code: return z + 2
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         add_1: "f32[2, 2][2, 1]cuda:0" = z + 2;  z = None
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         return (add_1,)
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0618 22:35:12.623000 24001 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
I0618 22:35:12.625000 24001 torch/_dynamo/output_graph.py:1515] [0/0] Step 2: calling compiler function inductor
I0618 22:35:13.747000 24001 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
I0618 22:35:13.752000 24001 torch/_dynamo/output_graph.py:1520] [0/0] Step 2: done compiler function inductor
I0618 22:35:13.754000 24001 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
V0618 22:35:13.754000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[0] 2 None
V0618 22:35:13.754000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[1] 2 None
V0618 22:35:13.755000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[0] 2 None
V0618 22:35:13.755000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[1] 1 None
V0618 22:35:13.755000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].storage_offset() 0 None
V0618 22:35:13.755000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[0] 2 None
V0618 22:35:13.756000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[1] 2 None
V0618 22:35:13.756000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[0] 2 None
V0618 22:35:13.756000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[1] 1 None
V0618 22:35:13.756000 24001 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].storage_offset() 0 None
V0618 22:35:13.757000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[0] == 2
V0618 22:35:13.757000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[1] == 2
V0618 22:35:13.757000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[0] == 2
V0618 22:35:13.758000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[1] == 1
V0618 22:35:13.758000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].storage_offset() == 0
V0618 22:35:13.758000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[0] == 2
V0618 22:35:13.758000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[1] == 2
V0618 22:35:13.759000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[0] == 2
V0618 22:35:13.759000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[1] == 1
V0618 22:35:13.759000 24001 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].storage_offset() == 0
V0618 22:35:13.759000 24001 torch/_dynamo/guards.py:2557] [0/0] [__guards] GUARDS:
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] TREE_GUARD_MANAGER:
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] +- RootGuardManager
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:520 in init_ambient_guards
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0)
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=FrameLocalsGuardAccessor(key='y', framelocals_idx=1)
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING
V0618 22:35:13.760000 24001 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0618 22:35:13.761000 24001 torch/_dynamo/guards.py:2524] [0/0] [__guards] Guard eval latency = 1.00 us
I0618 22:35:13.762000 24001 torch/_dynamo/pgo.py:660] [0/0] put_code_state: no cache key, skipping
I0618 22:35:13.762000 24001 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
V0618 22:35:13.764000 24001 torch/_dynamo/convert_frame.py:1395] skipping: _fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
===================Traced Graph=========================
I0618 22:35:13.765000 24001 torch/_dynamo/__init__.py:112] torch._dynamo.reset
I0618 22:35:13.765000 24001 torch/_dynamo/__init__.py:145] torch._dynamo.reset_code_caches
===================Fusion Decisions=========================
===================Output Code=========================
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] Output code:
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # AOT ID: ['0_inference']
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import torch
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import math
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import random
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import os
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import tempfile
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from math import inf, nan
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from cmath import nanj
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch import device, empty_strided
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton.language as tl
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] aten = torch.ops.aten
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] _quantized = torch.ops._quantized
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] async_compile = AsyncCompile()
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # kernel path: /tmp/torchinductor_ci-user/ld/cld7tar7n7kytdxqq7n73fjc5nptwpbw7wqmdbp24zf62axk3q3a.py
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Source node to ATen node mapping:
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   add_1 => add_1
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   z => add
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Graph fragment:
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 2), kwargs = {})
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton.language as tl
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] triton_helpers.set_driver_to_gpu()
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] @triton_heuristics.pointwise(
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     size_hints={'x': 4},
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     filename=__file__,
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '1E2C16421D4C3DBA4AD92BFC4278A3CB24C43DEDA6EE7FF9E3FBB1DBB80802DB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     min_elem_per_thread=0
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] )
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] @triton.jit
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xnumel = 4
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xmask = xindex < xnumel
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     x0 = xindex
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp2 = tmp0 + tmp1
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp3 = 2.0
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp4 = tmp2 + tmp3
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp4, xmask)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] ''', device_str='cuda')
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] async_compile.wait(globals())
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] del async_compile
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def call(args):
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg0_1, arg1_1 = args
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     args.clear()
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     assert_size_stride(arg0_1, (2, 2), (2, 1))
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     assert_size_stride(arg1_1, (2, 2), (2, 1))
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         torch.cuda.set_device(0)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         stream0 = get_raw_stream(0)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, stream=stream0)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         del arg0_1
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         del arg1_1
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     return (buf0, )
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     fn = lambda: call([arg0_1, arg1_1])
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code] if __name__ == "__main__":
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0618 22:35:13.854000 24001 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0618 22:35:13.861000 24001 torch/_inductor/codecache.py:1094] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/nk/cnk55csixpane7aredk4kvfxz3fx2bb7zgzf4vpzqkzufdznzojb.py
============================================

Conclusion

In this tutorial we introduced the TORCH_LOGS environment variable and python API by experimenting with a small number of the available logging options. To view descriptions of all available options, run any python script which imports torch and set TORCH_LOGS to “help”.

Alternatively, you can view the torch._logging documentation to see descriptions of all available logging options.

For more information on torch.compile, see the torch.compile tutorial.

Total running time of the script: ( 0 minutes 2.571 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources