pytorch_pfn_extras.training.IgniteExtensionsManager#
- class pytorch_pfn_extras.training.IgniteExtensionsManager(engine, models, optimizers, max_epochs, *, extensions=None, out_dir='result', writer=None, enable_profile=False, enable_trace=False, state_objects={})#
Bases:
_BaseExtensionsManager
Manages extensions and the current status in Ignite training loop.
- Parameters:
engine (ignite.engine.Engine) – Ignite trainer engine
models (dict or torch.nn.Module) – Map of string to Module or an actual Module
optimizers (dict or torch.Optimizer) – Map of string to Optimizer or an actual Optimizer.
max_epochs (int) – Number of epochs in the whole training loop.
extensions (list or None) – List of Extentions to be used.
out_dir (str) – Output directory (default:
result
).writer (writing.Writer object) – Writer that can be used by extensions to write data to custom filesystems.
enable_profile (bool) – Flag to enable/disable profiling of iterations. Default is False.
enable_trace (bool) – Flag to enable/disable tracing of iterations. Default is False.
state_objects (Dict[str, StateObjectProtocol]) –
Methods
__init__
(engine, models, optimizers, ...[, ...])extend
(extension[, name, trigger, priority, ...])Registers an extension to the manager.
get_extension
(name)Returns the extension of a given name.
load_state_dict
(to_load)needs_model_state
([iteration])needs_state_this_iteration
()run_extensions
()start_extensions
()Attributes
elapsed_time
epoch
epoch_detail
is_before_training
iteration
models
optimizers
out
raw_models
stop_trigger
updater
- __init__(engine, models, optimizers, max_epochs, *, extensions=None, out_dir='result', writer=None, enable_profile=False, enable_trace=False, state_objects={})#
- Parameters:
engine (ignite.engine.Engine) –
models (Union[Module, Mapping[str, Module]]) –
optimizers (Union[Optimizer, Mapping[str, Optimizer]]) –
max_epochs (int) –
extensions (Optional[Sequence[extension_module.ExtensionLike]]) –
out_dir (str) –
writer (Optional[Writer]) –
enable_profile (bool) –
enable_trace (bool) –
state_objects (Dict[str, StateObjectProtocol]) –
- Return type:
None
- load_state_dict(to_load)#
- Parameters:
to_load (Dict[str, Any]) –
- Return type:
None
- set_ignite_handlers()#
- Return type:
None
- state_dict()#
- Return type:
Dict[str, Any]