Shortcuts

(beta) Using TORCH_LOGS python API with torch.compile

Created On: Jan 24, 2024 | Last Updated: Jan 31, 2024 | Last Verified: Nov 05, 2024

Author: Michael Lazos

import logging

This tutorial introduces the TORCH_LOGS environment variable, as well as the Python API, and demonstrates how to apply it to observe the phases of torch.compile.

Note

This tutorial requires PyTorch 2.2.0 or later.

Setup

In this example, we’ll set up a simple Python function which performs an elementwise add and observe the compilation process with TORCH_LOGS Python API.

Note

There is also an environment variable TORCH_LOGS, which can be used to change logging settings at the command line. The equivalent environment variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Skipping because torch.compile is not supported on this device.")
else:
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2


    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()


    separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
    torch._logging.set_logs(output_code=True)
    fn(*inputs)

    separator("")
===================Dynamo Tracing=========================
I0718 17:42:37.951000 22423 torch/_dynamo/utils.py:1603] [0/0] ChromiumEventLogger initialized with id 212c50af-efd1-4db3-bcaf-5d9e0a993c6e
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0] torchdynamo start compiling fn /var/lib/workspace/recipes_source/torch_logs.py:39, stack (elided 5 frames):
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     sys.exit(main())
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 313, in main
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_main(argv)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 195, in make_main
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_mode.run_make_mode(argv[1:])
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return make.run_generic_build(args[0])
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return build_main(args + opts)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 276, in build_main
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 262, in __init__
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     self._init_builder()
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 335, in _init_builder
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     self.events.emit('builder-inited')
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     results.append(listener.handler(self.app, *args))
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 743, in generate_gallery_rst
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     ) = generate_dir_rst(
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 598, in generate_dir_rst
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     results = parallel(
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 599, in <genexpr>
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/conf.py", line 79, in wrapper
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     p.start()
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     self._popen = self._Popen(self)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return _default_context.get_context().Process._Popen(process_obj)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return Popen(process_obj)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     self._launch(process_obj)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     code = process_obj._bootstrap(parent_sentinel=child_r)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     self.run()
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     self._target(*self._args, **self._kwargs)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/conf.py", line 67, in call_fn
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     result = func(*args, **kwargs)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1346, in generate_file_rst
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks, time_elapsed = execute_script(
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1164, in execute_script
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     execute_code_block(
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1020, in execute_code_block
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 865, in _exec_and_get_memory
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     mem_max, _ = call_memory(
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1700, in _sg_call_memory_noop
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     return 0.0, func()
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 783, in __call__
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     exec(self.code, self.fake_main.__dict__)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/recipes_source/torch_logs.py", line 59, in <module>
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]     fn(*inputs)
V0718 17:42:37.952000 22423 torch/_dynamo/convert_frame.py:1003] [0/0]
I0718 17:42:37.957000 22423 torch/_dynamo/symbolic_convert.py:3324] [0/0] Step 1: torchdynamo start tracing fn /var/lib/workspace/recipes_source/torch_logs.py:39
I0718 17:42:37.957000 22423 torch/fx/experimental/symbolic_shapes.py:3334] [0/0] create_env
V0718 17:42:37.960000 22423 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:41 in fn (fn)
V0718 17:42:37.960000 22423 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             z = x + y
V0718 17:42:37.962000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0718 17:42:37.962000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
V0718 17:42:37.962000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0718 17:42:37.963000 22423 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0718 17:42:37.965000 22423 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0718 17:42:37.966000 22423 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='y', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0718 17:42:37.967000 22423 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_y_ L['y'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0718 17:42:37.969000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
V0718 17:42:37.970000 22423 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:42 in fn (fn)
V0718 17:42:37.970000 22423 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             return z + 2
V0718 17:42:37.970000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST z []
V0718 17:42:37.970000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V0718 17:42:37.971000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable(int: 2)]
V0718 17:42:37.972000 22423 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
I0718 17:42:37.972000 22423 torch/_dynamo/symbolic_convert.py:3681] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0718 17:42:37.972000 22423 torch/_dynamo/symbolic_convert.py:3685] [0/0] RETURN_VALUE triggered compile
V0718 17:42:37.973000 22423 torch/_dynamo/output_graph.py:1008] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /var/lib/workspace/recipes_source/torch_logs.py, line 42 in fn>], graph_break=False)
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code] TRACED GRAPH
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]     def forward(self, L_x_: "f32[2, 2][2, 1]cuda:0", L_y_: "f32[2, 2][2, 1]cuda:0"):
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_x_ = L_x_
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_y_ = L_y_
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:41 in fn, code: z = x + y
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         z: "f32[2, 2][2, 1]cuda:0" = l_x_ + l_y_;  l_x_ = l_y_ = None
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:42 in fn, code: return z + 2
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         add_1: "f32[2, 2][2, 1]cuda:0" = z + 2;  z = None
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         return (add_1,)
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0718 17:42:37.974000 22423 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
I0718 17:42:37.976000 22423 torch/_dynamo/output_graph.py:1515] [0/0] Step 2: calling compiler function inductor
I0718 17:42:39.329000 22423 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
I0718 17:42:39.334000 22423 torch/_dynamo/output_graph.py:1520] [0/0] Step 2: done compiler function inductor
I0718 17:42:39.336000 22423 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
V0718 17:42:39.336000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[0] 2 None
V0718 17:42:39.336000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[1] 2 None
V0718 17:42:39.337000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[0] 2 None
V0718 17:42:39.337000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[1] 1 None
V0718 17:42:39.337000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].storage_offset() 0 None
V0718 17:42:39.338000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[0] 2 None
V0718 17:42:39.338000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[1] 2 None
V0718 17:42:39.338000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[0] 2 None
V0718 17:42:39.338000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[1] 1 None
V0718 17:42:39.339000 22423 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].storage_offset() 0 None
V0718 17:42:39.339000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[0] == 2
V0718 17:42:39.340000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[1] == 2
V0718 17:42:39.340000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[0] == 2
V0718 17:42:39.340000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[1] == 1
V0718 17:42:39.340000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].storage_offset() == 0
V0718 17:42:39.341000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[0] == 2
V0718 17:42:39.341000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[1] == 2
V0718 17:42:39.341000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[0] == 2
V0718 17:42:39.342000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[1] == 1
V0718 17:42:39.342000 22423 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].storage_offset() == 0
V0718 17:42:39.342000 22423 torch/_dynamo/guards.py:2557] [0/0] [__guards] GUARDS:
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] TREE_GUARD_MANAGER:
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] +- RootGuardManager
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:520 in init_ambient_guards
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0)
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=FrameLocalsGuardAccessor(key='y', framelocals_idx=1)
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING
V0718 17:42:39.343000 22423 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0718 17:42:39.344000 22423 torch/_dynamo/guards.py:2524] [0/0] [__guards] Guard eval latency = 1.15 us
I0718 17:42:39.345000 22423 torch/_dynamo/pgo.py:660] [0/0] put_code_state: no cache key, skipping
I0718 17:42:39.345000 22423 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
V0718 17:42:39.348000 22423 torch/_dynamo/convert_frame.py:1395] skipping: _fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
===================Traced Graph=========================
I0718 17:42:39.349000 22423 torch/_dynamo/__init__.py:112] torch._dynamo.reset
I0718 17:42:39.349000 22423 torch/_dynamo/__init__.py:145] torch._dynamo.reset_code_caches
===================Fusion Decisions=========================
===================Output Code=========================
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] Output code:
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # AOT ID: ['0_inference']
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import torch
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import math
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import random
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import os
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import tempfile
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from math import inf, nan
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from cmath import nanj
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch import device, empty_strided
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton.language as tl
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] aten = torch.ops.aten
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] _quantized = torch.ops._quantized
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] async_compile = AsyncCompile()
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # kernel path: /tmp/torchinductor_ci-user/ld/cld7tar7n7kytdxqq7n73fjc5nptwpbw7wqmdbp24zf62axk3q3a.py
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Source node to ATen node mapping:
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   add_1 => add_1
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   z => add
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Graph fragment:
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 2), kwargs = {})
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton.language as tl
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] triton_helpers.set_driver_to_gpu()
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] @triton_heuristics.pointwise(
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     size_hints={'x': 4},
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     filename=__file__,
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '1E2C16421D4C3DBA4AD92BFC4278A3CB24C43DEDA6EE7FF9E3FBB1DBB80802DB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     min_elem_per_thread=0
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] )
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] @triton.jit
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xnumel = 4
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xmask = xindex < xnumel
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     x0 = xindex
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp2 = tmp0 + tmp1
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp3 = 2.0
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp4 = tmp2 + tmp3
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp4, xmask)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] ''', device_str='cuda')
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] async_compile.wait(globals())
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] del async_compile
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def call(args):
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg0_1, arg1_1 = args
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     args.clear()
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     assert_size_stride(arg0_1, (2, 2), (2, 1))
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     assert_size_stride(arg1_1, (2, 2), (2, 1))
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         torch.cuda.set_device(0)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         stream0 = get_raw_stream(0)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, stream=stream0)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         del arg0_1
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         del arg1_1
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     return (buf0, )
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     fn = lambda: call([arg0_1, arg1_1])
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code] if __name__ == "__main__":
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0718 17:42:39.439000 22423 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0718 17:42:39.446000 22423 torch/_inductor/codecache.py:1094] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/nk/cnk55csixpane7aredk4kvfxz3fx2bb7zgzf4vpzqkzufdznzojb.py
============================================

Conclusion

In this tutorial we introduced the TORCH_LOGS environment variable and python API by experimenting with a small number of the available logging options. To view descriptions of all available options, run any python script which imports torch and set TORCH_LOGS to “help”.

Alternatively, you can view the torch._logging documentation to see descriptions of all available logging options.

For more information on torch.compile, see the torch.compile tutorial.

Total running time of the script: (0 minutes 2.808 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources