pytorch_pfn_extras.dataloaders.dataloader.default_collate#
- pytorch_pfn_extras.dataloaders.dataloader.default_collate(batch)#
Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
The exact output type can be a
torch.Tensor, a Sequence oftorch.Tensor, a Collection oftorch.Tensor, or left unchanged, depending on the input type. This is used as the default function for collation when batch_size or batch_sampler is defined inDataLoader.Here is the general input type (based on the type of the element within the batch) to output type mapping:
torch.Tensor->torch.Tensor(with an added outer dimension batch size)NumPy Arrays ->
torch.Tensorfloat ->
torch.Tensorint ->
torch.Tensorstr -> str (unchanged)
bytes -> bytes (unchanged)
Mapping[K, V_i] -> Mapping[K, default_collate([V_1, V_2, …])]
NamedTuple[V1_i, V2_i, …] -> NamedTuple[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
Sequence[V1_i, V2_i, …] -> Sequence[default_collate([V1_1, V1_2, …]), default_collate([V2_1, V2_2, …]), …]
- Parameters:
batch – a single batch to be collated
Examples
>>> # xdoctest: +SKIP >>> # Example with a batch of `int`s: >>> default_collate([0, 1, 2, 3]) tensor([0, 1, 2, 3]) >>> # Example with a batch of `str`s: >>> default_collate(["a", "b", "c"]) ['a', 'b', 'c'] >>> # Example with `Map` inside the batch: >>> default_collate([{"A": 0, "B": 1}, {"A": 100, "B": 100}]) {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])} >>> # Example with `NamedTuple` inside the batch: >>> Point = namedtuple("Point", ["x", "y"]) >>> default_collate([Point(0, 0), Point(1, 1)]) Point(x=tensor([0, 1]), y=tensor([0, 1])) >>> # Example with `Tuple` inside the batch: >>> default_collate([(0, 1), (2, 3)]) [tensor([0, 2]), tensor([1, 3])] >>> # Example with `List` inside the batch: >>> default_collate([[0, 1], [2, 3]]) [tensor([0, 2]), tensor([1, 3])] >>> # Two options to extend `default_collate` to handle specific type >>> # Option 1: Write custom collate function and invoke `default_collate` >>> def custom_collate(batch): ... elem = batch[0] ... if isinstance(elem, CustomType): # Some custom condition ... return ... ... else: # Fall back to `default_collate` ... return default_collate(batch) >>> # Option 2: In-place modify `default_collate_fn_map` >>> def collate_customtype_fn(batch, *, collate_fn_map=None): ... return ... >>> default_collate_fn_map.update(CustomType, collate_customtype_fn) >>> default_collate(batch) # Handle `CustomType` automatically