ignite.handlers#
Complete list of handlers#
- class ignite.handlers.Checkpoint(to_save, save_handler, filename_prefix='', score_function=None, score_name=None, n_saved=1, global_step_transform=None, archived=False, filename_pattern=None, include_self=False)[source]#
- Checkpoint handler can be used to periodically save and load objects which have attribute - state_dict`/`load_state_dict. This class can use specific save handlers to store on the disk or a cloud storage, etc. The Checkpoint handler (if used with- DiskSaver) also handles automatically moving data on TPU to CPU before writing the checkpoint.- Parameters
- to_save (Mapping) – Dictionary with the objects to save. Objects should have implemented - state_dictand- load_state_dictmethods. If contains objects of type torch DistributedDataParallel or DataParallel, their internal wrapped model is automatically saved (to avoid additional key- module.in the state dictionary).
- save_handler (callable or - BaseSaveHandler) – Method or callable class to use to save engine and other provided objects. Function receives two objects: checkpoint as a dictionary and filename. If- save_handleris callable class, it can inherit of- BaseSaveHandlerand optionally implement- removemethod to keep a fixed number of saved checkpoints. In case if user needs to save engine’s checkpoint on a disk,- save_handlercan be defined with- DiskSaver.
- filename_prefix (str, optional) – Prefix for the file name to which objects will be saved. See Note for details. 
- score_function (callable, optional) – If not None, it should be a function taking a single argument, - Engineobject, and returning a score (float). Objects with highest scores will be retained.
- score_name (str, optional) – If - score_functionnot None, it is possible to store its value using- score_name. See Notes for more details.
- n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. 
- global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is - (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use- global_step_from_engine().
- archived (bool, optional) – Deprecated argument as models saved by - torch.saveare already compressed.
- filename_pattern (str, optional) – If - filename_patternis provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See Note for details.
- include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in - to_savewith key- checkpointer.
 
 - Note - This class stores a single file as a dictionary of provided objects to save. The filename is defined by - filename_patternand by default has the following structure:- {filename_prefix}_{name}_{suffix}.{ext}where- filename_prefixis the argument passed to the constructor,
- name is the key in - to_saveif a single object is to store, otherwise name is “checkpoint”.
- suffix is composed as following - {global_step}_{score_name}={score}.
 - score_function - score_name - global_step_transform - suffix - None - None - None - {engine.state.iteration}- X - None - None - {score}- X - None - X - {global_step}_{score}- X - X - X - {global_step}_{score_name}={score}- None - None - X - {global_step}- X - X - None - {score_name}={score}- Above global_step defined by the output of global_step_transform and score defined by the output of score_function. - By default, none of - score_function,- score_name,- global_step_transformis defined, then suffix is setup by attached engine’s current iteration. The filename will be {filename_prefix}_{name}_{engine.state.iteration}.{ext}.- For example, - score_name="neg_val_loss"and- score_functionthat returns -loss (as objects with highest scores will be retained), then saved filename will be- {filename_prefix}_{name}_neg_val_loss=-0.1234.pt.- Note - If - filename_patternis given, it will be used to render the filenames.- filename_patternis a string that can contain- {filename_prefix},- {name},- {score},- {score_name}and- {global_step}as templates.- For example, let - filename_pattern="{global_step}-{name}-{score}.pt"then the saved filename will be- 30000-checkpoint-94.pt- Warning: Please, keep in mind that if filename collide with already used one to saved a checkpoint, new checkpoint will not be stored. This means that filename like - checkpoint.ptwill be saved only once and will not be overwritten by newer checkpoints.- Note - To get the last stored filename, handler exposes attribute - last_checkpoint:- handler = Checkpoint(...) ... print(handler.last_checkpoint) > checkpoint_12345.pt - Note - This class is distributed configuration-friendly: it is not required to instantiate the class in rank 0 only process. This class supports automatically distributed configuration and if used with - DiskSaver, checkpoint is stored by rank 0 process.- Warning - When running on XLA devices, it should be run in all processes, otherwise application can get stuck on saving the checkpoint. - # Wrong: # if idist.get_rank() == 0: # handler = Checkpoint(...) # trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) # Correct: handler = Checkpoint(...) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) - Examples - Attach the handler to make checkpoints during training: - from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver trainer = ... model = ... optimizer = ... lr_scheduler = ... to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) trainer.run(data_loader, max_epochs=6) > ["checkpoint_7000.pt", "checkpoint_8000.pt", ] - Attach the handler to an evaluator to save best model during the training according to computed validation metric: - from ignite.engine import Engine, Events from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine trainer = ... evaluator = ... # Setup Accuracy metric computation on evaluator # Run evaluation on epoch completed event # ... def score_function(engine): return engine.state.metrics['accuracy'] to_save = {'model': model} handler = Checkpoint(to_save, DiskSaver('/tmp/models', create_dir=True), n_saved=2, filename_prefix='best', score_function=score_function, score_name="val_acc", global_step_transform=global_step_from_engine(trainer)) evaluator.add_event_handler(Events.COMPLETED, handler) trainer.run(data_loader, max_epochs=10) > ["best_model_9_val_acc=0.77.pt", "best_model_10_val_acc=0.78.pt", ] - static load_objects(to_load, checkpoint, **kwargs)[source]#
- Helper method to apply - load_state_dicton the objects from- to_loadusing states from- checkpoint.- Exemples: - import torch from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint, Checkpoint trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=None, create_dir=True) model = torch.nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) to_save = {"weights": model, "optimizer": optimizer} trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, to_save) trainer.run(torch.randn(10, 1), 5) to_load = to_save checkpoint_fp = "/tmp/models/myprefix_checkpoint_40.pth" checkpoint = torch.load(checkpoint_fp) Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint) - Note - If - to_loadcontains objects of type torch DistributedDataParallel or DataParallel, method- load_state_dictwill applied to their internal wrapped model (- obj.module).- Parameters
- to_load (Mapping) – a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, …} 
- checkpoint (Mapping) – a dictionary with state_dicts to load, e.g. {“model”: model_state_dict, “optimizer”: opt_state_dict}. If to_load contains a single key, then checkpoint can contain directly corresponding state_dict. 
- **kwargs – Keyword arguments accepted for nn.Module.load_state_dict(). Passing strict=False enables the user to load part of the pretrained model (useful for example, in Transfer Learning) 
 
- Return type
- None 
 
 
- class ignite.handlers.checkpoint.BaseSaveHandler[source]#
- Base class for save handlers - Methods to override: - Note - In derived class, please, make sure that in distributed configuration overridden methods are called by a single process. Distributed configuration on XLA devices should be treated slightly differently: for saving checkpoint with xm.save() all processes should pass into the function. Otherwise, application gets stuck. - abstract __call__(checkpoint, filename, metadata=None)[source]#
- Method to save checkpoint with filename. Additionally, metadata dictionary is provided. - Metadata contains: - basename: file prefix (if provided) with checkpoint name, e.g. epoch_checkpoint. 
- score_name: score name if provided, e.g val_acc. 
- priority: checkpoint priority value (higher is better), e.g. 12 or 0.6554435 
 - Parameters
- checkpoint (Mapping) – checkpoint dictionary to save. 
- filename (str) – filename associated with checkpoint. 
- metadata (Mapping, optional) – metadata on checkpoint to save. 
 
- Return type
- None 
 
 
- class ignite.handlers.DiskSaver(dirname, atomic=True, create_dir=True, require_empty=True, **kwargs)[source]#
- Handler that saves input checkpoint on a disk. - Parameters
- dirname (str) – Directory path where the checkpoint will be saved 
- atomic (bool, optional) – if True, checkpoint is serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). 
- create_dir (bool, optional) – if True, will create directory - dirnameif it doesnt exist.
- require_empty (bool, optional) – If True, will raise exception if there are any files in the directory - dirname.
- **kwargs – Accepted keyword arguments for torch.save or xm.save. 
 
 
- class ignite.handlers.ModelCheckpoint(dirname, filename_prefix, save_interval=None, score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, save_as_state_dict=True, global_step_transform=None, archived=False, include_self=False, **kwargs)[source]#
- ModelCheckpoint handler can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider - Checkpoint.- This handler expects two arguments: - an - Engineobject
- a dict mapping names (str) to objects that should be saved to disk. 
 - See Examples for further details. - Warning - Behaviour of this class has been changed since v0.3.0. - Argument - save_as_state_dictis deprecated and should not be used. It is considered as True.- Argument - save_intervalis deprecated and should not be used. Please, use events filtering instead, e.g.- ITERATION_STARTED(every=1000)- There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise. - A single pt file is created instead of multiple files. - Parameters
- dirname (str) – Directory path where objects will be saved. 
- filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of - Checkpointfor more details.
- score_function (callable, optional) – if not None, it should be a function taking a single argument, an - Engineobject, and return a score (float). Objects with highest scores will be retained.
- score_name (str, optional) – if - score_functionnot None, it is possible to store its value using score_name. See Notes for more details.
- n_saved (int, optional) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. 
- atomic (bool, optional) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). 
- require_empty (bool, optional) – If True, will raise exception if there are any files starting with - filename_prefixin the directory- dirname.
- create_dir (bool, optional) – If True, will create directory - dirnameif it does not exist.
- global_step_transform (callable, optional) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use - global_step_from_engine().
- archived (bool, optional) – Deprecated argument as models saved by torch.save are already compressed. 
- include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in - to_savewith key- checkpointer.
- **kwargs – Accepted keyword arguments for torch.save or xm.save in DiskSaver. 
- save_as_state_dict (bool) – 
 
 - Examples - >>> import os >>> from ignite.engine import Engine, Events >>> from ignite.handlers import ModelCheckpoint >>> from torch import nn >>> trainer = Engine(lambda engine, batch: None) >>> handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True) >>> model = nn.Linear(3, 3) >>> trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) >>> trainer.run([0], max_epochs=6) >>> os.listdir('/tmp/models') ['myprefix_mymodel_4.pt', 'myprefix_mymodel_6.pt'] >>> handler.last_checkpoint ['/tmp/models/myprefix_mymodel_6.pt'] 
- class ignite.handlers.EarlyStopping(patience, score_function, trainer, min_delta=0.0, cumulative_delta=False)[source]#
- EarlyStopping handler can be used to stop the training if no improvement after a given number of events. - Parameters
- patience (int) – Number of events to wait if no improvement and then stop the training. 
- score_function (callable) – It should be a function taking a single argument, an - Engineobject, and return a score float. An improvement is considered if the score is higher.
- trainer (Engine) – trainer engine to stop the run if no improvement. 
- min_delta (float, optional) – A minimum increase in the score to qualify as an improvement, i.e. an increase of less than or equal to min_delta, will count as no improvement. 
- cumulative_delta (bool, optional) – It True, min_delta defines an increase since the last patience reset, otherwise, it defines an increase after the last event. Default value is False. 
 
 - Examples: - from ignite.engine import Engine, Events from ignite.handlers import EarlyStopping def score_function(engine): val_loss = engine.state.metrics['nll'] return -val_loss handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) # Note: the handler is attached to an *Evaluator* (runs one epoch on validation dataset). evaluator.add_event_handler(Events.COMPLETED, handler) 
- class ignite.handlers.Timer(average=False)[source]#
- Timer object can be used to measure (average) time between events. - Parameters
- average (bool, optional) – if True, then when - .value()method is called, the returned value will be equal to total time measured, divided by the value of internal counter.
 - step_count#
- internal counter, useful to measure average time, e.g. of processing a single batch. Incremented with the - .step()method.- Type
 
 - Note - When using - Timer(average=True)do not forget to call- timer.step()every time an event occurs. See the examples below.- Examples - Measuring total time of the epoch: - >>> from ignite.handlers import Timer >>> import time >>> work = lambda : time.sleep(0.1) >>> idle = lambda : time.sleep(0.1) >>> t = Timer(average=False) >>> for _ in range(10): ... work() ... idle() ... >>> t.value() 2.003073937026784 - Measuring average time of the epoch: - >>> t = Timer(average=True) >>> for _ in range(10): ... work() ... idle() ... t.step() ... >>> t.value() 0.2003182829997968 - Measuring average time it takes to execute a single - work()call:- >>> t = Timer(average=True) >>> for _ in range(10): ... t.resume() ... work() ... t.pause() ... idle() ... t.step() ... >>> t.value() 0.10016545779653825 - Using the Timer to measure average time it takes to process a single batch of examples: - >>> from ignite.engine import Engine, Events >>> from ignite.handlers import Timer >>> trainer = Engine(training_update_function) >>> timer = Timer(average=True) >>> timer.attach(trainer, ... start=Events.EPOCH_STARTED, ... resume=Events.ITERATION_STARTED, ... pause=Events.ITERATION_COMPLETED, ... step=Events.ITERATION_COMPLETED) - attach(engine, start=Events.STARTED, pause=Events.COMPLETED, resume=None, step=None)[source]#
- Register callbacks to control the timer. - Parameters
- engine (Engine) – Engine that this timer will be attached to. 
- start (Events) – Event which should start (reset) the timer. 
- pause (Events) – Event which should pause the timer. 
- resume (Events, optional) – Event which should resume the timer. 
- step (Events, optional) – Event which should call the step method of the counter. 
 
- Returns
- self (Timer) 
 
 
- class ignite.handlers.TerminateOnNan(output_transform=<function TerminateOnNan.<lambda>>)[source]#
- TerminateOnNan handler can be used to stop the training if the process_function’s output contains a NaN or infinite number or torch.tensor. The output can be of type: number, tensor or collection of them. The training is stopped if there is at least a single number/tensor have NaN or Infinite value. For example, if the output is [1.23, torch.tensor(…), torch.tensor(float(‘nan’))] the handler will stop the training. - Parameters
- output_transform (callable, optional) – a callable that is used to transform the - Engine’s- process_function’s output into a number or torch.tensor or collection of them. This can be useful if, for example, you have a multi-output model and you want to check one or multiple values of the output.
 - Examples: - trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan())