pytorch_pfn_extras.training.extension.Extension#

class pytorch_pfn_extras.training.extension.Extension#

Bases: object

Base class of extensions.

An extension is a callable object that takes the manager object as the argument. It also provides some default configurations as its attributes, e.g. the default trigger and the default priority. This class provides a set of typical default values for these attributes.

There are three ways to define users’ own extensions: inheriting this class, decorating closures by make_extension(), or using any callable including lambda functions as extensions. Decorator can slightly reduce the overhead and is much easier to use, while this class provides more flexibility (for example, it can have methods to configure the behavior). Using a lambda function allows one-line coding for simple purposes, but users have to specify the configurations as arguments to ExtensionsManager.extend(). For a callable not inheriting this class, the default configurations of this class are used unless the user explicitly specifies them in ExtensionsManager.extend() method.

trigger#

Default value of trigger for this extension. It is set to (1, 'iteration') by default.

Type:

TriggerLike

priority#

Default priority of the extension. It is set to PRIORITY_READER by default.

Type:

int

~Extension.name

Name of the extension. It is set to None by default. This value will be overwritten when registering an extension to a manager. See pytorch_pfn_extras.ExtensionsManager.extend() for details.

Methods

__init__()

finalize(manager)

Finalizes the extension.

initialize(manager)

Initializes up the manager state.

load_state_dict(to_load)

on_error(manager, exc, tb)

Handles the error raised during training before finalization.

state_dict()

Serializes the extension state.

Attributes

default_name

Default name of the extension.

is_async

name

needs_model_state

priority

trigger

__call__(manager)#

Invokes the extension.

Implementations should override this operator. This method is called at iterations which the corresponding trigger accepts.

Parameters:

manager (ExtensionsManager) – Manager object to call this operator.

Return type:

Any

property default_name: str#

Default name of the extension.

It is the name of the class by default. Implementation can override this property, or provide a class attribute to hide it.

finalize(manager)#

Finalizes the extension.

This method is called at the end of the training loop.

Parameters:

manager (ExtensionsManagerProtocol) –

Return type:

None

initialize(manager)#

Initializes up the manager state.

This method is called before entering the training loop. An extension modifying the state of ExtensionsManager can override this method to initialize it.

When the manager has been restored from a snapshot, this method has to recover an appropriate part of the state of the manager.

Parameters:

manager (ExtensionsManager) – Manager object to call this extension.

Return type:

None

is_async = False#
load_state_dict(to_load)#
Parameters:

to_load (Dict[str, Any]) –

Return type:

None

name: Optional[str] = None#
needs_model_state = False#
on_error(manager, exc, tb)#

Handles the error raised during training before finalization.

This method is called when an exception is thrown during the training loop, before finalize. An extension that needs different error handling from finalize, can override this method to handle errors.

Parameters:
  • manager (ExtensionsManager) –

  • extension. (Manager object to call this) –

  • exc (Exception) – arbitrary exception thrown during update loop.

  • tb (traceback) – traceback object of the exception

Return type:

None

priority: int = 100#
state_dict()#

Serializes the extension state.

It is called when a manager that owns this extension is serialized. It serializes nothing by default.

Return type:

Dict[str, Any]

trigger: TriggerLike = (1, 'iteration')#