Runtimes for Custom Devices#

Note

This documentation is intended for those implementing the own device backend for PPE training framework. Most users can just skip this chapter.

The ppe.runtime.BaseRuntime API is in charge of abstracting the device details and performing the movement of data and modules to the corresponding device.

A runtime is an object that defines multiple callbacks used through the training, evaluation, and regular model calls. With runtimes, we can implement training in devices other than cpus or gpus with minimal changes to the user code.

Users wanting to override only a few callbacks can inherit from ppe.runtime.PyTorchRuntime which implements the basic functionality for cpu and gpu devices.

Runtimes must be registered by calling the ppe.runtime.runtime_registry.register(device_name, runtime_class) function for them to be discoverable.

Use of ppe.to to transfer modules and batches to custom devices#

If you have defined a new runtime for a custom device the ppe.to function allows moving a module or a tensor to the new device by invoking the Runtime.move_tensor and Runtime.move_module when needed.

The module will be tagged by adding a attribute named _ppe_runtime that holds the needed runtime. It is the responsibility of the user custom runtime to perform the actual movement to the device and apply all the transformations needed to a module so it can be correctly executed.

Usually, runtime writers will need to replace the given module forward function by a new one that performs the actual device execution.

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 10)

    def forward(x):
        return self.layer(x)

class MyMagicDeviceRuntime(ppe.runtime.BaseRuntime):
    def _device_forward(self, args):
        return run_batch_in_my_device(args):

    def move_module(self, module):
        # Registers a hook to initialize the module on the first batch
        # execution
        def hook(module, *args):
            module._ppe_runtime.initialize_module(module, args)

        self.hook = module.register_forward_pre_hook(hook)
        # Change the module forward to do the computation in the device
        module.forward = self._device_forward

    def initialize_module(self, module, loader_or_batch, optimizer=None):
        create_the_module_in_my_device(module, loader_or_batch, optimizer)

# Register the runtime class
ppe.runtime.runtime_registry.register('my_device', MyMagicDeviceRuntime)

# Create a regular module
module = MyModule()
# Move the module to the device
ppe.to(module, device='my_device')

for x in my_dataloader:
    # The first iteration will create the module in the device
    # and the next ones will directly execute the module in the device instead
    # of executing the regular pytorch `forward` call.
    y = model(x)

Please note that this is an oversimplified description and that developing a runtime that is 100% compatible with PyTorch requires to wrap the substitute forward function with torch.autograd.Function among several other concerns such as state_dict manipulation to ensure correcteness.

Runtime Registry#

When creating a new Runtime class for custom needs, they need to be registered in a global runtime_registry object as detailed above. This object is of the _RuntimeRegistry type and it maintains a map of strings and Runtime types. The keys are the devices passed to ppe.to and the types will be the type of the Runtime object that ppe.to will use to treat the module or tensor. Beware that users are not supposed to interact directly with this class, only with the runtime_registry.register to register new runtimes.