pytorch_pfn_extras.get_xp#

pytorch_pfn_extras.get_xp(obj)#

Returns a module of ndarray implementation (numpy or cupy) for the given obj.

The obj can be torch.Tensor, torch.device or NumPy/CuPy ndarray.

Parameters:

obj (Union[Any, Tensor]) –

Return type:

Any