pytorch-pfn-extras

pytorch-pfn-extras (PPE) is a collection of supplementary components to accelerate research and development in PyTorch.

User Guide

Trainer (techinical preview)

Trainer and Evaluator

Note

The Trainer/Evaluator APIs are currently under technical preview and may subject to change in the future versions.

The Trainer and Evaluator provide the device-agnostic training framework for PyTorch. These APIs abstract the training process using different runtimes, handlers, and logics.

Concepts
  • Trainer (ppe.engine.create_trainer()) abstracts the training loop, built on top of the ExtensionsManager.

  • Evaluator (ppe.engine.create_evaluator()) abstracts the evaluation step and invoked from the Trainer (usually once in every epoch).

  • Runtime (ppe.runtime.BaseRuntime) represents an environment used to execute models. Device-specific implementations will reside here. PPE provides the default Runtime that supports the PyTorch-native devices (ppe.runtime.PyTorchRuntime).

  • Handler (ppe.handler.Handler) is a layer to support device-agnostic training. This is considered as a low-level API and in most cases users can just use the Handler provided by PPE.

  • Logic (ppe.handler.Logic) is a set of callback functions that define the training logic (optimizer.zero_grad(), forward, backward, optimizer.step()). You can inherit the class and define your own training flow in case you need more complex training processes such as GAN.

  • Model is a torch.nn.Module used for training and evaluation, whose inputs are dicts or keyword arguments and outputs of the forward pass is a dict.

Note that the default logic will perform backward in tensors returned by model.forward so you will need to perform the loss calculation inside the model itself.

Trainer at a glance
import torch
import torch.nn.functional as F

import pytorch_pfn_extras as ppe


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.w = torch.nn.LazyLinear(1)

    def forward(self, *, x, target):
        y = self.w(x)
        loss = F.nll_loss(y, target)
        prefix = 'train' if self.training else 'val'
        ppe.reporting.report({f'{prefix}/loss': loss.item()})
        return {'loss': loss}


model = MyModel()
optim = torch.optim.SGD(model.parameters(), lr=0.01)

extensions = [
    ppe.training.extensions.LogReport(),
    ppe.training.extensions.ProgressBar(),
    ppe.training.extensions.PrintReport(
        ['epoch', 'iteration', 'train/loss', 'val/loss']),
]

device = 'cuda:0'  # or any other PyTorch devices ('cpu', etc.) or PPE runtime names
epochs = 10
trainer = ppe.engine.create_trainer(
    model,
    optim,
    epochs,
    evaluator=ppe.engine.create_evaluator(
        model,
        device=device,
        progress_bar=True,
    ),
    device=device,
    extensions=extensions,
)

# Move the model to the device. This is almost equivalent to
# `model.to(device)`, but supports PPE runtimes as well as the PyTorch's
# built-in devices.
ppe.to(model, device)

# Using dummy data to illustrate the minimal working example.
# Notice that dict keys match with the kwargs of the forward method.
train_loader = torch.utils.data.DataLoader(
    [{'x': torch.rand(10, 64), 'target': torch.tensor([1])} for _ in range(1)],
    num_workers=8)
val_loader = torch.utils.data.DataLoader(
    [{'x': torch.rand(10, 64), 'target': torch.tensor([1])} for _ in range(1)],
    num_workers=8)

trainer.run(train_loader, val_loader)
Snapshot

To obtain and save the trained model for later use you can use the Snapshot extension, or directly invoke state_dict on the trainer itself.

Handler

The ppe.handler.Handler object is used to help the trainer and evaluator objects in the Logic and Runtime manipulation. This class should ideally never be overriden by the user if the desired functionality can be achieved through subclassing BaseLogic or BaseRuntime.

The handler object’s main responsibility is to inspect all the submodules of a module to obtain the runtimes they have associated, and then execute their callbacks accordingly. In addition, it drives the actual model execution by using the user provided Logic object and deals with asynchronous execution in runtimes that provide support for it.

Runtime

By inheriting ppe.runtime.BaseRuntime and implementing your own runtime, you can use your non-standard devices with the training loop.

class MyRuntime(BaseRuntime):
    ...

# Register MyRuntime with device name "mydev"
ppe.runtime.runtime_registry.register('mydev', MyRuntime)

ppe.to(module_or_tensor, 'mydev')

See Runtimes for Custom Devices if you are interested in implementing your own runtime.

Logic for Custom Training and Evaluation

In the training and evaluation engines, ppe.handler.BaseLogic API is in charge of abstracting the algorithmic details of the training and evaluation loops.

Logic is an object that defines multiple callbacks used through the training and evaluation processes. With logic, we can implement training of complex models such as GANs.

Users wanting to define their own Logic for training can inherit from ppe.handler.Logic which implements the training and evaluation steps to train a single module.

Logic functions are not exepcted to be directly called by the user. They will be invoked by the Trainer and Evaluator engines.

Default Logic (ppe.handler.Logic)

PPE provides a default logic that performs the forward/backward/optimizer loop for a single model. This logic allows using some torch features such as AMP autocast and GradScaler and performs the backward pass on the outputs specified by the config option backward_outputs.

Runtimes for Custom Devices

Note

This documentation is intended for those implementing the own device backend for PPE training framework. Most users can just skip this chapter.

The ppe.runtime.BaseRuntime API is in charge of abstracting the device details and performing the movement of data and modules to the corresponding device.

A runtime is an object that defines multiple callbacks used through the training, evaluation, and regular model calls. With runtimes, we can implement training in devices other than cpus or gpus with minimal changes to the user code.

Users wanting to override only a few callbacks can inherit from ppe.runtime.PyTorchRuntime which implements the basic functionality for cpu and gpu devices.

Runtimes must be registered by calling the ppe.runtime.runtime_registry.register(device_name, runtime_class) function for them to be discoverable.

Use of ppe.to to transfer modules and batches to custom devices

If you have defined a new runtime for a custom device the ppe.to function allows moving a module or a tensor to the new device by invoking the Runtime.move_tensor and Runtime.move_module when needed.

The module will be tagged by adding a attribute named _ppe_runtime that holds the needed runtime. It is the responsibility of the user custom runtime to perform the actual movement to the device and apply all the transformations needed to a module so it can be correctly executed.

Usually, runtime writers will need to replace the given module forward function by a new one that performs the actual device execution.

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(x):
        return self.layer(x)

class MyMagicDeviceRuntime(ppe.runtime.BaseRuntime):
    def _device_forward(self, args):
        return run_batch_in_my_device(args):

    def move_module(self, module):
        # Registers a hook to initialize the module on the first batch
        # execution
        def hook(module, *args):
            module._ppe_runtime.initialize_module(module, args)

        self.hook = module.register_forward_pre_hook(hook)
        # Change the module forward to do the computation in the device
        module.forward = self._device_forward

    def initialize_module(self, module, loader_or_batch, optimizer=None):
        create_the_module_in_my_device(module, loader_or_batch, optimizer)

# Register the runtime class
ppe.runtime.runtime_registry.register('my_device', MyMagicDeviceRuntime)

# Create a regular module
module = MyModule()
# Move the module to the device
ppe.to(module, device='my_device')

for x in my_dataloader:
    # The first iteration will create the module in the device
    # and the next ones will directly execute the module in the device instead
    # of executing the regular pytorch `forward` call.
    y = model(x)

Please note that this is an oversimplified description and that developing a runtime that is 100% compatible with PyTorch requires to wrap the substitute forward function with torch.autograd.Function among several other concerns such as state_dict manipulation to ensure correcteness.

Extensions

Extensions Manager

Extensions Manager provides an interface to extend your training loop, by integrating it into your manual training loop or Ignite.

Extensions

See the API Reference for the list of built-in extensions.

How to use

Create an ExtensionsManager object and then wrap the iteration of your training loop inside the manager.run_iteration() context manager.

An example follows:

import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions

import time
import math

max_epoch = 10
iters_per_epoch = 938

# manager.extend(...) also works
my_extensions = [extensions.LogReport(),
                 extensions.ProgressBar(),
                 extensions.PrintReport(['epoch', 'iteration', 'sin', 'cos'])]

models = {}
optimizers = []
manager = ppe.training.ExtensionsManager(
    models, optimizers, max_epoch,
    extensions=my_extensions,
    iters_per_epoch=iters_per_epoch)

for epoch in range(max_epoch):
    for i in range(iters_per_epoch):
        with manager.run_iteration():
            ppe.reporting.report({
                'sin': math.sin(i * 2 * math.pi / iters_per_epoch),
                'cos': math.cos(i * 2 * math.pi / iters_per_epoch),
            })
            time.sleep(0.001)

In the examples folder there is a mnist using all the avaiable extensions.

Usage with Ignite

Ignite is supported by using the IgniteExtensionsManager with the trainer as the first argument.

The user needs to define an ignite event to report the appropiated metrics for the extensions to use them.

manager = ppe.training.IgniteExtensionsManager(
    trainer, models, optimizers, epochs,
    extensions=my_extensions)

@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
    ppe.reporting.report({'train/loss':engine.state.output})
Using Evaluators
Regular PyTorch

In order to report the results of the evaluation so they can be accessed by other extensions, an Evaluation extension needs to be created with the argument eval_func set to a function that gets the current data and target batches as parameters and reports the needed metrics. Example

The test function looks has the following signature

def test(args, model, device, data, target):

and is invoked once per batch in the validation dataloader. It is important to report the current validation loss or accuracy in order to the log report to see it.

def test(args, model, device, data, target):
    ...
    # Final result will be average of averages of the same size
    test_loss += F.nll_loss(output, target, reduction='mean').item()
    ppe.reporting.report({'val/loss': test_loss})
    pred = output.argmax(dim=1, keepdim=True)
    correct += pred.eq(target.view_as(pred)).sum().item()
    ppe.reporting.report({'val/acc': correct/len(data)})
Ignite

Just use the IgniteEvaluator extension with the ignite created evaluator as the first parameter and you are ready to go. Example The metrics defined when creating the evaluator with create_supervised_evaluator will be automatically reported

 create_supervised_evaluator(model, metrics={'acc': Accuracy(), 'loss': Loss(F.nll_loss)}, device)
Snapshots

It is possible to take snapshots by using the snapshot training extension just as in chainer.

Whenever the extension is triggered, it saves the status of the optimizer, model and extensions to the output folder in the same way as chainer. To load the snapshot and continue the training call torch.load and use the ExtensionsManager.load_state_dictexample to resume the training. The snapshots can be used outside the pytorch-pfn-extras module just by accessing the models, or optimizers fields of the loaded state.

Extensions execution order

The supported extensions honours the chainer priorities for execution. However, when using Ignite. Chainer extensions are executed after any user-defined ignite events. The idea is to use ignite events to report the metrics of the model, and after this, Chainer extensions will be executed in the chainer defined order.

If you want to execute an event-handler in between chainer extensions, create a Chainer-like extension and access the ignite engine on the .engine attribute of the manager object passed as a parameter when your extension is called.

Creating Extensions

It is possible to create an extension just by passing a function which receives the manager object as an argument to the manager extend call

def my_extension(manager):
    print('Epoch-Iteration: {}-{}'.format(manager.epoch, manager.iteration)

manager.extend(my_extension, trigger=(1, 'iteration')

It is also possible to create extensions using the ppe.training.extension.make_extension decorator to add a specific trigger, default_name, priority. In addition, initializer, finalizer and on_error functions can be specified as well.

@ppe.training.extension.make_extension(finalizer=lambda: print('done'))
def my_extension(manager):
    print('Epoch-Iteration: {}-{}'.format(manager.epoch, manager.iteration)

Finally, it is possible to create an extension by subclassing the ppe.training.extensions.Extension class as shown below.

import pytorch_pfn_extras as ppe

class MyExtension(ppe.training.extension.Extension)
    def __init__(self, args):
        self.args = args

    def initialize(self, manager):
        """
        Automatically called before training. Optional.
        """
        pass

    def __call__(self, manager):
        """
        Called when the associated trigger is fired.
        """
        print('Epoch-Iteration: {}-{}'.format(manager.epoch, manager.iteration)

    def state_dict(self):
        """ 
        Used to serialize the state. Optional.
        """
        return {'args': self.args}

    def load_state_dict(self, state):
        """ 
        Used to deserialize the state. Optional.
        """
        self.args = state['args']

Reporting

reporting.Reporter is used to collect values that users want to watch. The reporter object holds a mapping from value names to the actually observed values. We call this mapping observations.

When a value is passed to the reporter, an object called observer can be optionally attached. In this case, the name of the observer is added as the prefix of the value name. The observer name should be registered beforehand.

import pytorch_pfn_extras as ppe

reporter = ppe.reporting.Reporter()
observer = object()
reporter.add_observer('my_observer', observer)
observation = {}

with reporter.scope(observation):
    reporter.report({'x': 1}, observer)

print(observation)
# outputs: {'my_observer/x': 1}

There is also a global API to add values:

import pytorch_pfn_extras as ppe

reporter = ppe.reporting.Reporter()
observer = object()
reporter.add_observer('my_observer', observer)

observation = {}
with reporter:
    with ppe.reporting.report_scope(observation):
         ppe.reporting.report({'x': 1}, observer)

print(observation)
# outputs: {'my_observer/x': 1}

The most important application of Reporter is to report observed values from different parts of the model in the training and validation procedures. ExtensionsManager objects hold their own Reporter object with the parameters of the target module registered as observers. report() can be used inside the modules to report the observed values (e.g., training loss, accuracy, activation statistics, etc.).

Distributed Snapshot

To take snapshots when using torch.distributed the only needed step is to provide the saver_rank keyword argument to the regular snapshot extension.

# saver_rank is the MPI rank which will write the actual snapshot.
snapshot = extensions.snapshot(saver_rank=saver_rank)

To resume the training, snapshots are loaded in every worker by using the ExtensionsManager.load_state_dict method, or the extensions.snapshot autoload keyword argument.

Utilities

Lazy Modules

Lazy modules can automatically infer shapes of parameters based on the shape of the data given to the first forward invocation.

Following modules are provided:

  • ppe.nn.LazyBatchNorm1d, ppe.nn.LazyBatchNorm2d, ppe.nn.LazyBatchNorm3d

    • Module that behaves as torch.nn.BatchNorm[123]d but num_features can be set to None.

    • These modles are now included as a part of PyTorch 1.9 release (torch.nn.LazyBatchNormXd, pull-request).

The following modules are now considered deprecated as now included as a part of PyTorch 1.8 release:

  • ppe.nn.LazyLinear

  • ppe.nn.LazyConv1d, ppe.nn.LazyConv2d, ppe.nn.LazyConv3d

Now that all lazy modules are merged to the upstream, we encourage you to migrate to PyTorch’s lazy modules. We will keep these implementaions only for backward compatibility.

Note that you need to run a “dummy” forward to initialize lazy parameters. See the example below:

import torch
import torch.nn.functional as F

import pytorch_pfn_extras as ppe


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = ppe.nn.LazyConv2d(None, 20, 5, 1)
        self.conv2 = ppe.nn.LazyConv2d(None, 50, 5, 1)
        self.fc1 = ppe.nn.LazyLinear(None, 500)
        self.fc2 = ppe.nn.LazyLinear(None, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.flatten(start_dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)


model = Net()

# Initialize lazy parameters.
dummy_input = ...
model(dummy_input)

# Pass parameters to the optimizer.
optimizer = torch.optim.SGD(
    model.parameters(), lr=args.lr, momentum=args.momentum)

# Run training loop.
# ...

You need to run a dummy forward before passing parameters to optimizers; otherwise optimizers cannot refer to lazily-initialized parameters. You will get a warning if you pass uninitialized lazy parameters to optimizers:

>>> model = ppe.nn.LazyLinear(None, 10)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
/.../pytorch-pfn-extras/pytorch_pfn_extras/nn/modules/lazy.py:127: UserWarning:
    Use of uninitialized lazy parameter in Optimizer has been detected.
    Maybe you forgot to run forward before passing `module.parameters()` to the optimizer?

Config

Basic
from pytorch_pfn_extras.config import Config
import yaml
pre_eval_config = yaml.load('''
foo:
  bar: 'bar_value'
  ls:
    - 'first'
    - key0: 'value0'
      key1: 'value1'
baz: 'baz_value'
''')
config = Config(pre_eval_config)

Accessing config values:

print(config['/foo/ls/0'])
# 'first'
print(config['/foo/ls/1/key0'])
# 'value0'
print(config['/foo/ls'])
# ['first', {'key0': 'value0', 'key1': 'value1'}]
print(config['/baz'])
# 'baz_value'
Substitution
Callable Substitution

You could replace a value as the return value of a callable.

  • types is an additional input to Config. types is a mapping from a callable’s name to the actual callable.

  • A sub-dictionary containing the key type invokes callable substitution.

pre_eval_config = yaml.load('''
name:
  type: concat
  x0: 'First'
  x1: 'Last'
''')

types = {
  'concat': lambda x0, x1: x0 + ' ' + x1
}

config = Config(pre_eval_config, types)
# the value returned by
# concat(x0='First', x1='Last')
print(config['/name'])
# 'First Last'
Nested
pre_eval_config = yaml.load('''
name:
  type: concat
  x0: 'First'
  x1:
    type: concat
    x0: 'Middle'
    x1: 'Last'
''')
types = {
  'concat': lambda x0, x1: x0 + ' ' + x1
}
config = Config(pre_eval_config, types)
print(config['/name'])
# First Middle Last
Class
pre_eval_config = yaml.load('''
dataset:
  type: Dataset
  n_class: 10
''')

class Dataset(object):

    def __init__(self, n_class):
        self.n_class = n_class

types = {
  'Dataset': Dataset,
}

config = Config(pre_eval_config, types)
print(isintance(config['/dataset'], Dataset))
# True
Substitution by Path
Absolute

@/absolute/path is replaced by the value at /absolute/path.

pre_eval_config = yaml.load('''
foo: 'FOO'
boo:
  baz: '@/foo'
''')
config = Config(pre_eval_config)
print(config['/boo/baz'])
# FOO
Relative

Relative path is also possible using @relative/path.

pre_eval_config = yaml.load('''
foo: 'FOO'
boo:
  baz: '@../foo'
''')
config = Config(pre_eval_config)
print(config['/boo/baz'])
# FOO
Substitution by Attribute

@/path/to/obj.attr_name is replaced by:

  1. Use substitution by path to get an object at /path/to/obj.

  2. Replace the config value by getattr(obj, attr_name), where obj is obtained at step 1.

pre_eval_config = yaml.load('''
dataset:
  type: Dataset
  n_class: 10
n_data: '@/dataset.n_data'
''')

class Dataset(object):

    def __init__(self, n_class):
        self.n_class = n_class
        self.n_data = 4

types = {
  'Dataset': Dataset,
}

config = Config(pre_eval_config, types)
print(config['/n_data'])
# 4
Default Value by Path Substitution

customize_type is a decorator that sets default argument values by path substitution.

from pytorch_pfn_extras.config import customize_type

pre_eval_config = yaml.load('''
dataset:
  type: Dataset
n_class: 5
''')

# If n_class is not passed, the value would be config['/n_class'].
# Both absolute and relative paths are allowed.
@customize_type(n_class='/n_class')
class Dataset(object):

    def __init__(self, n_class):
        self.n_class = n_class

types = {
  'Dataset': Dataset,
}

config = Config(pre_eval_config, types)
print(config['/dataset'].n_class)
# 5
Ignore Substitution

Access using config['!/path'] instead of config['/path'].

pre_eval_config = yaml.load('''
name:
  type: concat
  x0: 'First'
  x1: 'Last'
''')

types = {
  'concat': lambda x0, x1: x0 + ' ' + x1
}

config = Config(pre_eval_config, types)
print(config['!/name'])
# {'type': 'concat', 'x0': 'First', 'x1': 'Last'}
Lazy Evaluation

Callable substitution is lazily executed. This means that callables that are not dependent on the accesed value do not get executed.

pre_eval_config = yaml.load('''
foo:
  - type: f0
  - '@/bar'
bar:
  type: f1
baz:
  type: f2
''')

def f0():
    print('f0 called')
    return 'f0_return'

def f1():
    print('f1 called')
    return 'f1_return'

def f2():
    print('f2 called')
    return 'f2_return'

types = {
  'f0': f0,
  'f1': f1,
  'f2': f2,
}

config = Config(pre_eval_config, types)
config['/foo']  # f2 does not get called
# f0 called
# f1 called

pytorch_pfn_extras.onnx

Extensions to torch.onnx.export.

Installation
pip3 install "pytorch-pfn-extras[onnx]"

Or

  1. Install pytorch-pfn-extras normally

  2. Install onnx with pip install onnx==1.7.0

API
pytorch_pfn_extras.onnx.export_testcase

Instead of specifying file name in torch.onnx.export, pytorch_pfn_extra.onnx.export_testcase specifies directory to output ONNX model and test case in/out.

import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(5, 10, bias=False))
x = torch.zeros((2, 5))

import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x, '/path/to/output')

Directory structure with following will be generated to /path/to/output:

$ tree /path/to/output
/path/to/output
├── meta.json
├── model.onnx
└── test_data_set_0
    ├── input_0.pb
    └── output_0.pb
  • This directory structure format is inspired by ONNX official test data set: (Example: node). PyTorch’s ONNX tests use this format too. (Reference: export_onnx_tests_generator.py)

    • There are scripts in chainer-compiler/utils to run inference in major runtime with the directory structure. For example to inference with ONNXRuntime, run $ python run_onnx_onnxruntime.py /path/to/output to use input_N.pb as input and compare numerically with its output output_N.pb(N is the index of test case).

  • By default meta.json is generated too to track git infos, date times, etc. Add metadata=False argument to suppress this.

out_grad option

If out_grad=True is specified gradient will be dumped too, which is useful for debugging backward. gradient_N.pb and gradient_input_N.pb would be dumped to test case directory with in/out data. gradient_input_N.pb is the initial value of backward, and it’s default value is ones tensor with same shape of output. Use out_grad to specify custom initial value (torch.Tensor type) for it.

model = nn.Sequential(nn.Linear(5, 10, bias=False))
x = torch.zeros((2, 5))

import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x, '/path/to/output', out_grad=True)
$ tree /path/to/output
/path/to/output
├── meta.json
├── model.onnx
└── test_data_set_0
    ├── gradient_0.pb
    ├── gradient_input_0.pb
    ├── input_0.pb
    └── output_0.pb
model_overwrite option

Use model_overwrite option to create multiple data set like following:

import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x1, '/path/to/output')
tou.export_testcase(model, x2, '/path/to/output', model_overwrite=False)

Following is the generated test cases of the above. test_data_set_0 is the inputx1 and is its output, test_data_set_1 is the input x2 and its output.

$ tree /path/to/output
├── meta.json
├── model.onnx
├── test_data_set_0
│   ├── input_0.pb
│   └── output_0.pb
└── test_data_set_1
    ├── input_0.pb
    └── output_0.pb
strip_large_tensor_data option

This option strips large tensor in dumped files which is useful to reduce file size in usage such as benchmarking. Not only model.onnx, in/out, gradient data would be affected too. large_tensor_threshold could be used to specify threshold of large tensor size.

import torchvision
model = torchvision.models.resnet50(pretrained=True)
x = torch.zeros((1, 3, 224, 224))

import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x, '/path/to/output')
tou.export_testcase(model, x, '/path/to/output2', strip_large_tensor_data=True)
$ ls -lh /path/to/output/model.onnx
-rwxrwxrwx 1 user user 98M Jun 24 23:34 /path/to/output/model.onnx
$ ls -lh /path/to/output2/model.onnx
-rwxrwxrwx 1 user user 64K Jun 24 23:34 /path/to/output2/model.onnx

This feature could be called from CLI:

$ python -m pytorch_pfn_extras.onnx.strip_large_tensor resnet50.onnx --out_onnx_path resnet50_slim.onnx
$ ls -lh
-rwxrwxrwx 1 user user 98M Jun 30 09:13 resnet50.onnx
-rwxrwxrwx 1 user user 64K Jun 30 09:16 resnet50_slim.onnx

See $ python -m pytorch_pfn_extras.onnx.strip_large_tensor -h for help

Notes:

If an ONNX runtime does not support no raw_data tensor, unstrip_tensor.py will resolve. See $ python -m pytorch_pfn_extras.onnx.unstrip_tensor -h for help

pytorch_pfn_extras.onnx.export

Function with same interface like torch.onnx.export. Unlike torch.onnx.export, you can use annotation feature (described below), strip_large_tensor_data options, or other torch.onnx extensions.

  • strip_large_tensor_data: Same as export_testcase. Useful reducing file sizes.

  • return_output: Returns output value of model execution. Note: Most output type would be torch.Tensor(not onnx.TensorProto)

model = nn.Sequential(nn.Linear(5, 10, bias=False))
x = torch.zeros((2, 5))

import io, onnx
bytesio = io.BytesIO()
pytorch_pfn_extras.onnx.export(model, x, bytesio)
onnx_proto = onnx.load(io.BytesIO(bytesio.getvalue()))
annotate

Feature to add custom ONNX attribute to specified nn.Module.

Notes:

  • Annotated ONNX would be invalid ONNX format that doesn’t pass check of onnx.checker.check_model.

  • Only valid with pytorch_pfn_extras.onnx.export_testcase or pytorch_pfn_extras.onnx.export export.

  • Only the first ONNX node of modules like nn.Linear, nn.GroupNorm, etc. with multiple ONNX node would be annotated

    • For example nn.Linear with bias is split to MatMul -> Add graph. Only MatMul would be annotated. This is same in apply_annotation (described later) too.

  • Use apply_annotation instead when the annotation target isn’t nn.Module.

import pytorch_pfn_extras.onnx as tou

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Conv2d(6, 9, 3)
        self.conv2 = nn.Conv2d(9, 12, 3)
        self.linear = nn.Linear(28, 20)
        self.linear2 = nn.Linear(20, 15)

    def forward(self, x):
        h = self.conv(x)
        with tou.annotate(key='value'):
            h = self.conv2(h)
            h = self.linear(h)
        h = self.linear2(h)
        return h

model = Net()
x = torch.randn((1, 6, 32, 32))
tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output, 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
  %input.1[FLOAT, 1x6x32x32]
) initializers (
  %17[FLOAT, 28x20]
  %18[FLOAT, 20x15]
  %conv.bias[FLOAT, 9]
  %conv.weight[FLOAT, 9x6x3x3]
  %conv2.bias[FLOAT, 12]
  %conv2.weight[FLOAT, 12x9x3x3]
  %linear.bias[FLOAT, 20]
  %linear2.bias[FLOAT, 15]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
  %10 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], key = 'value', pads = [0, 0, 0, 0], strides = [1, 1]](%9, %conv2.weight, %conv2.bias)
  %12 = MatMul[key = 'value'](%10, %17)
  %13 = Add(%12, %linear.bias)
  %15 = MatMul(%13, %18)
  %16 = Add(%15, %linear2.bias)
  return %16
}

In above example %10 = Conv and %12 = MatMul has key='value' attribute annotated.

apply_annotation

This annotates function call instead of annotating it with with.

The annotate target is nn.Module, so torch.nn.functional couldn’t be annotated

import torch.nn.functional as F
import pytorch_pfn_extras.onnx as tou

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Conv2d(6, 9, 3)
        self.conv2 = nn.Conv2d(9, 12, 3)
        self.linear = nn.Linear(28, 20)
        self.linear2 = nn.Linear(20, 15)

    def forward(self, x):
        h = self.conv(x)
        with tou.annotate(key='value'):
            h = self.conv2(h)
            h = F.relu(h)
            h = self.linear(h)
        h = self.linear2(h)
        return h

model = Net()
x = torch.randn((1, 6, 32, 32))
tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
  %input.1[FLOAT, 1x6x32x32]
) initializers (
  %18[FLOAT, 28x20]
  %19[FLOAT, 20x15]
  %conv.bias[FLOAT, 9]
  %conv.weight[FLOAT, 9x6x3x3]
  %conv2.bias[FLOAT, 12]
  %conv2.weight[FLOAT, 12x9x3x3]
  %linear.bias[FLOAT, 20]
  %linear2.bias[FLOAT, 15]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
  %10 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], key = 'value', pads = [0, 0, 0, 0], strides = [1, 1]](%9, %conv2.weight, %conv2.bias)
  %11 = Relu(%10)
  %13 = MatMul[key = 'value'](%11, %18)
  %14 = Add(%13, %linear.bias)
  %16 = MatMul(%14, %19)
  %17 = Add(%16, %linear2.bias)
  return %17
}

%10 = Conv and %13 = MatMul has key='value' attribute but %11 = Relu hasn’t. By using apply_annotation all node in the function is annotated.

import pytorch_pfn_extras.onnx as tou

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Conv2d(6, 9, 3)
        self.conv2 = nn.Conv2d(9, 12, 3)
        self.linear = nn.Linear(28, 20)
        self.linear2 = nn.Linear(20, 15)

    def forward(self, x):
        h = self.conv(x)
        def _f(x):
            h = self.conv2(x)
            h = F.relu(h)
            h = self.linear(h)
            return h
        h = tou.apply_annotation(_f, h, key='value')
        h = self.linear2(h)
        return h

model = Net()
x = torch.randn((1, 6, 32, 32))
tou.export_testcase(model, x, '/path/to/outout')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
  %input.1[FLOAT, 1x6x32x32]
) initializers (
  %18[FLOAT, 28x20]
  %19[FLOAT, 20x15]
  %conv.bias[FLOAT, 9]
  %conv.weight[FLOAT, 9x6x3x3]
  %conv2.bias[FLOAT, 12]
  %conv2.weight[FLOAT, 12x9x3x3]
  %linear.bias[FLOAT, 20]
  %linear2.bias[FLOAT, 15]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
  %10 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], key = 'value', pads = [0, 0, 0, 0], strides = [1, 1]](%9, %conv2.weight, %conv2.bias)
  %11 = Relu[key = 'value'](%10)
  %13 = MatMul[key = 'value'](%11, %18)
  %14 = Add(%13, %linear.bias)
  %16 = MatMul(%14, %19)
  %17 = Add(%16, %linear2.bias)
  return %17
}

Now %11 = Relu is annotated with key='value' attribute too.

scoped_anchor

This annotates scope’s beginning and end of one or modules by adding Anchor node. Node would be named Anchor_N_start or Anchor_N_end (N is a index) and with op_type Identity.

  • Adding custom parameter would add ONNX attribute and this will generate invalid ONNX in checker.

  • Use this with pytorch_pfn_extras.onnx.export_testcase or pytorch_pfn_extras.onnx.export.

  • When scope has multiple input/output only first input/output will get Anchor node added.

  • N of node name is the index of pair beginning/end Anchor node like Anchor_0_start, Anchor_0_end.

import pytorch_pfn_extras.onnx as tou

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Conv2d(6, 9, 3)
        self.conv2 = nn.Conv2d(9, 12, 3)
        self.linear = nn.Linear(28, 20)
        self.linear2 = nn.Linear(20, 15)

    def forward(self, x):
        h = self.conv(x)
        with tou.scoped_anchor(key='value'):
            h = self.conv2(h)
            h = self.linear(h)
        h = self.linear2(h)
        return h

    def forward(self, x):
        with annotate(key='value'):
            return self.add(x)

model = Net()
x = torch.randn((1, 6, 32, 32))
out_dir = tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
  %input.1[FLOAT, 1x6x32x32]
) initializers (
  %23[FLOAT, 28x20]
  %24[FLOAT, 20x15]
  %conv.bias[FLOAT, 9]
  %conv.weight[FLOAT, 9x6x3x3]
  %conv2.bias[FLOAT, 12]
  %conv2.weight[FLOAT, 12x9x3x3]
  %linear.bias[FLOAT, 20]
  %linear2.bias[FLOAT, 15]
) {
  %9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
  %11 = Identity[key = 'value'](%9)
  %12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%11, %conv2.weight, %conv2.bias)
  %16 = MatMul(%12, %23)
  %17 = Add(%16, %linear.bias)
  %19 = Identity[key = 'value'](%17)
  %21 = MatMul(%19, %24)
  %22 = Add(%21, %linear2.bias)
  return %22
}

%11 = Identity (node name = Anchor_0_start) and %19 = Identity (node name = Anchor_0_end) is added. key='value' is added as ONNX attribute.

non-nn.Module

The target of scope is only nn.Module. You can add adding sub nn.Module instead, if scope bound doesn’t match nn.Module.

import pytorch_pfn_extras.onnx as tou

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        class _Net(nn.Module):
            def forward(self, x):
                return x + torch.ones((1,))
        self.add = _Net()

    def forward(self, x):
        with tou.scoped_anchor(key='value'):
            return self.add(x)

model = Net()
x = torch.randn((1, 6, 32, 32))
out_dir = tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
  %x.1[FLOAT, 1x6x32x32]
) {
  %2 = Identity[key = 'value'](%x.1)
  %3 = Constant[value = <Tensor>]()
  %4 = Add(%2, %3)
  %6 = Identity[key = 'value'](%4)
  return %6
}

Or you can use anchor (described below) instead.

anchor (Future work)

Inserts Anchor node per each arbitrarily position of nn.Module . Node name would be Anchor and op_type would be Identity.

  • Note: adding extra parameter would make extended ONNX format because it would be attribute.

  • Please use it with pytorch_pfn_extras.onnx.export_testcase or pytorch_pfn_extras.onnx.export.

CUDA (CuPy Interoperability)

  • pytorch_pfn_extras.cuda.stream(stream)

    • Context-manager that selects a given stream. This context manager also changes the CuPy’s default stream if CuPy is available. When CuPy is not available, the functionality is the same as the PyTorch’s counterpart, torch.cuda.stream().

  • pytorch_pfn_extras.cuda.use_torch_mempool_in_cupy()

    • Use PyTorch’s memory pool in CuPy. If you want to use PyTorch’s memory pool and non-default CUDA streams, streams must be created and managed using PyTorch (using torch.cuda.Stream() and pytorch_pfn_extras.cuda.stream(stream)). This feature requires CuPy v8.0+ and PyTorch v1.5+.

  • pytorch_pfn_extras.cuda.use_default_mempool_in_cupy()

    • Use CuPy’s default memory pool in CuPy.

  • pytorch_pfn_extras.from_ndarray(ndarray)

    • Creates a Tensor from NumPy/CuPy ndarray.

  • pytorch_pfn_extras.as_ndarray(tensor)

    • Creates a NumPy/CuPy ndarray from Tensor.

  • pytorch_pfn_extras.get_xp(tensor_device_or_ndarray)

    • Returns numpy or cupy module for the given object.

  • pytorch_pfn_extras.as_numpy_dtype(torch_dtype)

    • Returns NumPy dtype for the given torch dtype.

  • pytorch_pfn_extras.from_numpy_dtype(numpy_dtype)

    • Returns torch dtype for the given NumPy dtype.

API Reference

Training Loop

Trainer (techincal preview)

engine.create_trainer(models, optimizers, …)

Creates a trainer object.

engine.create_evaluator(models, *[, …])

Creates an evaluator object.

handler.BaseLogic([options])

handler.Logic([model_name, options])

handler.BaseHandler(logic, options, *args, …)

handler.Handler(logic, entry_runtime, options)

runtime.BaseRuntime(device_spec, options)

A base class for collections of device-specific callback functions.

runtime.PyTorchRuntime(device_spec, options)

A collections of callback functions for the devices that PyTorch supports by default.

Extensions Manager

training.ExtensionsManager(models, …[, …])

Manages the extensions and the current status.

training.IgniteExtensionsManager(engine, …)

Manages extensions and the current status in Ignite training loop.

Extensions

training.extension.make_extension([trigger, …])

Decorator to make given function into an extension.

training.extension.Extension()

Base class of extensions.

training.extension.ExtensionEntry(extension, *)

Extension and options.

training.extensions.Evaluator(self, …[, …])

An extension to evaluate models on a validation set.

training.extensions.LogReport([keys, …])

An extension to output the accumulated results to a log file.

training.extensions.MicroAverage(…[, trigger])

Calculates micro-average ratio.

training.extensions.observe_lr(optimizer[, …])

Returns an extension to record the learning rate.

training.extensions.observe_value(…)

Returns an extension to continuously record a value.

training.extensions.ParameterStatistics(links)

An extension to report parameter statistics.

training.extensions.PlotReport(y_keys[, …])

An extension to output plots.

training.extensions.PrintReport([entries, …])

An extension to print the accumulated results.

training.extensions.ProgressBar([…])

An extension to print a progress bar and recent training status.

training.extensions.ProfileReport([…])

Writes the profile results to a file.

training.extensions.snapshot([savefun, …])

Returns a trainer extension to take snapshots of the trainer.

training.extensions.VariableStatisticsPlot(targets)

An extension to plot statistics for Tensors.

Triggers

training.triggers.EarlyStoppingTrigger(self)

Trigger for Early Stopping

training.triggers.IntervalTrigger(period, unit)

Trigger based on a fixed interval.

training.triggers.ManualScheduleTrigger(…)

Trigger invoked at specified point(s) of iterations or epochs.

training.triggers.BestValueTrigger(key, compare)

Trigger invoked when specific value becomes best.

training.triggers.MaxValueTrigger(key[, trigger])

Trigger invoked when specific value becomes maximum.

training.triggers.MinValueTrigger(key[, trigger])

Trigger invoked when specific value becomes minimum.

training.triggers.OnceTrigger([call_on_resume])

Trigger based on the starting point of the iteration.

training.triggers.TimeTrigger(period)

Trigger based on a fixed time interval.

Reporting

reporting.Reporter()

Object to which observed values are reported.

reporting.report(values[, observer])

Reports observed values with the current reporter object.

reporting.report_scope(observation)

Returns a report scope with the current reporter.

Logging

logging.get_logger(name)

Returns a child logger to be used by applications.

Profiler

profiler.TimeSummary.report(tag[, use_cuda])

Context manager to automatically report execution times.

Distributed Training

nn.parallel.DistributedDataParallel(module)

Module for distributed data parallelism

distributed.initialize_ompi_environment(*[, …])

Initialize torch.distributed environments with values taken from OpenMPI.

Check Pointing

utils.checkpoint

Lazy Modules

nn.Ensure(*[, shape, dtype, broadcastable, …])

Module to check the shape of a tensor.

nn.ensure(tensor[, shape, dtype, …])

Checks the shape and type of a tensor.

nn.LazyLinear(in_features, *args, **kwargs)

Linear module with lazy weight initialization.

nn.LazyConv1d(in_channels, *args, **kwargs)

Conv1d module with lazy weight initialization.

nn.LazyConv2d(in_channels, *args, **kwargs)

Conv2d module with lazy weight initialization.

nn.LazyConv3d(in_channels, *args, **kwargs)

Conv3d module with lazy weight initialization.

nn.LazyBatchNorm1d(num_features, *args, **kwargs)

BatchNorm1d module with lazy weight initialization.

nn.LazyBatchNorm2d(num_features, *args, **kwargs)

BatchNorm2d module with lazy weight initialization.

nn.LazyBatchNorm3d(num_features, *args, **kwargs)

BatchNorm3d module with lazy weight initialization.

ONNX

Export

onnx.export(model, args, f[, return_output, …])

Export model into ONNX Graph.

onnx.export_testcase(model, args, out_dir, *)

Export model and I/O tensors of the model in protobuf format.

Annotation

onnx.annotate(**attrs)

Annotation parameters to the target function.

onnx.apply_annotation(fn, *args, **attrs)

Annotation applier to the target function

onnx.scoped_anchor(**attrs)

Add anchor node to the scoped modules

onnx.export(model, args, f[, return_output, …])

Export model into ONNX Graph.

onnx.export_testcase(model, args, out_dir, *)

Export model and I/O tensors of the model in protobuf format.

Datasets

dataset.SharedDataset(sm_size[, cache_type])

Dataset that caches the load samples in shared memory

dataset.TabularDataset(*args, **kwds)

An abstract class that represents tabular dataset.

dataset.ItemNotFoundException

NumPy/CuPy Compatibility

from_ndarray(ndarray)

Creates a torch.Tensor from a numpy.ndarray or cupy.ndarray.

as_ndarray(tensor)

Creates a numpy.ndarray or cupy.ndarray from torch.Tensor.

get_xp(obj)

Returns a module of ndarray implementation (numpy or cupy) for the given obj.

as_numpy_dtype(torch_dtype)

Returns NumPy dtype for the given PyTorch dtype.

from_numpy_dtype(numpy_dtype)

Returns PyTorch dtype for the given NumPy dtype.

cuda.stream(stream)

Context-manager that selects a given stream.

cuda.use_torch_mempool_in_cupy()

Use the PyTorch memory pool in CuPy.

cuda.use_default_mempool_in_cupy()

Use the default memory pool in CuPy.