pytorch_pfn_extras.training.extensions.LRScheduler#

class pytorch_pfn_extras.training.extensions.LRScheduler(scheduler, *, stepper=<function _default_stepper>, trigger=(1, 'epoch'), wait_for_first_optimizer_step=False, is_async=True)#

Bases: Extension

Trainer extension to adjust the learning rate using PyTorch’s learning rate scheduler.

This extension calls step() method of the given LR scheduler. (torch.option.lr_scheduler.*). When using ReduceLROnPlateau, the latest reported val/loss value will be used. This behavior can be customized by passing a custom stepper function.

Parameters:
  • scheduler (_LRScheduler or ReduceLROnPlateau) – Any instance of torch.optim.lr_scheduler.*.

  • stepper (callable) – Function that performs the step on the scheduler.

  • trigger (Optional[Union[Trigger, Callable[[ExtensionsManagerProtocol], bool], Tuple[float, str]]]) – Frequency to call this extension.

  • wait_for_first_optimizer_step (bool) – Wait until optimizer.step() is called before invoking scheduler.step(). This can address the issue where optimizer.step() is not called from the first iteration when using GradScaler.

  • is_async (bool) –

Methods

__init__(scheduler, *[, stepper, trigger, ...])

finalize(manager)

Finalizes the extension.

initialize(manager)

Initializes up the manager state.

load_state_dict(state)

on_error(manager, exc, tb)

Handles the error raised during training before finalization.

state_dict()

Serializes the extension state.

step_by_value(key)

Attributes

default_name

Default name of the extension.

is_async

name

needs_model_state

priority

trigger

__call__(manager)#

Invokes the extension.

Implementations should override this operator. This method is called at iterations which the corresponding trigger accepts.

Parameters:

manager (ExtensionsManager) – Manager object to call this operator.

Return type:

None

__init__(scheduler, *, stepper=<function _default_stepper>, trigger=(1, 'epoch'), wait_for_first_optimizer_step=False, is_async=True)#
Parameters:
  • scheduler (Any) –

  • stepper (Any) –

  • trigger (Optional[Union[Trigger, Callable[[ExtensionsManagerProtocol], bool], Tuple[float, str]]]) –

  • wait_for_first_optimizer_step (bool) –

  • is_async (bool) –

Return type:

None

load_state_dict(state)#
Parameters:

state (Dict[str, Any]) –

Return type:

None

state_dict()#

Serializes the extension state.

It is called when a manager that owns this extension is serialized. It serializes nothing by default.

Return type:

Dict[str, Any]

static step_by_value(key)#
Parameters:

key (Optional[str]) –

Return type:

Any