You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.8 KiB
38 lines
1.8 KiB
import numpy as np
|
|
|
|
import torch
|
|
|
|
|
|
def fast_collate(batch):
|
|
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
|
assert isinstance(batch[0], tuple)
|
|
batch_size = len(batch)
|
|
if isinstance(batch[0][0], tuple):
|
|
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
|
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
|
inner_tuple_size = len(batch[0][0])
|
|
flattened_batch_size = batch_size * inner_tuple_size
|
|
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
|
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
|
for j in range(inner_tuple_size):
|
|
targets[i + j * batch_size] = batch[i][1]
|
|
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
|
return tensor, targets
|
|
elif isinstance(batch[0][0], np.ndarray):
|
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
assert len(targets) == batch_size
|
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
tensor[i] += torch.from_numpy(batch[i][0])
|
|
return tensor, targets
|
|
elif isinstance(batch[0][0], torch.Tensor):
|
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
assert len(targets) == batch_size
|
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=batch[0][0].dtype)
|
|
for i in range(batch_size):
|
|
tensor[i].copy_(batch[i][0])
|
|
return tensor, targets
|
|
else:
|
|
assert False |