pytorch_pfn_extras.training.manager.ExtensionsManager#
- class pytorch_pfn_extras.training.manager.ExtensionsManager(models, optimizers, max_epochs, *, iters_per_epoch, extensions=None, out_dir='result', stop_trigger=None, writer=None, transform_model=<function ExtensionsManager.<lambda>>, enable_profile=False, enable_trace=False, state_objects={})#
Bases:
_BaseExtensionsManagerManages the extensions and the current status.
- Parameters:
models (dict or torch.nn.Module) – Map of string to Module or an actual Module
optimizers (dict or torch.Optimizer) – Map of string to Optimizer or an actual Optimizer.
max_epochs (int) – Number of epochs in the whole training loop. Ignored if stop_trigger is passed as a kwarg.
iters_per_epoch (int) – Number of iterations in one epoch.
extensions (list or None) – List of Extentions to be used.
out_dir (str) – Output directory (default:
result).stop_trigger (trigger object, optional) – to determine wether training has concluded. The default is an interval trigger set to max_epochs
writer (writing.Writer object) – Writer that can be used by extensions to write data to custom filesystems.
enable_profile (bool) – Flag to enable/disable profiling of iterations. Default is False.
enable_trace (bool) – Flag to enable/disable tracing of iterations. Default is False.
transform_model (Callable[[str, Module], Module]) –
state_objects (Dict[str, StateObjectProtocol]) –
Methods
__init__(models, optimizers, max_epochs, *, ...)extend(extension[, name, trigger, priority, ...])Registers an extension to the manager.
finalize()get_extension(name)Returns the extension of a given name.
load_state_dict(to_load)needs_model_state([iteration])needs_state_this_iteration()run_extensions()run_iteration(*[, step_optimizers])Context manager to run an iteration.
start_extensions()state_dict()Attributes
elapsed_timeepochepoch_detailis_before_trainingiterationmodelsoptimizersoutraw_modelsstop_triggerupdater- __init__(models, optimizers, max_epochs, *, iters_per_epoch, extensions=None, out_dir='result', stop_trigger=None, writer=None, transform_model=<function ExtensionsManager.<lambda>>, enable_profile=False, enable_trace=False, state_objects={})#
- Parameters:
models (Union[Module, Dict[str, Module]]) –
optimizers (Union[Optimizer, Dict[str, Optimizer]]) –
max_epochs (int) –
iters_per_epoch (int) –
extensions (Optional[Sequence[extension_module.ExtensionLike]]) –
out_dir (str) –
stop_trigger (trigger_module.TriggerLike) –
writer (Optional[Writer]) –
transform_model (Callable[[str, Module], Module]) –
enable_profile (bool) –
enable_trace (bool) –
state_objects (Dict[str, StateObjectProtocol]) –
- Return type:
None
- finalize()#
- Return type:
None
- run_iteration(*, step_optimizers=None)#
Context manager to run an iteration.
This manager can additionally run a step in the specified optimizers names.
- Parameters:
step_optimizers (list or None) – names of the optimizers
step (to call zero_grad and) –
- Return type:
Generator[None, None, None]