pytorch_pfn_extras.training.extensions.DistributedEvaluator#

class pytorch_pfn_extras.training.extensions.DistributedEvaluator(self, iterator, target, eval_func=None, *, progress_bar=False)#

Bases: Evaluator

An extension to evaluate models on a validation set in a distributed training setup.

In case torch.distributed is used to parallelize training iterations, it is efficient to also run evaluation in parallel by splitting the validation set to each worker process and conduct evaluation separately followed by aggregation of results of each worker, which can be achieved by :class:~`DistributedEvaluator`.

This extension basically behaves similarly to Evaluator, but adds an aggregation step in Evaluator.evaluate(). A summary of evaluation (DictSummary) in each worker process is collected in “all-gather” manner and then accumulated. Therefore all the worker processes must attend the evaluation, i.e., make sure all the processes have a Evaluator extension object configured in the ExtensionManager with the same trigger. All the worker process will get identical evaluation result returned by Evaluator.evaluate() and reported to an observation.

It is necessary to pass a DataLoader with an appropripate sampler which properly splits the validation dataset to each MPI worker process. PyTorch DistributedSampler implements this, but it allows sampler repetition in order to make the number of samples assigned to each process identical. For evaluation purpose it distorts the evaluation result, hence it is recommended to use DistributedValidationSampler instead.

Methods

__init__(iterator, target[, eval_hook, ...])

add_metric(metric_fn)

Adds a custom metric to the evaluator.

eval_func(*args, **kwargs)

evaluate()

Evaluates the model and returns a result dictionary.

finalize(manager)

Finalizes the extension.

get_all_iterators()

Returns a dictionary of all iterators.

get_all_targets()

Returns a dictionary of all target links.

get_iterator(name)

Returns the iterator of the given name.

get_target(name)

Returns the target link of the given name.

initialize(manager)

Initializes up the manager state.

load_state_dict(to_load)

on_error(manager, exc, tb)

Handles the error raised during training before finalization.

state_dict()

Serializes the extension state.

Attributes

default_name

is_async

name

needs_model_state

priority

trigger

Parameters:
  • iterator (Union[DataLoader[Any], Dict[str, DataLoader[Any]]]) –

  • target (Union[Module, Dict[str, Module]]) –

  • eval_hook (Optional[Callable[[Evaluator], None]]) –

  • eval_func (Optional[Callable[[...], Any]]) –

  • kwargs (Any) –

__init__(iterator, target, eval_hook=None, eval_func=None, **kwargs)#
Parameters:
  • iterator (Union[DataLoader[Any], Dict[str, DataLoader[Any]]]) –

  • target (Union[Module, Dict[str, Module]]) –

  • eval_hook (Optional[Callable[[Evaluator], None]]) –

  • eval_func (Optional[Callable[[...], Any]]) –

  • kwargs (Any) –

Return type:

None