pytorch_pfn_extras.get_xp

pytorch_pfn_extras.get_xp(obj: Union[pytorch_pfn_extras._tensor._NDArray, torch.Tensor]) Any

Returns a module of ndarray implementation (numpy or cupy) for the given obj.

The obj can be torch.Tensor, torch.device or NumPy/CuPy ndarray.