API Reference#

Package#

pytorch_pfn_extras

Training Loop#

Trainer#

engine.create_trainer(models, optimizers, ...)

Creates a trainer object.

engine.create_evaluator(models, *[, ...])

Creates an evaluator object.

handler.BaseLogic([options])

handler.Logic([model_name, options])

A set of methods that defines the training logic.

handler.BaseHandler(logic, options, *args, ...)

Base class of Handler.

handler.Handler(logic, entry_runtime, options)

A set of callback functions to perform device-specific operations.

runtime.BaseRuntime(device_spec, options)

A base class for collections of device-specific callback functions.

runtime.PyTorchRuntime(device_spec, options)

A collections of callback functions for the devices that PyTorch supports by default.

Extensions Manager#

training.ExtensionsManager(models, ...[, ...])

Manages the extensions and the current status.

training.IgniteExtensionsManager(engine, ...)

Manages extensions and the current status in Ignite training loop.

Extensions#

training.extension.make_extension([trigger, ...])

Decorator to make given function into an extension.

training.extension.Extension()

Base class of extensions.

training.extension.ExtensionEntry(extension, *)

Extension and options.

training.extensions.BestValue(key, compare)

Extension traces the best value of a specific key in the observation.

training.extensions.Evaluator(self, ...[, ...])

An extension to evaluate models on a validation set.

training.extensions.LogReport([keys, ...])

An extension to output the accumulated results to a log file.

training.extensions.MaxValue(key[, trigger])

Extension traces the maximum value of a specific key in the observation.

training.extensions.MicroAverage(...[, trigger])

Calculates micro-average ratio.

training.extensions.MinValue(key[, trigger])

Extension traces the maximum value of a specific key in the observation.

training.extensions.observe_lr(optimizer[, ...])

Returns an extension to record the learning rate.

training.extensions.observe_value(...)

Returns an extension to continuously record a value.

training.extensions.ParameterStatistics(links)

An extension to report parameter statistics.

training.extensions.PlotReport(y_keys[, ...])

An extension to output plots.

training.extensions.PrintReport([entries, ...])

An extension to print the accumulated results.

training.extensions.ProgressBar([...])

An extension to print a progress bar and recent training status.

training.extensions.ProfileReport([...])

Writes the profile results to a file.

training.extensions.snapshot([savefun, ...])

Returns a trainer extension to take snapshots of the trainer.

training.extensions.Slack(channel[, msg, ...])

An extension to communicate with Slack.

training.extensions.SlackWebhook(url[, msg, ...])

An extension to communicate with Slack using Incoming Webhook.

training.extensions.VariableStatisticsPlot(targets)

An extension to plot statistics for Tensors.

Triggers#

training.triggers.EarlyStoppingTrigger(self)

Trigger for Early Stopping

training.triggers.IntervalTrigger(period, unit)

Trigger based on a fixed interval.

training.triggers.ManualScheduleTrigger(...)

Trigger invoked at specified point(s) of iterations or epochs.

training.triggers.BestValueTrigger(key, compare)

Trigger invoked when specific value becomes best.

training.triggers.MaxValueTrigger(key[, trigger])

Trigger invoked when specific value becomes maximum.

training.triggers.MinValueTrigger(key[, trigger])

Trigger invoked when specific value becomes minimum.

training.triggers.OnceTrigger([call_on_resume])

Trigger based on the starting point of the iteration.

training.triggers.TimeTrigger(period)

Trigger based on a fixed time interval.

Reporting#

reporting.Reporter()

Object to which observed values are reported.

reporting.report(values[, observer])

Reports observed values with the current reporter object.

reporting.report_scope(observation)

Returns a report scope with the current reporter.

Logging#

logging.get_logger(name)

Returns a child logger to be used by applications.

Profiler#

profiler.TimeSummary.report(tag[, use_cuda])

Context manager to automatically report execution times.

profiler.clear_tracer()

Resets the status of the global tracer.

profiler.enable_global_trace(enable)

Enable or disable tracing for all the threads.

profiler.enable_thread_trace(enable)

Enable or disable tracing for the current thread.

profiler.get_tracer([tracer_cls])

Gets the current global tracer.

profiler.ChromeTracer([max_event_count, ...])

Tracer object that outputs a timeline in Chrome format.

profiler.TraceableDataset(dataset, tag[, tracer])

Utility class to trace a Dataset inside the DataLoader worker threads.

Distributed Training#

nn.parallel.DistributedDataParallel(module)

Module for distributed data parallelism

distributed.initialize_ompi_environment(*[, ...])

Initialize torch.distributed environments with values taken from OpenMPI.

Check Pointing#

utils.checkpoint

Lazy Modules#

nn.Ensure(*[, shape, dtype, broadcastable, ...])

Module to check the shape of a tensor.

nn.ensure(tensor[, shape, dtype, ...])

Checks the shape and type of a tensor.

nn.LazyLinear(in_features, *args, **kwargs)

Linear module with lazy weight initialization.

nn.LazyConv1d(in_channels, *args, **kwargs)

Conv1d module with lazy weight initialization.

nn.LazyConv2d(in_channels, *args, **kwargs)

Conv2d module with lazy weight initialization.

nn.LazyConv3d(in_channels, *args, **kwargs)

Conv3d module with lazy weight initialization.

nn.LazyBatchNorm1d(num_features, *args, **kwargs)

BatchNorm1d module with lazy weight initialization.

nn.LazyBatchNorm2d(num_features, *args, **kwargs)

BatchNorm2d module with lazy weight initialization.

nn.LazyBatchNorm3d(num_features, *args, **kwargs)

BatchNorm3d module with lazy weight initialization.

ONNX#

Export#

onnx.export(model, args, f[, return_output, ...])

Export model into ONNX Graph.

onnx.export_testcase(model, args, out_dir, *)

Export model and I/O tensors of the model in protobuf format.

Annotation#

onnx.annotate(**attrs)

Annotation parameters to the target function.

onnx.apply_annotation(fn, *args, **attrs)

Annotation applier to the target function

onnx.scoped_anchor(**attrs)

Add anchor node to the scoped modules

onnx.export(model, args, f[, return_output, ...])

Export model into ONNX Graph.

onnx.export_testcase(model, args, out_dir, *)

Export model and I/O tensors of the model in protobuf format.

Datasets#

dataset.SharedDataset(sm_size[, cache_type])

Dataset that caches the load samples in shared memory

dataset.TabularDataset(*args, **kwds)

An abstract class that represents tabular dataset.

dataset.ItemNotFoundException

Config#

config.Config(config[, types])

config_types.optuna_types(trial)

config_types.load_path_with_optuna_types(...)

NumPy/CuPy Compatibility#

from_ndarray(ndarray)

Creates a torch.Tensor from a numpy.ndarray or cupy.ndarray.

as_ndarray(tensor)

Creates a numpy.ndarray or cupy.ndarray from torch.Tensor.

get_xp(obj)

Returns a module of ndarray implementation (numpy or cupy) for the given obj.

as_numpy_dtype(torch_dtype)

Returns NumPy dtype for the given PyTorch dtype.

from_numpy_dtype(numpy_dtype)

Returns PyTorch dtype for the given NumPy dtype.

cuda.stream(stream)

Context-manager that selects a given stream.

cuda.use_torch_mempool_in_cupy()

Use the PyTorch memory pool in CuPy.

cuda.use_default_mempool_in_cupy()

Use the default memory pool in CuPy.