pytorch_pfn_extras.nn.ensure

pytorch_pfn_extras.nn.ensure(tensor, shape=None, dtype=None, broadcastable=False, can_cast=False)

Checks the shape and type of a tensor.

Parameters
  • shape (Optional[Tuple[Optional[int], ...]]) – Tuple with the desired shape. If the input tensor shape is not compatible, ValueError will be raised. If None is set as a dimension value, that dimension will be ignored.

  • dtype (Optional[torch.dtype]) – Checks if the dtype of the input thensor matches the provided one.

  • broadcastable (bool) – Check if the shapes are compatible using broadcasting rules.

  • can_cast (bool) – Check if the input tensor can be casted to the provided type.

  • tensor (torch.Tensor) –

Return type

None