pytorch_pfn_extras.from_numpy_dtype

pytorch_pfn_extras.from_numpy_dtype(numpy_dtype)

Returns PyTorch dtype for the given NumPy dtype.

Parameters

numpy_dtype (numpy.dtype) – NumPy’s dtype object.

Returns

PyTorch type object.

Return type

torch.dtype