pytorch_pfn_extras.compile#

pytorch_pfn_extras.compile(module, optimizer=None, backend=None, *, generate_backward=True, decompositions=None)#

Compiles a module and an optimizer in a single graph using the provided backend.

Note

The backend object needs to be a callable accepting a torch.fx.GraphModule and a list of torch.Tensor and return a Callable as specified by https://pytorch.org/docs/2.0/dynamo/custom-backends.html#custom-backends

Note

Modules that are split in multiple graphs are not supported. torch.compiled is called with the fullgraph=True argument.

Parameters:
  • module (Module) – torch.nn.Module to be compiled

  • optimizer (Optional[Optimizer]) – Optimizer object associated to the module. It will be traced and its operations included in the module graph. Some dry run operations may be performed to fully initialize the optimizer status.

  • backend (optional) – Object to process the graph and compile it for custom devices, will use PyTorch dynamo by default if not specified.

  • generate_backward (bool) – Add the backward pass to the graph. Default is True.

  • decompositions (optional) – Custom mapping for decompose a torch op into simple ops. Default is None and resorts to torch._decomp.core_aten_decompositions()

Return type:

Callable[[…], Any]