pytorch_pfn_extras.runtime.PyTorchRuntime

class pytorch_pfn_extras.runtime.PyTorchRuntime(device_spec, options=None)

A collections of callback functions for the devices that PyTorch supports by default.

Parameters
  • device_spec (torch.device or str) – The device.

  • options (Optional[Dict[str, Any]]) –

Return type

None

__init__(device_spec, options=None)
Parameters
  • device_spec (Union[str, torch.device]) –

  • options (Optional[Dict[str, Any]]) –

Return type

None

Methods

__init__(device_spec[, options])

convert_batch(args)

Transfers the given batch to the specific device.

eval_post_step(evaluator, module, batch_idx, …)

The method called at the end of each evaluation.

eval_pre_step(evaluator, module, batch_idx, …)

The method called at the beginning of each evaluation.

initialize_module(module, loader_or_batch[, …])

Initializes the module at the beginning of training or inference.

move_module(module)

Transfers the module to the specific device.

move_tensor(tensor)

Transfers the tensor to the specific device.

train_epoch_begin(module)

Preprocess of each epoch.

train_post_step(trainer, module, batch_idx, …)

Postprocess of each step.

train_pre_step(trainer, module, batch_idx, batch)

Preprocess of each step.

train_validation_begin(module)

The method called before each evaluation.

train_validation_end(module)

The method called after each evaluation.