Source code for ignite.contrib.handlers.custom_events
import warnings
from ignite.engine import EventEnum, Events, State
[docs]class CustomPeriodicEvent:
    """DEPRECATED. Use filtered events instead.
    Handler to define a custom periodic events as a number of elapsed iterations/epochs
    for an engine.
    When custom periodic event is created and attached to an engine, the following events are fired:
    1) K iterations is specified:
    - `Events.ITERATIONS_<K>_STARTED`
    - `Events.ITERATIONS_<K>_COMPLETED`
    1) K epochs is specified:
    - `Events.EPOCHS_<K>_STARTED`
    - `Events.EPOCHS_<K>_COMPLETED`
    Examples:
    .. code-block:: python
        from ignite.engine import Engine, Events
        from ignite.contrib.handlers import CustomPeriodicEvent
        # Let's define an event every 1000 iterations
        cpe1 = CustomPeriodicEvent(n_iterations=1000)
        cpe1.attach(trainer)
        # Let's define an event every 10 epochs
        cpe2 = CustomPeriodicEvent(n_epochs=10)
        cpe2.attach(trainer)
        @trainer.on(cpe1.Events.ITERATIONS_1000_COMPLETED)
        def on_every_1000_iterations(engine):
            # run a computation after 1000 iterations
            # ...
            print(engine.state.iterations_1000)
        @trainer.on(cpe2.Events.EPOCHS_10_STARTED)
        def on_every_10_epochs(engine):
            # run a computation every 10 epochs
            # ...
            print(engine.state.epochs_10)
    Args:
        n_iterations (int, optional): number iterations of the custom periodic event
        n_epochs (int, optional): number iterations of the custom periodic event. Argument is optional, but only one,
            either n_iterations or n_epochs should defined.
    """
    def __init__(self, n_iterations=None, n_epochs=None):
        warnings.warn(
            "CustomPeriodicEvent is deprecated since 0.4.0 and will be removed in 0.5.0. Use filtered events instead.",
            DeprecationWarning,
        )
        if n_iterations is not None:
            if not isinstance(n_iterations, int):
                raise TypeError("Argument n_iterations should be an integer")
            if n_iterations < 1:
                raise ValueError("Argument n_iterations should be positive")
        if n_epochs is not None:
            if not isinstance(n_epochs, int):
                raise TypeError("Argument n_epochs should be an integer")
            if n_epochs < 1:
                raise ValueError("Argument n_epochs should be positive")
        if (n_iterations is None and n_epochs is None) or (n_iterations and n_epochs):
            raise ValueError("Either n_iterations or n_epochs should be defined")
        if n_iterations:
            prefix = "iterations"
            self.state_attr = "iteration"
            self.period = n_iterations
        if n_epochs:
            prefix = "epochs"
            self.state_attr = "epoch"
            self.period = n_epochs
        self.custom_state_attr = "{}_{}".format(prefix, self.period)
        event_name = "{}_{}".format(prefix.upper(), self.period)
        setattr(
            self,
            "Events",
            EventEnum("Events", " ".join(["{}_STARTED".format(event_name), "{}_COMPLETED".format(event_name)])),
        )
        # Update State.event_to_attr
        for e in self.Events:
            State.event_to_attr[e] = self.custom_state_attr
        # Create aliases
        self._periodic_event_started = getattr(self.Events, "{}_STARTED".format(event_name))
        self._periodic_event_completed = getattr(self.Events, "{}_COMPLETED".format(event_name))
    def _on_started(self, engine):
        setattr(engine.state, self.custom_state_attr, 0)
    def _on_periodic_event_started(self, engine):
        if getattr(engine.state, self.state_attr) % self.period == 1:
            setattr(engine.state, self.custom_state_attr, getattr(engine.state, self.custom_state_attr) + 1)
            engine.fire_event(self._periodic_event_started)
    def _on_periodic_event_completed(self, engine):
        if getattr(engine.state, self.state_attr) % self.period == 0:
            engine.fire_event(self._periodic_event_completed)
    def attach(self, engine):
        engine.register_events(*self.Events)
        engine.add_event_handler(Events.STARTED, self._on_started)
        engine.add_event_handler(
            getattr(Events, "{}_STARTED".format(self.state_attr.upper())), self._on_periodic_event_started
        )
        engine.add_event_handler(
            getattr(Events, "{}_COMPLETED".format(self.state_attr.upper())), self._on_periodic_event_completed
        )