You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
342 lines
11 KiB
342 lines
11 KiB
""" Loader Factory, Fast Collate, CUDA Prefetcher
|
|
|
|
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
|
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
|
|
|
Hacked together by / Copyright 2019, Ross Wightman
|
|
"""
|
|
import logging
|
|
import random
|
|
from contextlib import suppress
|
|
from functools import partial
|
|
from itertools import repeat
|
|
from typing import Callable
|
|
|
|
import torch
|
|
import torch.utils.data
|
|
import numpy as np
|
|
|
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from .dataset import IterableImageDataset
|
|
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
|
|
from .random_erasing import RandomErasing
|
|
from .mixup import FastCollateMixup
|
|
from .transforms_factory import create_transform
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
def fast_collate(batch):
|
|
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
|
assert isinstance(batch[0], tuple)
|
|
batch_size = len(batch)
|
|
if isinstance(batch[0][0], tuple):
|
|
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
|
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
|
inner_tuple_size = len(batch[0][0])
|
|
flattened_batch_size = batch_size * inner_tuple_size
|
|
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
|
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
|
for j in range(inner_tuple_size):
|
|
targets[i + j * batch_size] = batch[i][1]
|
|
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
|
return tensor, targets
|
|
elif isinstance(batch[0][0], np.ndarray):
|
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
assert len(targets) == batch_size
|
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
tensor[i] += torch.from_numpy(batch[i][0])
|
|
return tensor, targets
|
|
elif isinstance(batch[0][0], torch.Tensor):
|
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
assert len(targets) == batch_size
|
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
|
for i in range(batch_size):
|
|
tensor[i].copy_(batch[i][0])
|
|
return tensor, targets
|
|
else:
|
|
assert False
|
|
|
|
|
|
def adapt_to_chs(x, n):
|
|
if not isinstance(x, (tuple, list)):
|
|
x = tuple(repeat(x, n))
|
|
elif len(x) != n:
|
|
x_mean = np.mean(x).item()
|
|
x = (x_mean,) * n
|
|
_logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.')
|
|
else:
|
|
assert len(x) == n, 'normalization stats must match image channels'
|
|
return x
|
|
|
|
|
|
class PrefetchLoader:
|
|
|
|
def __init__(
|
|
self,
|
|
loader,
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD,
|
|
channels=3,
|
|
device=torch.device('cuda'),
|
|
img_dtype=torch.float32,
|
|
fp16=False,
|
|
re_prob=0.,
|
|
re_mode='const',
|
|
re_count=1,
|
|
re_num_splits=0):
|
|
|
|
mean = adapt_to_chs(mean, channels)
|
|
std = adapt_to_chs(std, channels)
|
|
normalization_shape = (1, channels, 1, 1)
|
|
|
|
self.loader = loader
|
|
self.device = device
|
|
if fp16:
|
|
# fp16 arg is deprecated, but will override dtype arg if set for bwd compat
|
|
img_dtype = torch.float16
|
|
self.img_dtype = img_dtype
|
|
self.mean = torch.tensor(
|
|
[x * 255 for x in mean], device=device, dtype=img_dtype).view(normalization_shape)
|
|
self.std = torch.tensor(
|
|
[x * 255 for x in std], device=device, dtype=img_dtype).view(normalization_shape)
|
|
if re_prob > 0.:
|
|
self.random_erasing = RandomErasing(
|
|
probability=re_prob,
|
|
mode=re_mode,
|
|
max_count=re_count,
|
|
num_splits=re_num_splits,
|
|
device=device,
|
|
)
|
|
else:
|
|
self.random_erasing = None
|
|
self.is_cuda = torch.cuda.is_available() and device.type == 'cuda'
|
|
|
|
def __iter__(self):
|
|
first = True
|
|
if self.is_cuda:
|
|
stream = torch.cuda.Stream()
|
|
stream_context = partial(torch.cuda.stream, stream=stream)
|
|
else:
|
|
stream = None
|
|
stream_context = suppress
|
|
|
|
for next_input, next_target in self.loader:
|
|
|
|
with stream_context():
|
|
next_input = next_input.to(device=self.device, non_blocking=True)
|
|
next_target = next_target.to(device=self.device, non_blocking=True)
|
|
next_input = next_input.to(self.img_dtype).sub_(self.mean).div_(self.std)
|
|
if self.random_erasing is not None:
|
|
next_input = self.random_erasing(next_input)
|
|
|
|
if not first:
|
|
yield input, target
|
|
else:
|
|
first = False
|
|
|
|
if stream is not None:
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
|
|
input = next_input
|
|
target = next_target
|
|
|
|
yield input, target
|
|
|
|
def __len__(self):
|
|
return len(self.loader)
|
|
|
|
@property
|
|
def sampler(self):
|
|
return self.loader.sampler
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self.loader.dataset
|
|
|
|
@property
|
|
def mixup_enabled(self):
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
return self.loader.collate_fn.mixup_enabled
|
|
else:
|
|
return False
|
|
|
|
@mixup_enabled.setter
|
|
def mixup_enabled(self, x):
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
self.loader.collate_fn.mixup_enabled = x
|
|
|
|
|
|
def _worker_init(worker_id, worker_seeding='all'):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert worker_info.id == worker_id
|
|
if isinstance(worker_seeding, Callable):
|
|
seed = worker_seeding(worker_info)
|
|
random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed % (2 ** 32 - 1))
|
|
else:
|
|
assert worker_seeding in ('all', 'part')
|
|
# random / torch seed already called in dataloader iter class w/ worker_info.seed
|
|
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
|
|
if worker_seeding == 'all':
|
|
np.random.seed(worker_info.seed % (2 ** 32 - 1))
|
|
|
|
|
|
def create_loader(
|
|
dataset,
|
|
input_size,
|
|
batch_size,
|
|
is_training=False,
|
|
use_prefetcher=True,
|
|
no_aug=False,
|
|
re_prob=0.,
|
|
re_mode='const',
|
|
re_count=1,
|
|
re_split=False,
|
|
scale=None,
|
|
ratio=None,
|
|
hflip=0.5,
|
|
vflip=0.,
|
|
color_jitter=0.4,
|
|
auto_augment=None,
|
|
num_aug_repeats=0,
|
|
num_aug_splits=0,
|
|
interpolation='bilinear',
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD,
|
|
num_workers=1,
|
|
distributed=False,
|
|
crop_pct=None,
|
|
collate_fn=None,
|
|
pin_memory=False,
|
|
fp16=False, # deprecated, use img_dtype
|
|
img_dtype=torch.float32,
|
|
device=torch.device('cuda'),
|
|
tf_preprocessing=False,
|
|
use_multi_epochs_loader=False,
|
|
persistent_workers=True,
|
|
worker_seeding='all',
|
|
):
|
|
re_num_splits = 0
|
|
if re_split:
|
|
# apply RE to second half of batch if no aug split otherwise line up with aug split
|
|
re_num_splits = num_aug_splits or 2
|
|
dataset.transform = create_transform(
|
|
input_size,
|
|
is_training=is_training,
|
|
use_prefetcher=use_prefetcher,
|
|
no_aug=no_aug,
|
|
scale=scale,
|
|
ratio=ratio,
|
|
hflip=hflip,
|
|
vflip=vflip,
|
|
color_jitter=color_jitter,
|
|
auto_augment=auto_augment,
|
|
interpolation=interpolation,
|
|
mean=mean,
|
|
std=std,
|
|
crop_pct=crop_pct,
|
|
tf_preprocessing=tf_preprocessing,
|
|
re_prob=re_prob,
|
|
re_mode=re_mode,
|
|
re_count=re_count,
|
|
re_num_splits=re_num_splits,
|
|
separate=num_aug_splits > 0,
|
|
)
|
|
|
|
if isinstance(dataset, IterableImageDataset):
|
|
# give Iterable datasets early knowledge of num_workers so that sample estimates
|
|
# are correct before worker processes are launched
|
|
dataset.set_loader_cfg(num_workers=num_workers)
|
|
|
|
sampler = None
|
|
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
|
if is_training:
|
|
if num_aug_repeats:
|
|
sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
|
|
else:
|
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
else:
|
|
# This will add extra duplicate entries to result in equal num
|
|
# of samples per-process, will slightly alter validation results
|
|
sampler = OrderedDistributedSampler(dataset)
|
|
else:
|
|
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
|
|
|
|
if collate_fn is None:
|
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
|
|
|
loader_class = torch.utils.data.DataLoader
|
|
if use_multi_epochs_loader:
|
|
loader_class = MultiEpochsDataLoader
|
|
|
|
loader_args = dict(
|
|
batch_size=batch_size,
|
|
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
|
num_workers=num_workers,
|
|
sampler=sampler,
|
|
collate_fn=collate_fn,
|
|
pin_memory=pin_memory,
|
|
drop_last=is_training,
|
|
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
|
|
persistent_workers=persistent_workers
|
|
)
|
|
try:
|
|
loader = loader_class(dataset, **loader_args)
|
|
except TypeError as e:
|
|
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
|
loader = loader_class(dataset, **loader_args)
|
|
if use_prefetcher:
|
|
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
|
|
loader = PrefetchLoader(
|
|
loader,
|
|
mean=mean,
|
|
std=std,
|
|
channels=input_size[0],
|
|
device=device,
|
|
fp16=fp16, # deprecated, use img_dtype
|
|
img_dtype=img_dtype,
|
|
re_prob=prefetch_re_prob,
|
|
re_mode=re_mode,
|
|
re_count=re_count,
|
|
re_num_splits=re_num_splits
|
|
)
|
|
|
|
return loader
|
|
|
|
|
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._DataLoader__initialized = False
|
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
|
self._DataLoader__initialized = True
|
|
self.iterator = super().__iter__()
|
|
|
|
def __len__(self):
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
def __iter__(self):
|
|
for i in range(len(self)):
|
|
yield next(self.iterator)
|
|
|
|
|
|
class _RepeatSampler(object):
|
|
""" Sampler that repeats forever.
|
|
|
|
Args:
|
|
sampler (Sampler)
|
|
"""
|
|
|
|
def __init__(self, sampler):
|
|
self.sampler = sampler
|
|
|
|
def __iter__(self):
|
|
while True:
|
|
yield from iter(self.sampler)
|