ModelCheckpoint#
- class ignite.handlers.checkpoint.ModelCheckpoint(dirname, filename_prefix='', score_function=None, score_name=None, n_saved=1, atomic=True, require_empty=True, create_dir=True, global_step_transform=None, filename_pattern=None, include_self=False, greater_or_equal=False, save_on_rank=0, **kwargs)[source]#
- ModelCheckpoint handler, inherits from - Checkpoint, can be used to periodically save objects to disk only. If needed to store checkpoints to another storage type, please consider- Checkpoint. It also provides last_checkpoint attribute to show the last saved checkpoint.- This handler expects two arguments: - an - Engineobject
- a dict mapping names (str) to objects that should be saved to disk. 
 - See Examples for further details. - Warning - Behaviour of this class has been changed since v0.3.0. - There is no more internal counter that has been used to indicate the number of save actions. User could see its value step_number in the filename, e.g. {filename_prefix}_{name}_{step_number}.pt. Actually, step_number is replaced by current engine’s epoch if score_function is specified and current iteration otherwise. - A single pt file is created instead of multiple files. - Parameters
- dirname (Union[str, Path]) – Directory path where objects will be saved. 
- filename_prefix (str) – Prefix for the file names to which objects will be saved. See Notes of - Checkpointfor more details.
- score_function (Optional[Callable]) – if not None, it should be a function taking a single argument, an - Engineobject, and return a score (float). Objects with highest scores will be retained.
- score_name (Optional[str]) – if - score_functionnot None, it is possible to store its value using score_name. See Examples of- Checkpointfor more details.
- n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. 
- atomic (bool) – If True, objects are serialized to a temporary file, and then moved to final destination, so that files are guaranteed to not be damaged (for example if exception occurs during saving). 
- require_empty (bool) – If True, will raise exception if there are any files starting with - filename_prefixin the directory- dirname.
- create_dir (bool) – If True, will create directory - dirnameif it does not exist.
- global_step_transform (Optional[Callable]) – global step transform function to output a desired global step. Input of the function is (engine, event_name). Output of function should be an integer. Default is None, global_step based on attached engine. If provided, uses function output as global_step. To setup global step from another engine, please use - global_step_from_engine().
- filename_pattern (Optional[str]) – If - filename_patternis provided, this pattern will be used to render checkpoint filenames. If the pattern is not defined, the default pattern would be used. See- Checkpointfor details.
- include_self (bool) – Whether to include the state_dict of this object in the checkpoint. If True, then there must not be another object in - to_savewith key- checkpointer.
- greater_or_equal (bool) – if True, the latest equally scored model is stored. Otherwise, the first model. Default, False. 
- save_on_rank (int) – Which rank to save the objects on, in the distributed configuration. Used to instantiate a - DiskSaverand is also passed to the parent class.
- kwargs (Any) – Accepted keyword arguments for torch.save or xm.save in DiskSaver. 
 
 - Changed in version 0.4.2: Accept - kwargsfor torch.save or xm.save- Changed in version 0.4.9: Accept - filename_patternand- greater_or_equalfor parity with- Checkpoint- Changed in version 0.4.10: Added save_on_rank arg to save objects on this rank in a distributed configuration - Examples - import os from ignite.engine import Engine, Events from ignite.handlers import ModelCheckpoint from torch import nn trainer = Engine(lambda engine, batch: None) handler = ModelCheckpoint('/tmp/models', 'myprefix', n_saved=2, create_dir=True, require_empty=False) model = nn.Linear(3, 3) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=2), handler, {'mymodel': model}) trainer.run([0, 1, 2, 3, 4], max_epochs=6) print(sorted(os.listdir('/tmp/models'))) print(handler.last_checkpoint) - ['myprefix_mymodel_20.pt', 'myprefix_mymodel_30.pt'] /tmp/models/myprefix_mymodel_30.pt - Methods