pytorch_pfn_extras.training.extensions.MicroAverage#

class pytorch_pfn_extras.training.extensions.MicroAverage(numerator_key, denominator_key, result_key, trigger=(1, 'epoch'))#

Bases: Extension

Calculates micro-average ratio.

Give \(N\) batches and values \(\{n_1, \dots, n_N\}\) and \(\{d_1, \dots, d_N\}\), this extension calculates micro-average of these ratio defined as:

\[\frac{\sum_i^N n_i}{\sum_i^N d_i}.\]

A user usually uses the number of examples which a system correctly predict as \(n_i\) and the number of total examples in \(i\)-th batch as \(d_i\). This value is called macro-average of precision.

Note that macro-average is defined as:

\[\frac{1}{N}\sum_i^N (n_i / d_i),\]

It is same to the micro-average when each mini-batch has the same \(d_i\).

You need to report numerator value (the number of correct examples) and denominator value (the number of examples) in your model.

>>> class MyModel(torch.nn.Module):
...     def __call__(self, x, y):
...         loss = torch.nn.CrossEntropyLoss(x, y)
...         correct = (x.data.argmax(axis=1) == y.data).sum()
...         total = len(y.data)
...         reporting.report({'correct': correct, 'total': total}, self)
...         return loss

And then, make an extension with corresponding reporting keys and register it.

>>> ext = extensions.MicroAverage(
...     'main/correct', 'main/total', 'main/accuracy')
Parameters:
  • numerator_key (str) – Key string of obserbation storing a numerator value.

  • denominator_key (str) – Key string of obserbation storing a denominator value.

  • result_key (str) – Key string of obserbation to store a result.

  • trigger (Optional[Union[Trigger, Callable[[ExtensionsManagerProtocol], bool], Tuple[float, str]]]) – Trigger that decides when to calcurate average. This is distinct from the trigger of this extension itself. If it is a tuple in the form <int>, 'epoch' or <int>, 'iteration', it is passed to IntervalTrigger.

Methods

__init__(numerator_key, denominator_key, ...)

finalize(manager)

Finalizes the extension.

initialize(manager)

Initializes up the manager state.

load_state_dict(to_load)

on_error(manager, exc, tb)

Handles the error raised during training before finalization.

state_dict()

Serializes the extension state.

Attributes

default_name

Default name of the extension.

is_async

name

needs_model_state

priority

trigger

__call__(manager)#

Invokes the extension.

Implementations should override this operator. This method is called at iterations which the corresponding trigger accepts.

Parameters:

manager (ExtensionsManager) – Manager object to call this operator.

Return type:

None

__init__(numerator_key, denominator_key, result_key, trigger=(1, 'epoch'))#
Parameters:
  • numerator_key (str) –

  • denominator_key (str) –

  • result_key (str) –

  • trigger (Optional[Union[Trigger, Callable[[ExtensionsManagerProtocol], bool], Tuple[float, str]]]) –

Return type:

None

load_state_dict(to_load)#
Parameters:

to_load (Dict[str, Any]) –

Return type:

None

priority: int = 200#
state_dict()#

Serializes the extension state.

It is called when a manager that owns this extension is serialized. It serializes nothing by default.

Return type:

Dict[str, Any]