pytorch_pfn_extras.from_numpy_dtype

pytorch_pfn_extras.from_numpy_dtype(numpy_dtype: numpy.dtype) torch.dtype

Returns PyTorch dtype for the given NumPy dtype.

Parameters

numpy_dtype – NumPy’s dtype object.

Returns

PyTorch type object.