pytorch-pfn-extras¶
pytorch-pfn-extras (PPE) is a collection of supplementary components to accelerate research and development in PyTorch.
User Guide¶
Trainer (techinical preview)¶
Trainer and Evaluator¶
Note
The Trainer/Evaluator APIs are currently under technical preview and may subject to change in the future versions.
The Trainer and Evaluator provide the device-agnostic training framework for PyTorch. These APIs abstract the training process using different runtimes, handlers, and logics.
Concepts¶
Trainer (
ppe.engine.create_trainer()
) abstracts the training loop, built on top of theExtensionsManager
.Evaluator (
ppe.engine.create_evaluator()
) abstracts the evaluation step and invoked from the Trainer (usually once in every epoch).Runtime (
ppe.runtime.BaseRuntime
) represents an environment used to execute models. Device-specific implementations will reside here. PPE provides the default Runtime that supports the PyTorch-native devices (ppe.runtime.PyTorchRuntime
).Handler (
ppe.handler.Handler
) is a layer to support device-agnostic training. This is considered as a low-level API and in most cases users can just use the Handler provided by PPE.Logic (
ppe.handler.Logic
) is a set of callback functions that define the training logic (optimizer.zero_grad(), forward, backward, optimizer.step()). You can inherit the class and define your own training flow in case you need more complex training processes such as GAN.Model is a
torch.nn.Module
used for training and evaluation, whose inputs are dicts or keyword arguments and outputs of theforward
pass is a dict.
Note that the default logic will perform backward
in tensors returned by model.forward
so you will need to perform the loss calculation inside the model itself.
Trainer at a glance¶
import torch
import torch.nn.functional as F
import pytorch_pfn_extras as ppe
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.w = torch.nn.LazyLinear(1)
def forward(self, *, x, target):
y = self.w(x)
loss = F.nll_loss(y, target)
prefix = 'train' if self.training else 'val'
ppe.reporting.report({f'{prefix}/loss': loss.item()})
return {'loss': loss}
model = MyModel()
optim = torch.optim.SGD(model.parameters(), lr=0.01)
extensions = [
ppe.training.extensions.LogReport(),
ppe.training.extensions.ProgressBar(),
ppe.training.extensions.PrintReport(
['epoch', 'iteration', 'train/loss', 'val/loss']),
]
device = 'cuda:0' # or any other PyTorch devices ('cpu', etc.) or PPE runtime names
epochs = 10
trainer = ppe.engine.create_trainer(
model,
optim,
epochs,
evaluator=ppe.engine.create_evaluator(
model,
device=device,
progress_bar=True,
),
device=device,
extensions=extensions,
)
# Move the model to the device. This is almost equivalent to
# `model.to(device)`, but supports PPE runtimes as well as the PyTorch's
# built-in devices.
ppe.to(model, device)
# Using dummy data to illustrate the minimal working example.
# Notice that dict keys match with the kwargs of the forward method.
train_loader = torch.utils.data.DataLoader(
[{'x': torch.rand(10, 64), 'target': torch.tensor([1])} for _ in range(1)],
num_workers=8)
val_loader = torch.utils.data.DataLoader(
[{'x': torch.rand(10, 64), 'target': torch.tensor([1])} for _ in range(1)],
num_workers=8)
trainer.run(train_loader, val_loader)
Snapshot¶
To obtain and save the trained model for later use you can use the Snapshot extension, or directly invoke state_dict on the trainer itself.
Handler¶
The ppe.handler.Handler
object is used to help the trainer and evaluator objects
in the Logic and Runtime manipulation. This class
should ideally never be overriden by the user if the desired functionality can be
achieved through subclassing BaseLogic or BaseRuntime.
The handler object’s main responsibility is to inspect all the submodules of a module to obtain the runtimes they have associated, and then execute their callbacks accordingly. In addition, it drives the actual model execution by using the user provided Logic object and deals with asynchronous execution in runtimes that provide support for it.
Runtime¶
By inheriting ppe.runtime.BaseRuntime
and implementing your own runtime, you can use your non-standard devices with the training loop.
class MyRuntime(BaseRuntime):
...
# Register MyRuntime with device name "mydev"
ppe.runtime.runtime_registry.register('mydev', MyRuntime)
ppe.to(module_or_tensor, 'mydev')
See Runtimes for Custom Devices if you are interested in implementing your own runtime.
Logic for Custom Training and Evaluation¶
In the training and evaluation engines, ppe.handler.BaseLogic
API is in charge of abstracting the algorithmic details of the training and evaluation loops.
Logic is an object that defines multiple callbacks used through the training and evaluation processes. With logic, we can implement training of complex models such as GANs.
Users wanting to define their own Logic for training can inherit from
ppe.handler.Logic
which implements the training and evaluation steps to train
a single module.
Logic functions are not exepcted to be directly called by the user. They will be invoked by the Trainer and Evaluator engines.
Default Logic (ppe.handler.Logic
)¶
PPE provides a default logic that performs the forward/backward/optimizer loop for a single model. This logic allows using some torch features such as AMP autocast and GradScaler and performs the backward pass on the outputs specified by the config option backward_outputs.
CodeBlock Logic (ppe.handler.Logic
)¶
With the CodeBlock API, we provide a basic logic that uses it to perform the training of a single model. Similarly to the default logic AMP features are supported but by means of the Runtime. For more information check the codeblock documentation.
Runtimes for Custom Devices¶
Note
This documentation is intended for those implementing the own device backend for PPE training framework. Most users can just skip this chapter.
The ppe.runtime.BaseRuntime
API is in charge of abstracting the device details
and performing the movement of data and modules to the corresponding device.
A runtime is an object that defines multiple callbacks used through the training, evaluation, and regular model calls. With runtimes, we can implement training in devices other than cpus or gpus with minimal changes to the user code.
Users wanting to override only a few callbacks can inherit from
ppe.runtime.PyTorchRuntime
which implements the basic functionality for cpu and
gpu devices.
Runtimes must be registered by calling the
ppe.runtime.runtime_registry.register(device_name, runtime_class)
function
for them to be discoverable.
Use of ppe.to
to transfer modules and batches to custom devices¶
If you have defined a new runtime for a custom device the ppe.to
function
allows moving a module or a tensor to the new device by invoking the
Runtime.move_tensor
and Runtime.move_module
when needed.
The module will be tagged by adding a attribute named
_ppe_runtime
that holds the needed runtime. It is the responsibility of the
user custom runtime to perform the actual movement to the device and apply
all the transformations needed to a module so it can be correctly executed.
Usually, runtime writers will need to replace the given module forward function by a new one that performs the actual device execution.
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(x):
return self.layer(x)
class MyMagicDeviceRuntime(ppe.runtime.BaseRuntime):
def _device_forward(self, args):
return run_batch_in_my_device(args):
def move_module(self, module):
# Registers a hook to initialize the module on the first batch
# execution
def hook(module, *args):
module._ppe_runtime.initialize_module(module, args)
self.hook = module.register_forward_pre_hook(hook)
# Change the module forward to do the computation in the device
module.forward = self._device_forward
def initialize_module(self, module, loader_or_batch, optimizer=None):
create_the_module_in_my_device(module, loader_or_batch, optimizer)
# Register the runtime class
ppe.runtime.runtime_registry.register('my_device', MyMagicDeviceRuntime)
# Create a regular module
module = MyModule()
# Move the module to the device
ppe.to(module, device='my_device')
for x in my_dataloader:
# The first iteration will create the module in the device
# and the next ones will directly execute the module in the device instead
# of executing the regular pytorch `forward` call.
y = model(x)
Please note that this is an oversimplified description and that developing a
runtime that is 100% compatible with PyTorch requires to wrap the substitute
forward function with torch.autograd.Function
among several other concerns
such as state_dict
manipulation to ensure correcteness.
Runtime Registry¶
When creating a new Runtime
class for custom needs, they need to be registered
in a global runtime_registry
object as detailed above.
This object is of the _RuntimeRegistry
type and it maintains a map of strings and
Runtime
types. The keys are the devices passed to ppe.to
and the
types will be the type of the Runtime
object that ppe.to
will use to treat the
module or tensor. Beware that users are not supposed to interact directly with this class, only
with the runtime_registry.register
to register new runtimes.
CodeBlocks for Abstracting Logic Steps¶
The ppe.handler.CodeBlock
API
provides a mean of abstracting the actions that are possible to be done in a model
in a device agnostic way.
Currently there is support for two different actions using CodeBlock
.
takes a model, an optimizer and returns a
CodeBlock
object that performs the forward, backward and optimizer step at once.
takes a model and returns a
CodeBlock
object that performs only the forward pass.
Executing CodeBlocks¶
For executing CodeBlock
objects we need to add an :method:`ppe.runtime.BaseRuntime.execute <pytorch_pfn_extras.runtime.BaseRuntime.execute` to the
corresponding Runtime
class. This method takes a CodeBlock
and uses the information in the object to execute the CodeBlock
in the
device. Note that the :method:`ppe.runtime.PyTorchRuntime.execute <pytorch_pfn_extras.runtime.PyTorchRuntime.execute` method providesn support
for using PyTorch AMP with autocast or gradient scaling if needed.
Moreover, you can execute CodeBlock
objects outside the training API.
ppe.to(model, "cuda:0")
cblock = ppe.handler.update_parameters(model, optimizer)
outs = cblock(input_batch)
The only requirement is that the associated model has been assigned a device using ppe.to
.
Extensions¶
Extensions Manager¶
Extensions Manager provides an interface to extend your training loop, by integrating it into your manual training loop or Ignite.
Extensions¶
See the API Reference for the list of built-in extensions.
How to use¶
Create an ExtensionsManager
object and then wrap the iteration of your
training loop inside the manager.run_iteration()
context manager.
An example follows:
import pytorch_pfn_extras as ppe
from pytorch_pfn_extras.training import extensions
import time
import math
max_epoch = 10
iters_per_epoch = 938
# manager.extend(...) also works
my_extensions = [extensions.LogReport(),
extensions.ProgressBar(),
extensions.PrintReport(['epoch', 'iteration', 'sin', 'cos'])]
models = {}
optimizers = []
manager = ppe.training.ExtensionsManager(
models, optimizers, max_epoch,
extensions=my_extensions,
iters_per_epoch=iters_per_epoch)
for epoch in range(max_epoch):
for i in range(iters_per_epoch):
with manager.run_iteration():
ppe.reporting.report({
'sin': math.sin(i * 2 * math.pi / iters_per_epoch),
'cos': math.cos(i * 2 * math.pi / iters_per_epoch),
})
time.sleep(0.001)
In the examples folder there is a mnist using all the avaiable extensions.
Usage with Ignite¶
Ignite is supported by using the IgniteExtensionsManager
with the trainer
as the first argument.
The user needs to define an ignite event to report the appropiated metrics for the extensions to use them.
manager = ppe.training.IgniteExtensionsManager(
trainer, models, optimizers, epochs,
extensions=my_extensions)
@trainer.on(Events.ITERATION_COMPLETED)
def report_loss(engine):
ppe.reporting.report({'train/loss':engine.state.output})
Using Evaluators¶
In order to report the results of the evaluation so they can be
accessed by other extensions, an Evaluation
extension
needs to be created with the argument eval_func
set to a function
that gets the current data and target batches as parameters and
reports the needed metrics. Example
The test function looks has the following signature
def test(args, model, device, data, target):
and is invoked once per batch in the validation dataloader. It is important to report the current validation loss or accuracy in order to the log report to see it.
def test(args, model, device, data, target):
...
# Final result will be average of averages of the same size
test_loss += F.nll_loss(output, target, reduction='mean').item()
ppe.reporting.report({'val/loss': test_loss})
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
ppe.reporting.report({'val/acc': correct/len(data)})
Just use the IgniteEvaluator
extension with the ignite created evaluator as
the first parameter and you are ready to go. Example
The metrics defined when creating the evaluator with create_supervised_evaluator
will be automatically reported
create_supervised_evaluator(model, metrics={'acc': Accuracy(), 'loss': Loss(F.nll_loss)}, device)
Snapshots¶
It is possible to take snapshots by using the snapshot
training extension just as in chainer.
Whenever the extension is triggered, it saves the status of the optimizer, model and extensions to the output folder in the same way as chainer.
To load the snapshot and continue the training call torch.load
and use the ExtensionsManager.load_state_dict
example to resume the training.
The snapshots can be used outside the pytorch-pfn-extras module just by accessing the models, or optimizers fields of the loaded state.
Extensions execution order¶
The supported extensions honours the chainer priorities for execution. However, when using Ignite. Chainer extensions are executed after any user-defined ignite events. The idea is to use ignite events to report the metrics of the model, and after this, Chainer extensions will be executed in the chainer defined order.
If you want to execute an event-handler in between chainer extensions, create a Chainer-like extension
and access the ignite engine on the .engine
attribute of the manager object passed as a parameter
when your extension is called.
Creating Extensions¶
It is possible to create an extension just by passing a function which receives the manager object as an argument to the manager extend call
def my_extension(manager):
print('Epoch-Iteration: {}-{}'.format(manager.epoch, manager.iteration)
manager.extend(my_extension, trigger=(1, 'iteration')
It is also possible to create extensions using the ppe.training.extension.make_extension
decorator to add a specific trigger
, default_name
, priority
.
In addition, initializer
, finalizer
and on_error
functions can be specified as well.
@ppe.training.extension.make_extension(finalizer=lambda: print('done'))
def my_extension(manager):
print('Epoch-Iteration: {}-{}'.format(manager.epoch, manager.iteration)
Finally, it is possible to create an extension by subclassing the ppe.training.extensions.Extension
class
as shown below.
import pytorch_pfn_extras as ppe
class MyExtension(ppe.training.extension.Extension)
def __init__(self, args):
self.args = args
def initialize(self, manager):
"""
Automatically called before training. Optional.
"""
pass
def __call__(self, manager):
"""
Called when the associated trigger is fired.
"""
print('Epoch-Iteration: {}-{}'.format(manager.epoch, manager.iteration)
def state_dict(self):
"""
Used to serialize the state. Optional.
"""
return {'args': self.args}
def load_state_dict(self, state):
"""
Used to deserialize the state. Optional.
"""
self.args = state['args']
Reporting¶
reporting.Reporter
is used to collect values that users want to watch.
The reporter object holds a mapping from value names to the actually observed values. We call this mapping observations.
When a value is passed to the reporter, an object called observer can be optionally attached. In this case, the name of the observer is added as the prefix of the value name. The observer name should be registered beforehand.
import pytorch_pfn_extras as ppe
reporter = ppe.reporting.Reporter()
observer = object()
reporter.add_observer('my_observer', observer)
observation = {}
with reporter.scope(observation):
reporter.report({'x': 1}, observer)
print(observation)
# outputs: {'my_observer/x': 1}
There is also a global API to add values:
import pytorch_pfn_extras as ppe
reporter = ppe.reporting.Reporter()
observer = object()
reporter.add_observer('my_observer', observer)
observation = {}
with reporter:
with ppe.reporting.report_scope(observation):
ppe.reporting.report({'x': 1}, observer)
print(observation)
# outputs: {'my_observer/x': 1}
The most important application of Reporter is to report observed values from different parts of the model in the training
and validation procedures. ExtensionsManager
objects hold their own Reporter
object with the parameters of the target
module registered as observers. report()
can be used inside the modules to report the observed values (e.g., training loss,
accuracy, activation statistics, etc.).
Distributed Snapshot¶
To take snapshots when using torch.distributed
the only needed step is to
provide the saver_rank
keyword argument to the regular snapshot extension.
# saver_rank is the MPI rank which will write the actual snapshot.
snapshot = extensions.snapshot(saver_rank=saver_rank)
To resume the training, snapshots are loaded in every worker by using the
ExtensionsManager.load_state_dict
method, or the extensions.snapshot
autoload
keyword argument.
Utilities¶
Lazy Modules¶
Lazy modules can automatically infer shapes of parameters based on the shape of the data given to the first forward invocation.
Following modules are provided:
ppe.nn.LazyBatchNorm1d
,ppe.nn.LazyBatchNorm2d
,ppe.nn.LazyBatchNorm3d
Module that behaves as
torch.nn.BatchNorm[123]d
butnum_features
can be set toNone
.These modles are now included as a part of PyTorch 1.9 release (torch.nn.LazyBatchNormXd, pull-request).
The following modules are now considered deprecated as now included as a part of PyTorch 1.8 release:
ppe.nn.LazyLinear
Module that behaves as
torch.nn.Linear
butin_features
can be set toNone
.PyTorch-native implementation: (torch.nn.LazyLinear, pull-request)
ppe.nn.LazyConv1d
,ppe.nn.LazyConv2d
,ppe.nn.LazyConv3d
Module that behaves as
torch.nn.Conv[123]d
butin_channels
can be set toNone
.PyTorch-native implementation: (torch.nn.LazyConvXd, pull-request)
Now that all lazy modules are merged to the upstream, we encourage you to migrate to PyTorch’s lazy modules. We will keep these implementaions only for backward compatibility.
Note that you need to run a “dummy” forward to initialize lazy parameters. See the example below:
import torch
import torch.nn.functional as F
import pytorch_pfn_extras as ppe
class Net(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv1 = ppe.nn.LazyConv2d(None, 20, 5, 1)
self.conv2 = ppe.nn.LazyConv2d(None, 50, 5, 1)
self.fc1 = ppe.nn.LazyLinear(None, 500)
self.fc2 = ppe.nn.LazyLinear(None, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.flatten(start_dim=1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
# Initialize lazy parameters.
dummy_input = ...
model(dummy_input)
# Pass parameters to the optimizer.
optimizer = torch.optim.SGD(
model.parameters(), lr=args.lr, momentum=args.momentum)
# Run training loop.
# ...
You need to run a dummy forward before passing parameters to optimizers; otherwise optimizers cannot refer to lazily-initialized parameters. You will get a warning if you pass uninitialized lazy parameters to optimizers:
>>> model = ppe.nn.LazyLinear(None, 10)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
/.../pytorch-pfn-extras/pytorch_pfn_extras/nn/modules/lazy.py:127: UserWarning:
Use of uninitialized lazy parameter in Optimizer has been detected.
Maybe you forgot to run forward before passing `module.parameters()` to the optimizer?
Config¶
Basic¶
from pytorch_pfn_extras.config import Config
import yaml
pre_eval_config = yaml.load('''
foo:
bar: 'bar_value'
ls:
- 'first'
- key0: 'value0'
key1: 'value1'
baz: 'baz_value'
''')
config = Config(pre_eval_config)
Accessing config values:
print(config['/foo/ls/0'])
# 'first'
print(config['/foo/ls/1/key0'])
# 'value0'
print(config['/foo/ls'])
# ['first', {'key0': 'value0', 'key1': 'value1'}]
print(config['/baz'])
# 'baz_value'
Substitution¶
Callable Substitution¶
You could replace a value as the return value of a callable.
types
is an additional input toConfig
.types
is a mapping from a callable’s name to the actual callable.A sub-dictionary containing the key
type
invokes callable substitution.
pre_eval_config = yaml.load('''
name:
type: concat
x0: 'First'
x1: 'Last'
''')
types = {
'concat': lambda x0, x1: x0 + ' ' + x1
}
config = Config(pre_eval_config, types)
# the value returned by
# concat(x0='First', x1='Last')
print(config['/name'])
# 'First Last'
pre_eval_config = yaml.load('''
name:
type: concat
x0: 'First'
x1:
type: concat
x0: 'Middle'
x1: 'Last'
''')
types = {
'concat': lambda x0, x1: x0 + ' ' + x1
}
config = Config(pre_eval_config, types)
print(config['/name'])
# First Middle Last
pre_eval_config = yaml.load('''
dataset:
type: Dataset
n_class: 10
''')
class Dataset(object):
def __init__(self, n_class):
self.n_class = n_class
types = {
'Dataset': Dataset,
}
config = Config(pre_eval_config, types)
print(isintance(config['/dataset'], Dataset))
# True
Substitution by Path¶
@/absolute/path
is replaced by the value at /absolute/path
.
pre_eval_config = yaml.load('''
foo: 'FOO'
boo:
baz: '@/foo'
''')
config = Config(pre_eval_config)
print(config['/boo/baz'])
# FOO
Relative path is also possible using @relative/path
.
pre_eval_config = yaml.load('''
foo: 'FOO'
boo:
baz: '@../foo'
''')
config = Config(pre_eval_config)
print(config['/boo/baz'])
# FOO
Substitution by Attribute¶
@/path/to/obj.attr_name
is replaced by:
Use substitution by path to get an object at
/path/to/obj
.Replace the config value by
getattr(obj, attr_name)
, whereobj
is obtained at step 1.
pre_eval_config = yaml.load('''
dataset:
type: Dataset
n_class: 10
n_data: '@/dataset.n_data'
''')
class Dataset(object):
def __init__(self, n_class):
self.n_class = n_class
self.n_data = 4
types = {
'Dataset': Dataset,
}
config = Config(pre_eval_config, types)
print(config['/n_data'])
# 4
Default Value by Path Substitution¶
customize_type
is a decorator that sets default argument values by path substitution.
from pytorch_pfn_extras.config import customize_type
pre_eval_config = yaml.load('''
dataset:
type: Dataset
n_class: 5
''')
# If n_class is not passed, the value would be config['/n_class'].
# Both absolute and relative paths are allowed.
@customize_type(n_class='/n_class')
class Dataset(object):
def __init__(self, n_class):
self.n_class = n_class
types = {
'Dataset': Dataset,
}
config = Config(pre_eval_config, types)
print(config['/dataset'].n_class)
# 5
Ignore Substitution¶
Access using config['!/path']
instead of config['/path']
.
pre_eval_config = yaml.load('''
name:
type: concat
x0: 'First'
x1: 'Last'
''')
types = {
'concat': lambda x0, x1: x0 + ' ' + x1
}
config = Config(pre_eval_config, types)
print(config['!/name'])
# {'type': 'concat', 'x0': 'First', 'x1': 'Last'}
Lazy Evaluation¶
Callable substitution is lazily executed. This means that callables that are not dependent on the accesed value do not get executed.
pre_eval_config = yaml.load('''
foo:
- type: f0
- '@/bar'
bar:
type: f1
baz:
type: f2
''')
def f0():
print('f0 called')
return 'f0_return'
def f1():
print('f1 called')
return 'f1_return'
def f2():
print('f2 called')
return 'f2_return'
types = {
'f0': f0,
'f1': f1,
'f2': f2,
}
config = Config(pre_eval_config, types)
config['/foo'] # f2 does not get called
# f0 called
# f1 called
pytorch_pfn_extras.onnx¶
Extensions to torch.onnx.export
.
Installation¶
pip3 install "pytorch-pfn-extras[onnx]"
Or
Install pytorch-pfn-extras normally
Install onnx with
pip install onnx==1.7.0
API¶
pytorch_pfn_extras.onnx.export_testcase
¶
Instead of specifying file name in torch.onnx.export
, pytorch_pfn_extra.onnx.export_testcase
specifies directory to output ONNX model and test case in/out.
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(5, 10, bias=False))
x = torch.zeros((2, 5))
import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x, '/path/to/output')
Directory structure with following will be generated to /path/to/output
:
$ tree /path/to/output
/path/to/output
├── meta.json
├── model.onnx
└── test_data_set_0
├── input_0.pb
└── output_0.pb
This directory structure format is inspired by ONNX official test data set: (Example: node). PyTorch’s ONNX tests use this format too. (Reference: export_onnx_tests_generator.py)
There are scripts in chainer-compiler/utils to run inference in major runtime with the directory structure. For example to inference with ONNXRuntime, run
$ python run_onnx_onnxruntime.py /path/to/output
to useinput_N.pb
as input and compare numerically with its outputoutput_N.pb
(N is the index of test case).
By default
meta.json
is generated too to track git infos, date times, etc. Addmetadata=False
argument to suppress this.
out_grad
option¶If out_grad=True
is specified gradient will be dumped too, which is useful for debugging backward. gradient_N.pb
and gradient_input_N.pb
would be dumped to test case directory with in/out data.
gradient_input_N.pb
is the initial value of backward, and it’s default value is ones tensor with same shape of output. Use out_grad
to specify custom initial value (torch.Tensor
type) for it.
model = nn.Sequential(nn.Linear(5, 10, bias=False))
x = torch.zeros((2, 5))
import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x, '/path/to/output', out_grad=True)
$ tree /path/to/output
/path/to/output
├── meta.json
├── model.onnx
└── test_data_set_0
├── gradient_0.pb
├── gradient_input_0.pb
├── input_0.pb
└── output_0.pb
model_overwrite
option¶Use model_overwrite
option to create multiple data set like following:
import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x1, '/path/to/output')
tou.export_testcase(model, x2, '/path/to/output', model_overwrite=False)
Following is the generated test cases of the above.
test_data_set_0
is the inputx1
and is its output, test_data_set_1
is the input x2
and its output.
$ tree /path/to/output
├── meta.json
├── model.onnx
├── test_data_set_0
│ ├── input_0.pb
│ └── output_0.pb
└── test_data_set_1
├── input_0.pb
└── output_0.pb
strip_large_tensor_data
option¶This option strips large tensor in dumped files which is useful to reduce file size in usage such as benchmarking. Not only model.onnx
, in/out, gradient data would be affected too.
large_tensor_threshold
could be used to specify threshold of large tensor size.
import torchvision
model = torchvision.models.resnet50(pretrained=True)
x = torch.zeros((1, 3, 224, 224))
import pytorch_pfn_extras.onnx as tou
tou.export_testcase(model, x, '/path/to/output')
tou.export_testcase(model, x, '/path/to/output2', strip_large_tensor_data=True)
$ ls -lh /path/to/output/model.onnx
-rwxrwxrwx 1 user user 98M Jun 24 23:34 /path/to/output/model.onnx
$ ls -lh /path/to/output2/model.onnx
-rwxrwxrwx 1 user user 64K Jun 24 23:34 /path/to/output2/model.onnx
This feature could be called from CLI:
$ python -m pytorch_pfn_extras.onnx.strip_large_tensor resnet50.onnx --out_onnx_path resnet50_slim.onnx
$ ls -lh
-rwxrwxrwx 1 user user 98M Jun 30 09:13 resnet50.onnx
-rwxrwxrwx 1 user user 64K Jun 30 09:16 resnet50_slim.onnx
See $ python -m pytorch_pfn_extras.onnx.strip_large_tensor -h
for help
Notes:
If an ONNX runtime does not support no raw_data
tensor, unstrip_tensor.py
will resolve. See $ python -m pytorch_pfn_extras.onnx.unstrip_tensor -h
for help
pytorch_pfn_extras.onnx.export
¶
Function with same interface like torch.onnx.export
. Unlike torch.onnx.export
, you can use annotation feature (described below), strip_large_tensor_data
options, or other torch.onnx
extensions.
strip_large_tensor_data
: Same asexport_testcase
. Useful reducing file sizes.return_output
: Returns output value of model execution. Note: Most output type would betorch.Tensor
(notonnx.TensorProto
)
model = nn.Sequential(nn.Linear(5, 10, bias=False))
x = torch.zeros((2, 5))
import io, onnx
bytesio = io.BytesIO()
pytorch_pfn_extras.onnx.export(model, x, bytesio)
onnx_proto = onnx.load(io.BytesIO(bytesio.getvalue()))
annotate
¶
Feature to add custom ONNX attribute to specified nn.Module
.
Notes:
Annotated ONNX would be invalid ONNX format that doesn’t pass check of onnx.checker.check_model.
Only valid with
pytorch_pfn_extras.onnx.export_testcase
orpytorch_pfn_extras.onnx.export
export.Only the first ONNX node of modules like
nn.Linear
,nn.GroupNorm
, etc. with multiple ONNX node would be annotatedFor example
nn.Linear
with bias is split toMatMul
->Add
graph. OnlyMatMul
would be annotated. This is same inapply_annotation
(described later) too.
Use
apply_annotation
instead when the annotation target isn’tnn.Module
.
import pytorch_pfn_extras.onnx as tou
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(6, 9, 3)
self.conv2 = nn.Conv2d(9, 12, 3)
self.linear = nn.Linear(28, 20)
self.linear2 = nn.Linear(20, 15)
def forward(self, x):
h = self.conv(x)
with tou.annotate(key='value'):
h = self.conv2(h)
h = self.linear(h)
h = self.linear2(h)
return h
model = Net()
x = torch.randn((1, 6, 32, 32))
tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output, 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
%input.1[FLOAT, 1x6x32x32]
) initializers (
%17[FLOAT, 28x20]
%18[FLOAT, 20x15]
%conv.bias[FLOAT, 9]
%conv.weight[FLOAT, 9x6x3x3]
%conv2.bias[FLOAT, 12]
%conv2.weight[FLOAT, 12x9x3x3]
%linear.bias[FLOAT, 20]
%linear2.bias[FLOAT, 15]
) {
%9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
%10 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], key = 'value', pads = [0, 0, 0, 0], strides = [1, 1]](%9, %conv2.weight, %conv2.bias)
%12 = MatMul[key = 'value'](%10, %17)
%13 = Add(%12, %linear.bias)
%15 = MatMul(%13, %18)
%16 = Add(%15, %linear2.bias)
return %16
}
In above example %10 = Conv
and %12 = MatMul
has key='value'
attribute annotated.
apply_annotation
¶
This annotates function call instead of annotating it with with
.
The annotate
target is nn.Module
, so torch.nn.functional
couldn’t be annotated
import torch.nn.functional as F
import pytorch_pfn_extras.onnx as tou
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(6, 9, 3)
self.conv2 = nn.Conv2d(9, 12, 3)
self.linear = nn.Linear(28, 20)
self.linear2 = nn.Linear(20, 15)
def forward(self, x):
h = self.conv(x)
with tou.annotate(key='value'):
h = self.conv2(h)
h = F.relu(h)
h = self.linear(h)
h = self.linear2(h)
return h
model = Net()
x = torch.randn((1, 6, 32, 32))
tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
%input.1[FLOAT, 1x6x32x32]
) initializers (
%18[FLOAT, 28x20]
%19[FLOAT, 20x15]
%conv.bias[FLOAT, 9]
%conv.weight[FLOAT, 9x6x3x3]
%conv2.bias[FLOAT, 12]
%conv2.weight[FLOAT, 12x9x3x3]
%linear.bias[FLOAT, 20]
%linear2.bias[FLOAT, 15]
) {
%9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
%10 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], key = 'value', pads = [0, 0, 0, 0], strides = [1, 1]](%9, %conv2.weight, %conv2.bias)
%11 = Relu(%10)
%13 = MatMul[key = 'value'](%11, %18)
%14 = Add(%13, %linear.bias)
%16 = MatMul(%14, %19)
%17 = Add(%16, %linear2.bias)
return %17
}
%10 = Conv
and %13 = MatMul
has key='value'
attribute but %11 = Relu
hasn’t.
By using apply_annotation
all node in the function is annotated.
import pytorch_pfn_extras.onnx as tou
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(6, 9, 3)
self.conv2 = nn.Conv2d(9, 12, 3)
self.linear = nn.Linear(28, 20)
self.linear2 = nn.Linear(20, 15)
def forward(self, x):
h = self.conv(x)
def _f(x):
h = self.conv2(x)
h = F.relu(h)
h = self.linear(h)
return h
h = tou.apply_annotation(_f, h, key='value')
h = self.linear2(h)
return h
model = Net()
x = torch.randn((1, 6, 32, 32))
tou.export_testcase(model, x, '/path/to/outout')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
%input.1[FLOAT, 1x6x32x32]
) initializers (
%18[FLOAT, 28x20]
%19[FLOAT, 20x15]
%conv.bias[FLOAT, 9]
%conv.weight[FLOAT, 9x6x3x3]
%conv2.bias[FLOAT, 12]
%conv2.weight[FLOAT, 12x9x3x3]
%linear.bias[FLOAT, 20]
%linear2.bias[FLOAT, 15]
) {
%9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
%10 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], key = 'value', pads = [0, 0, 0, 0], strides = [1, 1]](%9, %conv2.weight, %conv2.bias)
%11 = Relu[key = 'value'](%10)
%13 = MatMul[key = 'value'](%11, %18)
%14 = Add(%13, %linear.bias)
%16 = MatMul(%14, %19)
%17 = Add(%16, %linear2.bias)
return %17
}
Now %11 = Relu
is annotated with key='value'
attribute too.
scoped_anchor
¶
This annotates scope’s beginning and end of one or modules by adding Anchor node.
Node would be named Anchor_N_start
or Anchor_N_end
(N is a index) and with op_type Identity.
Adding custom parameter would add ONNX attribute and this will generate invalid ONNX in checker.
Use this with
pytorch_pfn_extras.onnx.export_testcase
orpytorch_pfn_extras.onnx.export
.When scope has multiple input/output only first input/output will get Anchor node added.
N
of node name is the index of pair beginning/end Anchor node likeAnchor_0_start
,Anchor_0_end
.
import pytorch_pfn_extras.onnx as tou
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Conv2d(6, 9, 3)
self.conv2 = nn.Conv2d(9, 12, 3)
self.linear = nn.Linear(28, 20)
self.linear2 = nn.Linear(20, 15)
def forward(self, x):
h = self.conv(x)
with tou.scoped_anchor(key='value'):
h = self.conv2(h)
h = self.linear(h)
h = self.linear2(h)
return h
def forward(self, x):
with annotate(key='value'):
return self.add(x)
model = Net()
x = torch.randn((1, 6, 32, 32))
out_dir = tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
%input.1[FLOAT, 1x6x32x32]
) initializers (
%23[FLOAT, 28x20]
%24[FLOAT, 20x15]
%conv.bias[FLOAT, 9]
%conv.weight[FLOAT, 9x6x3x3]
%conv2.bias[FLOAT, 12]
%conv2.weight[FLOAT, 12x9x3x3]
%linear.bias[FLOAT, 20]
%linear2.bias[FLOAT, 15]
) {
%9 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%input.1, %conv.weight, %conv.bias)
%11 = Identity[key = 'value'](%9)
%12 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [1, 1]](%11, %conv2.weight, %conv2.bias)
%16 = MatMul(%12, %23)
%17 = Add(%16, %linear.bias)
%19 = Identity[key = 'value'](%17)
%21 = MatMul(%19, %24)
%22 = Add(%21, %linear2.bias)
return %22
}
%11 = Identity
(node name = Anchor_0_start
) and %19 = Identity
(node name = Anchor_0_end
) is added. key='value'
is added as ONNX attribute.
nn.Module
¶The target of scope is only nn.Module
. You can add adding sub nn.Module
instead, if scope bound doesn’t match nn.Module
.
import pytorch_pfn_extras.onnx as tou
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
class _Net(nn.Module):
def forward(self, x):
return x + torch.ones((1,))
self.add = _Net()
def forward(self, x):
with tou.scoped_anchor(key='value'):
return self.add(x)
model = Net()
x = torch.randn((1, 6, 32, 32))
out_dir = tou.export_testcase(model, x, '/path/to/output')
onnx_proto = onnx.load(os.path.join('/path/to/output', 'model.onnx'))
print(onnx.helper.printable_graph(onnx_proto.graph))
graph torch-jit-export (
%x.1[FLOAT, 1x6x32x32]
) {
%2 = Identity[key = 'value'](%x.1)
%3 = Constant[value = <Tensor>]()
%4 = Add(%2, %3)
%6 = Identity[key = 'value'](%4)
return %6
}
Or you can use anchor
(described below) instead.
anchor
(Future work)¶
Inserts Anchor node per each arbitrarily position of nn.Module
. Node name would be Anchor and op_type would be Identity
.
Note: adding extra parameter would make extended ONNX format because it would be attribute.
Please use it with
pytorch_pfn_extras.onnx.export_testcase
orpytorch_pfn_extras.onnx.export
.
CUDA (CuPy Interoperability)¶
pytorch_pfn_extras.cuda.stream(stream)
Context-manager that selects a given stream. This context manager also changes the CuPy’s default stream if CuPy is available. When CuPy is not available, the functionality is the same as the PyTorch’s counterpart,
torch.cuda.stream()
.
pytorch_pfn_extras.cuda.use_torch_mempool_in_cupy()
Use PyTorch’s memory pool in CuPy. If you want to use PyTorch’s memory pool and non-default CUDA streams, streams must be created and managed using PyTorch (using
torch.cuda.Stream()
andpytorch_pfn_extras.cuda.stream(stream)
). This feature requires CuPy v8.0+ and PyTorch v1.5+.
pytorch_pfn_extras.cuda.use_default_mempool_in_cupy()
Use CuPy’s default memory pool in CuPy.
pytorch_pfn_extras.from_ndarray(ndarray)
Creates a Tensor from NumPy/CuPy ndarray.
pytorch_pfn_extras.as_ndarray(tensor)
Creates a NumPy/CuPy ndarray from Tensor.
pytorch_pfn_extras.get_xp(tensor_device_or_ndarray)
Returns
numpy
orcupy
module for the given object.
pytorch_pfn_extras.as_numpy_dtype(torch_dtype)
Returns NumPy dtype for the given torch dtype.
pytorch_pfn_extras.from_numpy_dtype(numpy_dtype)
Returns torch dtype for the given NumPy dtype.
API Reference¶
Training Loop¶
Trainer (techincal preview)¶
|
Creates a trainer object. |
|
Creates an evaluator object. |
|
|
|
|
|
|
|
|
|
A base class for collections of device-specific callback functions. |
|
A collections of callback functions for the devices that PyTorch supports by default. |
Extensions Manager¶
|
Manages the extensions and the current status. |
|
Manages extensions and the current status in Ignite training loop. |
Extensions¶
|
Decorator to make given function into an extension. |
Base class of extensions. |
|
|
Extension and options. |
|
Extension traces the best value of a specific key in the observation. |
|
An extension to evaluate models on a validation set. |
|
An extension to output the accumulated results to a log file. |
|
Extension traces the maximum value of a specific key in the observation. |
|
Calculates micro-average ratio. |
|
Extension traces the maximum value of a specific key in the observation. |
|
Returns an extension to record the learning rate. |
Returns an extension to continuously record a value. |
|
An extension to report parameter statistics. |
|
|
An extension to output plots. |
|
An extension to print the accumulated results. |
An extension to print a progress bar and recent training status. |
|
Writes the profile results to a file. |
|
|
Returns a trainer extension to take snapshots of the trainer. |
|
An extension to communicate with Slack. |
|
An extension to communicate with Slack using Incoming Webhook. |
An extension to plot statistics for |
Triggers¶
Trigger for Early Stopping |
|
|
Trigger based on a fixed interval. |
Trigger invoked at specified point(s) of iterations or epochs. |
|
|
Trigger invoked when specific value becomes best. |
|
Trigger invoked when specific value becomes maximum. |
|
Trigger invoked when specific value becomes minimum. |
|
Trigger based on the starting point of the iteration. |
|
Trigger based on a fixed time interval. |
Reporting¶
Object to which observed values are reported. |
|
|
Reports observed values with the current reporter object. |
|
Returns a report scope with the current reporter. |
Logging¶
|
Returns a child logger to be used by applications. |
Profiler¶
|
Context manager to automatically report execution times. |
Distributed Training¶
Module for distributed data parallelism |
|
Initialize torch.distributed environments with values taken from OpenMPI. |
Check Pointing¶
Lazy Modules¶
|
Module to check the shape of a tensor. |
|
Checks the shape and type of a tensor. |
|
Linear module with lazy weight initialization. |
|
Conv1d module with lazy weight initialization. |
|
Conv2d module with lazy weight initialization. |
|
Conv3d module with lazy weight initialization. |
|
BatchNorm1d module with lazy weight initialization. |
|
BatchNorm2d module with lazy weight initialization. |
|
BatchNorm3d module with lazy weight initialization. |
ONNX¶
Export¶
|
Export model into ONNX Graph. |
|
Export model and I/O tensors of the model in protobuf format. |
Annotation¶
|
Annotation parameters to the target function. |
|
Annotation applier to the target function |
|
Add anchor node to the scoped modules |
|
Export model into ONNX Graph. |
|
Export model and I/O tensors of the model in protobuf format. |
Datasets¶
|
Dataset that caches the load samples in shared memory |
|
An abstract class that represents tabular dataset. |
NumPy/CuPy Compatibility¶
|
Creates a torch.Tensor from a numpy.ndarray or cupy.ndarray. |
|
Creates a numpy.ndarray or cupy.ndarray from torch.Tensor. |
|
Returns a module of ndarray implementation (numpy or cupy) for the given obj. |
|
Returns NumPy dtype for the given PyTorch dtype. |
|
Returns PyTorch dtype for the given NumPy dtype. |
|
Context-manager that selects a given stream. |
Use the PyTorch memory pool in CuPy. |
|
Use the default memory pool in CuPy. |