Shortcuts

Source code for torchtune.utils._profiler

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib

from typing import ContextManager, Optional

import torch
from torch.profiler import profile


[docs]def profiler( enabled: Optional[bool] = False, output_dir: Optional[str] = "./torchtune_perf_tracing.json", ) -> ContextManager: """ Utility component that wraps around `torch.profiler` to profile model's operators. See https://pytorch.org/docs/stable/profiler.html for more details. The schedule for this profiler is wait 100 steps, warmup 5 steps, trace 5 steps Note: Enabling pytorch profiler may have training speed reduction. Args: enabled (Optional[bool]): Enable pytorch profiler. Default is False. output_dir (Optional[str]): Tracing file output path. Default is "./torchtune_perf_tracing.json". Returns: ContextManager: pytorch profiler context manager """ def trace_handler(prof) -> None: prof.export_chrome_trace(output_dir) return ( profile( activities=[ torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], schedule=torch.profiler.schedule(wait=100, warmup=5, active=5, repeat=1), on_trace_ready=trace_handler, record_shapes=True, profile_memory=False, with_stack=False, ) if enabled else contextlib.nullcontext() )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources