pytorch_pfn_extras.training.ExtensionsManager#

class pytorch_pfn_extras.training.ExtensionsManager(models, optimizers, max_epochs, *, iters_per_epoch, extensions=None, out_dir='result', stop_trigger=None, writer=None, transform_model=<function ExtensionsManager.<lambda>>, enable_profile=False, enable_trace=False, state_objects={})#

Bases: _BaseExtensionsManager

Manages the extensions and the current status.

Parameters:
  • models (dict or torch.nn.Module) – Map of string to Module or an actual Module

  • optimizers (dict or torch.Optimizer) – Map of string to Optimizer or an actual Optimizer.

  • max_epochs (int) – Number of epochs in the whole training loop. Ignored if stop_trigger is passed as a kwarg.

  • iters_per_epoch (int) – Number of iterations in one epoch.

  • extensions (list or None) – List of Extentions to be used.

  • out_dir (str) – Output directory (default: result).

  • stop_trigger (trigger object, optional) – to determine wether training has concluded. The default is an interval trigger set to max_epochs

  • writer (writing.Writer object) – Writer that can be used by extensions to write data to custom filesystems.

  • enable_profile (bool) – Flag to enable/disable profiling of iterations. Default is False.

  • enable_trace (bool) – Flag to enable/disable tracing of iterations. Default is False.

  • transform_model (Callable[[str, Module], Module]) –

  • state_objects (Dict[str, StateObjectProtocol]) –

Methods

__init__(models, optimizers, max_epochs, *, ...)

extend(extension[, name, trigger, priority, ...])

Registers an extension to the manager.

finalize()

get_extension(name)

Returns the extension of a given name.

load_state_dict(to_load)

needs_model_state([iteration])

needs_state_this_iteration()

run_extensions()

run_iteration(*[, step_optimizers])

Context manager to run an iteration.

start_extensions()

state_dict()

Attributes

elapsed_time

epoch

epoch_detail

is_before_training

iteration

models

optimizers

out

raw_models

stop_trigger

updater

__init__(models, optimizers, max_epochs, *, iters_per_epoch, extensions=None, out_dir='result', stop_trigger=None, writer=None, transform_model=<function ExtensionsManager.<lambda>>, enable_profile=False, enable_trace=False, state_objects={})#
Parameters:
  • models (Union[Module, Dict[str, Module]]) –

  • optimizers (Union[Optimizer, Dict[str, Optimizer]]) –

  • max_epochs (int) –

  • iters_per_epoch (int) –

  • extensions (Optional[Sequence[extension_module.ExtensionLike]]) –

  • out_dir (str) –

  • stop_trigger (trigger_module.TriggerLike) –

  • writer (Optional[Writer]) –

  • transform_model (Callable[[str, Module], Module]) –

  • enable_profile (bool) –

  • enable_trace (bool) –

  • state_objects (Dict[str, StateObjectProtocol]) –

Return type:

None

finalize()#
Return type:

None

run_iteration(*, step_optimizers=None)#

Context manager to run an iteration.

This manager can additionally run a step in the specified optimizers names.

Parameters:
  • step_optimizers (list or None) – names of the optimizers

  • step (to call zero_grad and) –

Return type:

Generator[None, None, None]