pytorch_pfn_extras.training.extensions.ParameterStatistics

class pytorch_pfn_extras.training.extensions.ParameterStatistics(links, statistics='default', report_params=True, report_grads=True, prefix=None, trigger=(1, 'epoch'), skip_nan_params=False)

An extension to report parameter statistics.

Statistics are collected and reported for a given Module or an iterable of Modules. If a link contains child modules, the statistics are reported separately for each child.

Any function that takes a one-dimensional torch.Tensor and outputs a single or multiple real numbers can be registered to handle the collection of statistics, e.g. numpy.ndarray.mean().

The keys of reported statistics follow the convention of link name followed by parameter name, attribute name and function name, e.g. VGG16Layers/conv1_1/W/data/mean. They are prepended with an optional prefix and appended with integer indices if the statistics generating function return multiple values.

Parameters
  • links (instance or iterable of ~torch.nn.Module) – Module(s) containing the parameters to observe. The link is expected to have a name attribute which is used as a part of the report key.

  • statistics (dict or 'default') – Dictionary with function name to function mappings. The name is a string and is used as a part of the report key. The function is responsible for generating the statistics. If the special value 'default' is specified, the default statistics functions will be used.

  • report_params (bool) – If True, report statistics for parameter values such as weights and biases.

  • report_grads (bool) – If True, report statistics for parameter gradients.

  • prefix (str) – Optional prefix to prepend to the report keys.

  • trigger (Optional[Union[pytorch_pfn_extras.training._trigger_util.Trigger, Callable[[pytorch_pfn_extras.training._manager_protocol.ExtensionsManagerProtocol], bool], Tuple[float, str]]]) – Trigger that decides when to aggregate the results and report the values.

  • skip_nan_params (bool) – If True, statistics are not computed for parameters including NaNs and a single NaN value is immediately reported instead. Otherwise, this extension will simply try to compute the statistics without performing any checks for NaNs.

Note

The default statistic functions are as follows:

  • 'mean' (xp.mean(x))

  • 'std' (xp.std(x))

  • 'min' (xp.min(x))

  • 'max' (xp.max(x))

  • 'zeros' (xp.count_nonzero(x == 0))

  • 'percentile' (xp.percentile(x, (0.13, 2.28, 15.87, 50, 84.13, 97.72, 99.87)))

__init__(links, statistics='default', report_params=True, report_grads=True, prefix=None, trigger=(1, 'epoch'), skip_nan_params=False)
Parameters
  • links (Any) –

  • statistics (Any) –

  • report_params (bool) –

  • report_grads (bool) –

  • prefix (Optional[str]) –

  • trigger (Optional[Union[pytorch_pfn_extras.training._trigger_util.Trigger, Callable[[pytorch_pfn_extras.training._manager_protocol.ExtensionsManagerProtocol], bool], Tuple[float, str]]]) –

  • skip_nan_params (bool) –

Methods

__init__(links[, statistics, report_params, …])

finalize()

Finalizes the extension.

initialize(manager)

Initializes up the manager state.

load_state_dict(to_load)

on_error(manager, exc, tb)

Handles the error raised during training before finalization.

register_statistics(name, function)

Register a function to compute a certain statistic.

state_dict()

Serializes the extension state.

Attributes

default_name

default_statistics

is_async

name

needs_model_state

priority

report_key_template

trigger