""" Loader Factory, Fast Collate, CUDA Prefetcher Prefetcher and Fast Collate inspired by NVIDIA APEX example at https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf Hacked together by / Copyright 2020 Ross Wightman """ import torch.utils.data from timm.bits import DeviceEnv from .fetcher import Fetcher from .prefetcher_cuda import PrefetcherCuda from .collate import fast_collate from .transforms_factory import create_transform from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .distributed_sampler import OrderedDistributedSampler def create_loader( dataset, input_size, batch_size, is_training=False, dev_env=None, no_aug=False, re_prob=0., re_mode='const', re_count=1, re_split=False, scale=None, ratio=None, hflip=0.5, vflip=0., color_jitter=0.4, auto_augment=None, num_aug_splits=0, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_workers=1, crop_pct=None, collate_fn=None, pin_memory=False, tf_preprocessing=False, use_multi_epochs_loader=False, persistent_workers=True, ): re_num_splits = 0 if re_split: # apply RE to second half of batch if no aug split otherwise line up with aug split re_num_splits = num_aug_splits or 2 dataset.transform = create_transform( input_size, is_training=is_training, use_fetcher=True, no_aug=no_aug, scale=scale, ratio=ratio, hflip=hflip, vflip=vflip, color_jitter=color_jitter, auto_augment=auto_augment, interpolation=interpolation, mean=mean, std=std, crop_pct=crop_pct, tf_preprocessing=tf_preprocessing, re_prob=re_prob, re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits, separate=num_aug_splits > 0, ) if dev_env is None: dev_env = DeviceEnv.instance() sampler = None if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) else: # This will add extra duplicate entries to result in equal num # of samples per-process, will slightly alter validation results sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) if collate_fn is None: collate_fn = fast_collate loader_class = torch.utils.data.DataLoader if use_multi_epochs_loader: loader_class = MultiEpochsDataLoader loader_args = dict( batch_size=batch_size, shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training, num_workers=num_workers, sampler=sampler, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=is_training, persistent_workers=persistent_workers) try: loader = loader_class(dataset, **loader_args) except TypeError as e: loader_args.pop('persistent_workers') # only in Pytorch 1.7+ loader = loader_class(dataset, **loader_args) fetcher_kwargs = dict( mean=mean, std=std, re_prob=re_prob if is_training and not no_aug else 0., re_mode=re_mode, re_count=re_count, re_num_splits=re_num_splits ) if dev_env.type_cuda: loader = PrefetcherCuda(loader, **fetcher_kwargs) else: loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs) return loader class MultiEpochsDataLoader(torch.utils.data.DataLoader): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._DataLoader__initialized = False self.batch_sampler = _RepeatSampler(self.batch_sampler) self._DataLoader__initialized = True self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for i in range(len(self)): yield next(self.iterator) class _RepeatSampler(object): """ Sampler that repeats forever. Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler)