pytorch_pfn_extras.engine.Trainer#

class pytorch_pfn_extras.engine.Trainer(handler, *, evaluator, models, profile=None, **kwargs)#

Bases: object

Methods

__init__(handler, *, evaluator, models[, ...])

extend(extension[, name, trigger, priority, ...])

get_optimizer(name)

is_epoch_last_iter(idx)

load_state_dict(to_load)

run(train_loader[, val_loader, train_len, ...])

Executes the training loop.

set_optimizer(name, optimizer)

state_dict()

Attributes

epoch

epoch_detail

evaluator

is_before_training

iteration

manager

models

optimizers

stop_trigger

Parameters:
__init__(handler, *, evaluator, models, profile=None, **kwargs)#
Parameters:
property epoch: int#
property epoch_detail: float#
property evaluator: Optional[Evaluator]#
extend(extension, name=None, trigger=None, priority=None, *, call_before_training=False, **kwargs)#
Parameters:
  • extension (Union[extension.ExtensionLike, ExtensionEntry]) –

  • name (Optional[str]) –

  • trigger (TriggerLike) –

  • priority (Optional[int]) –

  • call_before_training (bool) –

  • kwargs (Any) –

Return type:

None

get_optimizer(name)#
Parameters:

name (str) –

Return type:

Optimizer

property is_before_training: bool#
is_epoch_last_iter(idx)#
Parameters:

idx (int) –

Return type:

bool

property iteration: int#
load_state_dict(to_load)#
Parameters:

to_load (Dict[str, Any]) –

Return type:

None

property manager: ExtensionsManager#
property models: Mapping[str, Module]#
property optimizers: Mapping[str, Optimizer]#
run(train_loader, val_loader=None, *, train_len=None, eval_len=None)#

Executes the training loop.

Parameters:
  • train_loader (torch.utils.data.DataLoader) – A data loader for training.

  • val_loader (torch.utils.data.DataLoader, optional) – A data loader passed to Evaluator.run().

  • train_len (int, optional) – The number of iterations per one training epoch. The default value is inferred from the size of training data loader.

  • eval_len (int, optional) – The number of iterations per one evaluation epoch, passed to Evaluator.run()

Return type:

None

set_optimizer(name, optimizer)#
Parameters:
  • name (str) –

  • optimizer (Optimizer) –

Return type:

None

state_dict()#
Return type:

Dict[str, Any]

property stop_trigger: Trigger#