pytorch_pfn_extras.training.ExtensionsManager

class pytorch_pfn_extras.training.ExtensionsManager(models: Union[torch.nn.modules.module.Module, Dict[str, torch.nn.modules.module.Module]], optimizers: Union[torch.optim.optimizer.Optimizer, Dict[str, torch.optim.optimizer.Optimizer]], max_epochs: int, *, iters_per_epoch: Optional[int], extensions: Optional[List[extension_module.ExtensionLike]] = None, out_dir: str = 'result', stop_trigger: trigger_module.TriggerLike = None, writer: Optional[pytorch_pfn_extras.writing.Writer] = None)

Manages the extensions and the current status.

Parameters
  • models (dict or torch.nn.Module) – Map of string to Module or an actual Module

  • optimizers (dict or torch.Optimizer) – Map of string to Optimizer or an actual Optimizer.

  • max_epochs (int) – Number of epochs in the whole training loop. Ignored if stop_trigger is passed as a kwarg.

  • iters_per_epoch (int) – Number of iterations in one epoch.

  • extensions (list or None) – List of Extentions to be used.

  • out_dir (str) – Output directory (default: result).

  • stop_trigger (trigger object, optional) – to determine wether training has concluded. The default is an interval trigger set to max_epochs

  • writer (writing.Writer object) – Writer that can be used by extensions to write data to custom filesystems.

__init__(models: Union[torch.nn.modules.module.Module, Dict[str, torch.nn.modules.module.Module]], optimizers: Union[torch.optim.optimizer.Optimizer, Dict[str, torch.optim.optimizer.Optimizer]], max_epochs: int, *, iters_per_epoch: Optional[int], extensions: Optional[List[extension_module.ExtensionLike]] = None, out_dir: str = 'result', stop_trigger: trigger_module.TriggerLike = None, writer: Optional[pytorch_pfn_extras.writing.Writer] = None) None

Methods

__init__(models, optimizers, max_epochs, *, …)

extend(extension[, name, trigger, priority, …])

Registers an extension to the manager.

get_extension(name)

Returns the extension of a given name.

load_state_dict(to_load, *[, transform_models])

transform_models is a function that apply a transformation to a model before loading its state.

run_extensions()

run_iteration(*[, step_optimizers])

Context manager to run an iteration.

start_extensions()

state_dict(*[, transform_models])

transform_models is a function that apply a transformation to a model.

Attributes

elapsed_time

epoch

epoch_detail

is_before_training

iteration

models

optimizers

out

stop_trigger

updater