from collections import OrderedDict
import torch
from ignite.engine import Engine, Events
from ignite.handlers import Timer
[docs]class BasicTimeProfiler:
    """
    BasicTimeProfiler can be used to profile the handlers,
    events, data loading and data processing times.
    Examples:
    .. code-block:: python
        from ignite.contrib.handlers import BasicTimeProfiler
        trainer = Engine(train_updater)
        # Create an object of the profiler and attach an engine to it
        profiler = BasicTimeProfiler()
        profiler.attach(trainer)
        @trainer.on(Events.EPOCH_COMPLETED)
        def log_intermediate_results():
            profiler.print_results(profiler.get_results())
        trainer.run(dataloader, max_epochs=3)
        profiler.write_results('path_to_dir/time_profiling.csv')
    """
    events_to_ignore = [
        Events.EXCEPTION_RAISED,
        Events.TERMINATE,
        Events.TERMINATE_SINGLE_EPOCH,
        Events.DATALOADER_STOP_ITERATION,
    ]
    def __init__(self):
        self._dataflow_timer = Timer()
        self._processing_timer = Timer()
        self._event_handlers_timer = Timer()
        self.dataflow_times = None
        self.processing_times = None
        self.event_handlers_times = None
        self._events = [
            Events.EPOCH_STARTED,
            Events.EPOCH_COMPLETED,
            Events.ITERATION_STARTED,
            Events.ITERATION_COMPLETED,
            Events.GET_BATCH_STARTED,
            Events.GET_BATCH_COMPLETED,
            Events.COMPLETED,
        ]
        self._fmethods = [
            self._as_first_epoch_started,
            self._as_first_epoch_completed,
            self._as_first_iter_started,
            self._as_first_iter_completed,
            self._as_first_get_batch_started,
            self._as_first_get_batch_completed,
            self._as_first_completed,
        ]
        self._lmethods = [
            self._as_last_epoch_started,
            self._as_last_epoch_completed,
            self._as_last_iter_started,
            self._as_last_iter_completed,
            self._as_last_get_batch_started,
            self._as_last_get_batch_completed,
            self._as_last_completed,
        ]
    def _reset(self, num_epochs, total_num_iters):
        self.dataflow_times = torch.zeros(total_num_iters)
        self.processing_times = torch.zeros(total_num_iters)
        self.event_handlers_times = {
            Events.STARTED: torch.zeros(1),
            Events.COMPLETED: torch.zeros(1),
            Events.EPOCH_STARTED: torch.zeros(num_epochs),
            Events.EPOCH_COMPLETED: torch.zeros(num_epochs),
            Events.ITERATION_STARTED: torch.zeros(total_num_iters),
            Events.ITERATION_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_COMPLETED: torch.zeros(total_num_iters),
            Events.GET_BATCH_STARTED: torch.zeros(total_num_iters),
        }
    def _as_first_started(self, engine):
        if hasattr(engine.state.dataloader, "__len__"):
            num_iters_per_epoch = len(engine.state.dataloader)
        else:
            num_iters_per_epoch = engine.state.epoch_length
        self.max_epochs = engine.state.max_epochs
        self.total_num_iters = self.max_epochs * num_iters_per_epoch
        self._reset(self.max_epochs, self.total_num_iters)
        self.event_handlers_names = {
            e: [
                h.__qualname__ if hasattr(h, "__qualname__") else h.__class__.__name__
                for (h, _, _) in engine._event_handlers[e]
                if "BasicTimeProfiler." not in repr(h)  # avoid adding internal handlers into output
            ]
            for e in Events
            if e not in self.events_to_ignore
        }
        # Setup all other handlers:
        engine._event_handlers[Events.STARTED].append((self._as_last_started, (engine,), {}))
        for e, m in zip(self._events, self._fmethods):
            engine._event_handlers[e].insert(0, (m, (engine,), {}))
        for e, m in zip(self._events, self._lmethods):
            engine._event_handlers[e].append((m, (engine,), {}))
        # Let's go
        self._event_handlers_timer.reset()
    def _as_last_started(self, engine):
        self.event_handlers_times[Events.STARTED][0] = self._event_handlers_timer.value()
    def _as_first_epoch_started(self, engine):
        self._event_handlers_timer.reset()
    def _as_last_epoch_started(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_STARTED][e] = t
    def _as_first_get_batch_started(self, engine):
        self._event_handlers_timer.reset()
        self._dataflow_timer.reset()
    def _as_last_get_batch_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_STARTED][i] = t
    def _as_first_get_batch_completed(self, engine):
        self._event_handlers_timer.reset()
    def _as_last_get_batch_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.GET_BATCH_COMPLETED][i] = t
        d = self._dataflow_timer.value()
        self.dataflow_times[i] = d
        self._dataflow_timer.reset()
    def _as_first_iter_started(self, engine):
        self._event_handlers_timer.reset()
    def _as_last_iter_started(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_STARTED][i] = t
        self._processing_timer.reset()
    def _as_first_iter_completed(self, engine):
        t = self._processing_timer.value()
        i = engine.state.iteration - 1
        self.processing_times[i] = t
        self._event_handlers_timer.reset()
    def _as_last_iter_completed(self, engine):
        t = self._event_handlers_timer.value()
        i = engine.state.iteration - 1
        self.event_handlers_times[Events.ITERATION_COMPLETED][i] = t
    def _as_first_epoch_completed(self, engine):
        self._event_handlers_timer.reset()
    def _as_last_epoch_completed(self, engine):
        t = self._event_handlers_timer.value()
        e = engine.state.epoch - 1
        self.event_handlers_times[Events.EPOCH_COMPLETED][e] = t
    def _as_first_completed(self, engine):
        self._event_handlers_timer.reset()
    def _as_last_completed(self, engine):
        self.event_handlers_times[Events.COMPLETED][0] = self._event_handlers_timer.value()
        # Remove added handlers:
        engine.remove_event_handler(self._as_last_started, Events.STARTED)
        for e, m in zip(self._events, self._fmethods):
            engine.remove_event_handler(m, e)
        for e, m in zip(self._events, self._lmethods):
            engine.remove_event_handler(m, e)
    def attach(self, engine):
        if not isinstance(engine, Engine):
            raise TypeError("Argument engine should be ignite.engine.Engine, " "but given {}".format(type(engine)))
        if not engine.has_event_handler(self._as_first_started):
            engine._event_handlers[Events.STARTED].insert(0, (self._as_first_started, (engine,), {}))
    @staticmethod
    def _compute_basic_stats(data):
        # compute on non-zero data:
        data = data[data > 0]
        out = [("total", torch.sum(data).item() if len(data) > 0 else "not yet triggered")]
        if len(data) > 1:
            out += [
                ("min/index", (torch.min(data).item(), torch.argmin(data).item())),
                ("max/index", (torch.max(data).item(), torch.argmax(data).item())),
                ("mean", torch.mean(data).item()),
                ("std", torch.std(data).item()),
            ]
        return OrderedDict(out)
[docs]    def get_results(self):
        """
        Method to fetch the aggregated profiler results after the engine is run
        .. code-block:: python
            results = profiler.get_results()
        """
        total_eh_time = sum([(self.event_handlers_times[e]).sum() for e in Events if e not in self.events_to_ignore])
        return OrderedDict(
            [
                ("processing_stats", self._compute_basic_stats(self.processing_times)),
                ("dataflow_stats", self._compute_basic_stats(self.dataflow_times)),
                (
                    "event_handlers_stats",
                    dict(
                        [
                            (str(e.name).replace(".", "_"), self._compute_basic_stats(self.event_handlers_times[e]))
                            for e in Events
                            if e not in self.events_to_ignore
                        ]
                        + [("total_time", total_eh_time)]
                    ),
                ),
                (
                    "event_handlers_names",
                    {str(e.name).replace(".", "_") + "_names": v for e, v in self.event_handlers_names.items()},
                ),
            ]
        ) 
[docs]    def write_results(self, output_path):
        """
        Method to store the unaggregated profiling results to a csv file
        .. code-block:: python
            profiler.write_results('path_to_dir/awesome_filename.csv')
        Example output:
        .. code-block:: text
            -----------------------------------------------------------------
            epoch iteration processing_stats dataflow_stats Event_STARTED ...
            1.0     1.0        0.00003         0.252387        0.125676
            1.0     2.0        0.00029         0.252342        0.125123
        """
        try:
            import pandas as pd
        except ImportError:
            print("Need pandas to write results as files")
            return
        iters_per_epoch = self.total_num_iters // self.max_epochs
        epochs = torch.arange(self.max_epochs, dtype=torch.float32).repeat_interleave(iters_per_epoch) + 1
        iterations = torch.arange(self.total_num_iters, dtype=torch.float32) + 1
        processing_stats = self.processing_times
        dataflow_stats = self.dataflow_times
        event_started = self.event_handlers_times[Events.STARTED].repeat_interleave(self.total_num_iters)
        event_completed = self.event_handlers_times[Events.COMPLETED].repeat_interleave(self.total_num_iters)
        event_epoch_started = self.event_handlers_times[Events.EPOCH_STARTED].repeat_interleave(iters_per_epoch)
        event_epoch_completed = self.event_handlers_times[Events.EPOCH_COMPLETED].repeat_interleave(iters_per_epoch)
        event_iter_started = self.event_handlers_times[Events.ITERATION_STARTED]
        event_iter_completed = self.event_handlers_times[Events.ITERATION_COMPLETED]
        event_batch_started = self.event_handlers_times[Events.GET_BATCH_STARTED]
        event_batch_completed = self.event_handlers_times[Events.GET_BATCH_COMPLETED]
        results_dump = torch.stack(
            [
                epochs,
                iterations,
                processing_stats,
                dataflow_stats,
                event_started,
                event_completed,
                event_epoch_started,
                event_epoch_completed,
                event_iter_started,
                event_iter_completed,
                event_batch_started,
                event_batch_completed,
            ],
            dim=1,
        ).numpy()
        results_df = pd.DataFrame(
            data=results_dump,
            columns=[
                "epoch",
                "iteration",
                "processing_stats",
                "dataflow_stats",
                "Event_STARTED",
                "Event_COMPLETED",
                "Event_EPOCH_STARTED",
                "Event_EPOCH_COMPLETED",
                "Event_ITERATION_STARTED",
                "Event_ITERATION_COMPLETED",
                "Event_GET_BATCH_STARTED",
                "Event_GET_BATCH_COMPLETED",
            ],
        )
        results_df.to_csv(output_path, index=False) 
[docs]    @staticmethod
    def print_results(results):
        """
        Method to print the aggregated results from the profiler
        .. code-block:: python
            profiler.print_results(results)
        Example output:
        .. code-block:: text
             ----------------------------------------------------
            | Time profiling stats (in seconds):                 |
             ----------------------------------------------------
            total  |  min/index  |  max/index  |  mean  |  std
            Processing function:
            157.46292 | 0.01452/1501 | 0.26905/0 | 0.07730 | 0.01258
            Dataflow:
            6.11384 | 0.00008/1935 | 0.28461/1551 | 0.00300 | 0.02693
            Event handlers:
            2.82721
            - Events.STARTED: []
            0.00000
            - Events.EPOCH_STARTED: []
            0.00006 | 0.00000/0 | 0.00000/17 | 0.00000 | 0.00000
            - Events.ITERATION_STARTED: ['PiecewiseLinear']
            0.03482 | 0.00001/188 | 0.00018/679 | 0.00002 | 0.00001
            - Events.ITERATION_COMPLETED: ['TerminateOnNan']
            0.20037 | 0.00006/866 | 0.00089/1943 | 0.00010 | 0.00003
            - Events.EPOCH_COMPLETED: ['empty_cuda_cache', 'training.<locals>.log_elapsed_time', ]
            2.57860 | 0.11529/0 | 0.14977/13 | 0.12893 | 0.00790
            - Events.COMPLETED: []
            not yet triggered
        """
        def to_str(v):
            if isinstance(v, str):
                return v
            elif isinstance(v, tuple):
                return "{:.5f}/{}".format(v[0], v[1])
            return "{:.5f}".format(v)
        def odict_to_str(d):
            out = " | ".join([to_str(v) for v in d.values()])
            return out
        others = {
            k: odict_to_str(v) if isinstance(v, OrderedDict) else v for k, v in results["event_handlers_stats"].items()
        }
        others.update(results["event_handlers_names"])
        output_message = """
 ----------------------------------------------------
| Time profiling stats (in seconds):                 |
 ----------------------------------------------------
total  |  min/index  |  max/index  |  mean  |  std
Processing function:
{processing_stats}
Dataflow:
{dataflow_stats}
Event handlers:
{total_time:.5f}
- Events.STARTED: {STARTED_names}
{STARTED}
- Events.EPOCH_STARTED: {EPOCH_STARTED_names}
{EPOCH_STARTED}
- Events.ITERATION_STARTED: {ITERATION_STARTED_names}
{ITERATION_STARTED}
- Events.ITERATION_COMPLETED: {ITERATION_COMPLETED_names}
{ITERATION_COMPLETED}
- Events.EPOCH_COMPLETED: {EPOCH_COMPLETED_names}
{EPOCH_COMPLETED}
- Events.COMPLETED: {COMPLETED_names}
{COMPLETED}
""".format(
            processing_stats=odict_to_str(results["processing_stats"]),
            dataflow_stats=odict_to_str(results["dataflow_stats"]),
            **others,
        )
        print(output_message)
        return output_message