Shortcuts

(beta) Using TORCH_LOGS python API with torch.compile

Created On: Jan 24, 2024 | Last Updated: Jan 31, 2024 | Last Verified: Nov 05, 2024

Author: Michael Lazos

import logging

This tutorial introduces the TORCH_LOGS environment variable, as well as the Python API, and demonstrates how to apply it to observe the phases of torch.compile.

Note

This tutorial requires PyTorch 2.2.0 or later.

Setup

In this example, we’ll set up a simple Python function which performs an elementwise add and observe the compilation process with TORCH_LOGS Python API.

Note

There is also an environment variable TORCH_LOGS, which can be used to change logging settings at the command line. The equivalent environment variable setting is shown for each example.

import torch

# exit cleanly if we are on a device that doesn't support torch.compile
if torch.cuda.get_device_capability() < (7, 0):
    print("Skipping because torch.compile is not supported on this device.")
else:
    @torch.compile()
    def fn(x, y):
        z = x + y
        return z + 2


    inputs = (torch.ones(2, 2, device="cuda"), torch.zeros(2, 2, device="cuda"))


# print separator and reset dynamo
# between each example
    def separator(name):
        print(f"==================={name}=========================")
        torch._dynamo.reset()


    separator("Dynamo Tracing")
# View dynamo tracing
# TORCH_LOGS="+dynamo"
    torch._logging.set_logs(dynamo=logging.DEBUG)
    fn(*inputs)

    separator("Traced Graph")
# View traced graph
# TORCH_LOGS="graph"
    torch._logging.set_logs(graph=True)
    fn(*inputs)

    separator("Fusion Decisions")
# View fusion decisions
# TORCH_LOGS="fusion"
    torch._logging.set_logs(fusion=True)
    fn(*inputs)

    separator("Output Code")
# View output code generated by inductor
# TORCH_LOGS="output_code"
    torch._logging.set_logs(output_code=True)
    fn(*inputs)

    separator("")
===================Dynamo Tracing=========================
I0707 22:34:51.769000 23610 torch/_dynamo/utils.py:1603] [0/0] ChromiumEventLogger initialized with id 50896f60-e51e-458a-960a-7586257ea430
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0] torchdynamo start compiling fn /var/lib/workspace/recipes_source/torch_logs.py:39, stack (elided 5 frames):
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/bin/sphinx-build", line 8, in <module>
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     sys.exit(main())
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 313, in main
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_main(argv)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 195, in make_main
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return make_mode.run_make_mode(argv[1:])
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return make.run_generic_build(args[0])
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return build_main(args + opts)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 276, in build_main
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     app = Sphinx(args.sourcedir, args.confdir, args.outputdir,
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 262, in __init__
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     self._init_builder()
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 335, in _init_builder
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     self.events.emit('builder-inited')
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     results.append(listener.handler(self.app, *args))
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 491, in generate_gallery_rst
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     ) = generate_dir_rst(
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 431, in generate_dir_rst
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     intro, title, cost = generate_file_rst(
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/conf.py", line 79, in wrapper
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     p.start()
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     self._popen = self._Popen(self)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return _default_context.get_context().Process._Popen(process_obj)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return Popen(process_obj)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     self._launch(process_obj)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     code = process_obj._bootstrap(parent_sentinel=child_r)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     self.run()
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     self._target(*self._args, **self._kwargs)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/conf.py", line 67, in call_fn
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     result = func(*args, **kwargs)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1027, in generate_file_rst
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks, time_elapsed = execute_script(script_blocks,
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 945, in execute_script
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     output_blocks.append(execute_code_block(
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 810, in execute_code_block
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     is_last_expr, mem_max = _exec_and_get_memory(
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 676, in _exec_and_get_memory
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     mem_max, _ = gallery_conf['call_memory'](
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 223, in call_memory
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     return 0., func()
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 600, in __call__
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     exec(self.code, self.fake_main.__dict__)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]   File "/var/lib/workspace/recipes_source/torch_logs.py", line 59, in <module>
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]     fn(*inputs)
V0707 22:34:51.770000 23610 torch/_dynamo/convert_frame.py:1003] [0/0]
I0707 22:34:51.775000 23610 torch/_dynamo/symbolic_convert.py:3324] [0/0] Step 1: torchdynamo start tracing fn /var/lib/workspace/recipes_source/torch_logs.py:39
I0707 22:34:51.775000 23610 torch/fx/experimental/symbolic_shapes.py:3334] [0/0] create_env
V0707 22:34:51.778000 23610 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:41 in fn (fn)
V0707 22:34:51.778000 23610 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             z = x + y
V0707 22:34:51.780000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST x []
V0707 22:34:51.780000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST y [LazyVariableTracker()]
V0707 22:34:51.780000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [LazyVariableTracker(), LazyVariableTracker()]
V0707 22:34:51.782000 23610 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['x'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='x', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0707 22:34:51.783000 23610 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_x_ L['x'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0707 22:34:51.784000 23610 torch/_dynamo/variables/builder.py:3025] [0/0] wrap_to_fake L['y'] (2, 2) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None, None], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='y', is_input=True, dynamism=None, is_derefed_cell_contents=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0707 22:34:51.785000 23610 torch/_dynamo/output_graph.py:2271] [0/0] create_graph_input L_y_ L['y'] FakeTensor(..., device='cuda:0', size=(2, 2)) at debug_level 0 before=False
V0707 22:34:51.788000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE STORE_FAST z [TensorVariable()]
V0707 22:34:51.788000 23610 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source] TRACE starts_line /var/lib/workspace/recipes_source/torch_logs.py:42 in fn (fn)
V0707 22:34:51.788000 23610 torch/_dynamo/symbolic_convert.py:1216] [0/0] [__trace_source]             return z + 2
V0707 22:34:51.788000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_FAST z []
V0707 22:34:51.789000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE LOAD_CONST 2 [TensorVariable()]
V0707 22:34:51.789000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE BINARY_ADD None [TensorVariable(), ConstantVariable(int: 2)]
V0707 22:34:51.790000 23610 torch/_dynamo/symbolic_convert.py:1239] [0/0] [__trace_bytecode] TRACE RETURN_VALUE None [TensorVariable()]
I0707 22:34:51.791000 23610 torch/_dynamo/symbolic_convert.py:3681] [0/0] Step 1: torchdynamo done tracing fn (RETURN_VALUE)
V0707 22:34:51.791000 23610 torch/_dynamo/symbolic_convert.py:3685] [0/0] RETURN_VALUE triggered compile
V0707 22:34:51.791000 23610 torch/_dynamo/output_graph.py:1008] [0/0] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /var/lib/workspace/recipes_source/torch_logs.py, line 42 in fn>], graph_break=False)
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code] TRACED GRAPH
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  ===== __compiled_fn_1 =====
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]  /usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]     def forward(self, L_x_: "f32[2, 2][2, 1]cuda:0", L_y_: "f32[2, 2][2, 1]cuda:0"):
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_x_ = L_x_
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         l_y_ = L_y_
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:41 in fn, code: z = x + y
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         z: "f32[2, 2][2, 1]cuda:0" = l_x_ + l_y_;  l_x_ = l_y_ = None
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]          # File: /var/lib/workspace/recipes_source/torch_logs.py:42 in fn, code: return z + 2
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         add_1: "f32[2, 2][2, 1]cuda:0" = z + 2;  z = None
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]         return (add_1,)
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
V0707 22:34:51.793000 23610 torch/_dynamo/output_graph.py:1408] [0/0] [__graph_code]
I0707 22:34:51.795000 23610 torch/_dynamo/output_graph.py:1515] [0/0] Step 2: calling compiler function inductor
I0707 22:34:53.204000 23610 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
I0707 22:34:53.209000 23610 torch/_dynamo/output_graph.py:1520] [0/0] Step 2: done compiler function inductor
I0707 22:34:53.211000 23610 torch/fx/experimental/symbolic_shapes.py:4734] [0/0] produce_guards
V0707 22:34:53.211000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[0] 2 None
V0707 22:34:53.212000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].size()[1] 2 None
V0707 22:34:53.212000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[0] 2 None
V0707 22:34:53.212000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].stride()[1] 1 None
V0707 22:34:53.212000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['x'].storage_offset() 0 None
V0707 22:34:53.213000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[0] 2 None
V0707 22:34:53.213000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].size()[1] 2 None
V0707 22:34:53.213000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[0] 2 None
V0707 22:34:53.213000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].stride()[1] 1 None
V0707 22:34:53.214000 23610 torch/fx/experimental/symbolic_shapes.py:4954] [0/0] track_symint L['y'].storage_offset() 0 None
V0707 22:34:53.214000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[0] == 2
V0707 22:34:53.214000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].size()[1] == 2
V0707 22:34:53.215000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[0] == 2
V0707 22:34:53.215000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].stride()[1] == 1
V0707 22:34:53.215000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['x'].storage_offset() == 0
V0707 22:34:53.215000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[0] == 2
V0707 22:34:53.216000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].size()[1] == 2
V0707 22:34:53.216000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[0] == 2
V0707 22:34:53.216000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].stride()[1] == 1
V0707 22:34:53.216000 23610 torch/fx/experimental/symbolic_shapes.py:5156] [0/0] Skipping guard L['y'].storage_offset() == 0
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2557] [0/0] [__guards] GUARDS:
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] TREE_GUARD_MANAGER:
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] +- RootGuardManager
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- DEFAULT_DEVICE: utils_device.CURRENT_DEVICE == None                           # _dynamo/output_graph.py:520 in init_ambient_guards
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GLOBAL_STATE: ___check_global_state()
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- TORCH_FUNCTION_MODE_STACK: ___check_torch_function_mode_stack()
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['x'], accessed_by=FrameLocalsGuardAccessor(key='x', framelocals_idx=0)
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['x'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['x'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING: check_no_aliasing(L['x'], L['y'])
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | +- GuardManager: source=L['y'], accessed_by=FrameLocalsGuardAccessor(key='y', framelocals_idx=1)
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- TENSOR_MATCH: check_tensor(L['y'], Tensor, DispatchKeySet(CUDA, BackendSelect, ADInplaceOrView, AutogradCUDA), torch.float32, device=0, requires_grad=False, size=[2, 2], stride=[2, 1])  # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_HASATTR: hasattr(L['y'], '_dynamo_dynamic_indices') == False           # z = x + y  # ar/lib/workspace/recipes_source/torch_logs.py:41 in fn
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards] | | +- NO_TENSOR_ALIASING
V0707 22:34:53.217000 23610 torch/_dynamo/guards.py:2495] [0/0] [__guards]
V0707 22:34:53.219000 23610 torch/_dynamo/guards.py:2524] [0/0] [__guards] Guard eval latency = 1.08 us
I0707 22:34:53.219000 23610 torch/_dynamo/pgo.py:660] [0/0] put_code_state: no cache key, skipping
I0707 22:34:53.219000 23610 torch/_dynamo/convert_frame.py:1121] [0/0] run_gc_after_compile: running gc
V0707 22:34:53.222000 23610 torch/_dynamo/convert_frame.py:1395] skipping: _fn (reason: in skipfiles, file: /usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py)
===================Traced Graph=========================
I0707 22:34:53.223000 23610 torch/_dynamo/__init__.py:112] torch._dynamo.reset
I0707 22:34:53.223000 23610 torch/_dynamo/__init__.py:145] torch._dynamo.reset_code_caches
===================Fusion Decisions=========================
===================Output Code=========================
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] Output code:
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # AOT ID: ['0_inference']
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from ctypes import c_void_p, c_long, c_int
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import torch
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import math
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import random
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import os
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import tempfile
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from math import inf, nan
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from cmath import nanj
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.hooks import run_intermediate_hooks
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.utils import maybe_profile
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.codegen.memory_planning import _align as align
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch import device, empty_strided
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.async_compile import AsyncCompile
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.select_algorithm import extern_kernels
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.codegen.multi_kernel import MultiKernelCall
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton.language as tl
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._C import _cuda_getCurrentRawStream as get_raw_stream
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] aten = torch.ops.aten
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] inductor_ops = torch.ops.inductor
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] _quantized = torch.ops._quantized
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] assert_size_stride = torch._C._dynamo.guards.assert_size_stride
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] alloc_from_pool = torch.ops.inductor._alloc_from_pool
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] async_compile = AsyncCompile()
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # kernel path: /tmp/torchinductor_ci-user/ld/cld7tar7n7kytdxqq7n73fjc5nptwpbw7wqmdbp24zf62axk3q3a.py
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Source node to ATen node mapping:
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   add_1 => add_1
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   z => add
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] # Graph fragment:
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg1_1), kwargs = {})
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] #   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add, 2), kwargs = {})
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] triton_poi_fused_add_0 = async_compile.triton('triton_poi_fused_add_0', '''
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] import triton.language as tl
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime import triton_helpers, triton_heuristics
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] triton_helpers.set_driver_to_gpu()
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] @triton_heuristics.pointwise(
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     size_hints={'x': 4},
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     filename=__file__,
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='cuda', index=0, multi_processor_count=80, cc=86, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, warp_size=32), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}]},
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_0', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'num_load': 2, 'num_reduction': 0, 'backend_hash': '1E2C16421D4C3DBA4AD92BFC4278A3CB24C43DEDA6EE7FF9E3FBB1DBB80802DB', 'are_deterministic_algorithms_enabled': False, 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False},
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     min_elem_per_thread=0
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] )
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] @triton.jit
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def triton_poi_fused_add_0(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xnumel = 4
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xoffset = tl.program_id(0) * XBLOCK
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xindex = xoffset + tl.arange(0, XBLOCK)[:]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     xmask = xindex < xnumel
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     x0 = xindex
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp0 = tl.load(in_ptr0 + (x0), xmask)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp1 = tl.load(in_ptr1 + (x0), xmask)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp2 = tmp0 + tmp1
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp3 = 2.0
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tmp4 = tmp2 + tmp3
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     tl.store(out_ptr0 + (x0), tmp4, xmask)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] ''', device_str='cuda')
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] async_compile.wait(globals())
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] del async_compile
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def call(args):
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg0_1, arg1_1 = args
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     args.clear()
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     assert_size_stride(arg0_1, (2, 2), (2, 1))
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     assert_size_stride(arg1_1, (2, 2), (2, 1))
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     with torch.cuda._DeviceGuard(0):
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         torch.cuda.set_device(0)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         buf0 = empty_strided_cuda((2, 2), (2, 1), torch.float32)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         # Topologically Sorted Source Nodes: [z, add_1], Original ATen: [aten.add]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         stream0 = get_raw_stream(0)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         triton_poi_fused_add_0.run(arg0_1, arg1_1, buf0, 4, stream=stream0)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         del arg0_1
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]         del arg1_1
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     return (buf0, )
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] def benchmark_compiled_module(times=10, repeat=10):
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._dynamo.testing import rand_strided
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._inductor.utils import print_performance
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg0_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     arg1_1 = rand_strided((2, 2), (2, 1), device='cuda:0', dtype=torch.float32)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     fn = lambda: call([arg0_1, arg1_1])
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     return print_performance(fn, times=times, repeat=repeat)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code] if __name__ == "__main__":
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     from torch._inductor.wrapper_benchmark import compiled_module_main
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]     compiled_module_main('None', benchmark_compiled_module)
V0707 22:34:53.315000 23610 torch/_inductor/codecache.py:1093] [0/0] [__output_code]
V0707 22:34:53.321000 23610 torch/_inductor/codecache.py:1094] [0/0] [__output_code] Output code written to: /tmp/torchinductor_ci-user/nk/cnk55csixpane7aredk4kvfxz3fx2bb7zgzf4vpzqkzufdznzojb.py
============================================

Conclusion

In this tutorial we introduced the TORCH_LOGS environment variable and python API by experimenting with a small number of the available logging options. To view descriptions of all available options, run any python script which imports torch and set TORCH_LOGS to “help”.

Alternatively, you can view the torch._logging documentation to see descriptions of all available logging options.

For more information on torch.compile, see the torch.compile tutorial.

Total running time of the script: ( 0 minutes 2.874 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources