pytorch_pfn_extras.training.extensions.MicroAverage¶
- class pytorch_pfn_extras.training.extensions.MicroAverage(numerator_key, denominator_key, result_key, trigger=(1, 'epoch'))¶
Calculates micro-average ratio.
Give \(N\) batches and values \(\{n_1, \dots, n_N\}\) and \(\{d_1, \dots, d_N\}\), this extension calculates micro-average of these ratio defined as:
\[\frac{\sum_i^N n_i}{\sum_i^N d_i}.\]A user usually uses the number of examples which a system correctly predict as \(n_i\) and the number of total examples in \(i\)-th batch as \(d_i\). This value is called macro-average of precision.
Note that macro-average is defined as:
\[\frac{1}{N}\sum_i^N (n_i / d_i),\]It is same to the micro-average when each mini-batch has the same \(d_i\).
You need to report numerator value (the number of correct examples) and denominator value (the number of examples) in your model.
>>> class MyModel(torch.nn.Module): ... def __call__(self, x, y): ... loss = torch.nn.CrossEntropyLoss(x, y) ... correct = (x.data.argmax(axis=1) == y.data).sum() ... total = len(y.data) ... reporting.report({'correct': correct, 'total': total}, self) ... return loss
And then, make an extension with corresponding reporting keys and register it.
>>> ext = extensions.MicroAverage( ... 'main/correct', 'main/total', 'main/accuracy')
- Parameters
numerator_key (str) – Key string of obserbation storing a numerator value.
denominator_key (str) – Key string of obserbation storing a denominator value.
result_key (str) – Key string of obserbation to store a result.
trigger (Optional[Union[pytorch_pfn_extras.training._trigger_util.Trigger, Callable[[pytorch_pfn_extras.training._manager_protocol.ExtensionsManagerProtocol], bool], Tuple[float, str]]]) – Trigger that decides when to calcurate average. This is distinct from the trigger of this extension itself. If it is a tuple in the form
<int>, 'epoch'
or<int>, 'iteration'
, it is passed toIntervalTrigger
.
- Return type
None
- __init__(numerator_key, denominator_key, result_key, trigger=(1, 'epoch'))¶
- Parameters
numerator_key (str) –
denominator_key (str) –
result_key (str) –
trigger (Optional[Union[pytorch_pfn_extras.training._trigger_util.Trigger, Callable[[pytorch_pfn_extras.training._manager_protocol.ExtensionsManagerProtocol], bool], Tuple[float, str]]]) –
- Return type
None
Methods
__init__
(numerator_key, denominator_key, …)finalize
()Finalizes the extension.
initialize
(manager)Initializes up the manager state.
load_state_dict
(to_load)on_error
(manager, exc, tb)Handles the error raised during training before finalization.
state_dict
()Serializes the extension state.
Attributes
default_name
Default name of the extension.
is_async
name
needs_model_state
priority
trigger