Distributed SnapshotΒΆ

To take snapshots when using torch.distributed the only needed step is to provide the saver_rank keyword argument to the regular snapshot extension.

# saver_rank is the MPI rank which will write the actual snapshot.
snapshot = extensions.snapshot(saver_rank=saver_rank)

To resume the training, snapshots are loaded in every worker by using the ExtensionsManager.load_state_dict method, or the extensions.snapshot autoload keyword argument.