pytorch_pfn_extras.training.extensions.snapshot#

pytorch_pfn_extras.training.extensions.snapshot(savefun=None, filename='snapshot_iter_{.iteration}', *, target=None, condition=None, writer=None, snapshot_on_error=False, n_retains=-1, autoload=False, saver_rank=None, snapshot_mode=SnapshotMode.DEFAULT)#

Returns a trainer extension to take snapshots of the trainer.

This extension serializes the manager object and saves it to the output directory. It is used to support resuming the training loop from the saved state.

This extension is called once per epoch by default. To take a snapshot at a different interval, a trigger object specifying the required interval can be passed along with this extension to the extend() method of the manager.

The default priority is -100, which is lower than that of most built-in extensions.

Parameters:
  • savefun (Optional[Any]) – Function to save the manager. It takes two arguments: the output file path and the manager object. It is torch.save() by default. If writer is specified, this argument must be None.

  • filename (str) – Name of the file into which the manager is serialized. It can be a format string, where the manager object is passed to the str.format() method.

  • target (Optional[Any]) – Object to serialize. If it is not specified, it will be the manager object.

  • condition (Optional[Any]) – Condition object. It must be a callable object that returns boolean without any arguments. If it returns True, the snapshot will be done. If not, it will be skipped. The default is a function that always returns True.

  • writer (Optional[Writer]) – Writer object. It must be a callable object. See below for the list of built-in writers. If savefun is other than None, this argument must be None. In that case, a SimpleWriter object instantiated with specified savefun argument will be used.

  • snapshot_on_error (bool) – Whether to take a snapshot in case training loop has been failed.

  • n_retains (int) – Number of snapshot files to retain through the cleanup. Must be a positive integer for any cleanup to take place. Automatic deletion of old snapshots only works when the filename is string.

  • autoload (bool) – With this enabled, the extension automatically finds the latest snapshot and loads the data to the target. Automatic loading only works when the filename is a string. It is assumed that snapshots are generated by torch.save() .

  • saver_rank (int) – If defined, the snapshot will be taken by only one rank when running in distributed mode and restored by all.

  • snapshot_mode (SnapshotMode) –

    If SnapshotModel.DEFAULT is specified, it provides a snapshot feature that operates in single-process mode. However, if saver_rank is specified, it provides a snapshot feature that operates in a distributed execution environment.

    If SnapshotModel.DISTRIBUTED is specified, it provides a snapshot feature

    that operates in a distributed execution environment. saver_rank must be specified simultaneously. In this mode, only the specified saver_rank will create a snapshot.

    If SnapshotModel.SHARDED is specified, it provides a snapshot feature that

    operates in a distributed execution environment. saver_rank must be specified simultaneously. In this mode, all ranks create a snapshot. It creates an appropriate snapshot when the state_dict holds a sharded value (e.g. FullyShardedDataParallel).

Returns:

Snapshot extension object.

Return type:

_Snapshot

Using asynchronous writers

By specifying writer argument, writing operations can be made asynchronous, hiding I/O overhead of snapshots.

>>> from pytorch_pfn_extras.training import extensions
>>> from pytorch_pfn_extras import writing
>>> writer = writing.ProcessWriter()
>>> manager.extend(extensions.snapshot(writer=writer), trigger=(1, 'epoch'))

To change the format, you can pass a saving function as savefun argument of the writer.

>>> from pytorch_pfn_extras.training import extensions
>>> from pytorch_pfn_extras import writing
>>> writer = writing.ProcessWriter(
...     savefun=torch.save)
>>> manager.extend(extensions.snapshot(writer=writer), trigger=(1, 'epoch'))

This is the list of built-in snapshot writers.