You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/loader.py

187 lines
6.3 KiB

import torch.utils.data
import numpy as np
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
def fast_collate(batch):
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
class PrefetchLoader:
def __init__(self,
loader,
rand_erase_prob=0.,
rand_erase_mode='const',
rand_erase_count=1,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if rand_erase_prob > 0.:
self.random_erasing = RandomErasing(
probability=rand_erase_prob, mode=rand_erase_mode, max_count=rand_erase_count)
else:
self.random_erasing = None
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
rand_erase_prob=0.,
rand_erase_mode='const',
rand_erase_count=1,
color_jitter=0.4,
auto_augment=None,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
collate_fn=None,
fp16=False,
tf_preprocessing=False,
separate_transforms=False,
):
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_prefetcher=use_prefetcher,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
separate=separate_transforms,
)
sampler = None
if distributed:
if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset)
if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
drop_last=is_training,
)
if use_prefetcher:
loader = PrefetchLoader(
loader,
rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_mode=rand_erase_mode,
rand_erase_count=rand_erase_count,
mean=mean,
std=std,
fp16=fp16)
return loader