You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
177 lines
5.5 KiB
177 lines
5.5 KiB
""" Loader Factory, Fast Collate, CUDA Prefetcher
|
|
|
|
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
|
|
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
|
|
|
|
Hacked together by / Copyright 2019 Ross Wightman
|
|
"""
|
|
|
|
from typing import Optional, Callable
|
|
|
|
import numpy as np
|
|
import torch.utils.data
|
|
|
|
from timm.bits import DeviceEnv
|
|
from .collate import fast_collate
|
|
from .config import PreprocessCfg, MixupCfg
|
|
from .distributed_sampler import OrderedDistributedSampler
|
|
from .fetcher import Fetcher
|
|
from .mixup import FastCollateMixup
|
|
from .prefetcher_cuda import PrefetcherCuda
|
|
from .transforms_factory import create_transform_v2
|
|
|
|
|
|
def _worker_init(worker_id):
|
|
worker_info = torch.utils.data.get_worker_info()
|
|
assert worker_info.id == worker_id
|
|
np.random.seed(worker_info.seed % (2**32-1))
|
|
|
|
|
|
def create_loader_v2(
|
|
dataset: torch.utils.data.Dataset,
|
|
batch_size: int,
|
|
is_training: bool = False,
|
|
dev_env: Optional[DeviceEnv] = None,
|
|
pp_cfg: PreprocessCfg = PreprocessCfg(),
|
|
mix_cfg: MixupCfg = None,
|
|
create_transform: bool = True,
|
|
normalize_in_transform: bool = True,
|
|
separate_transform: bool = False,
|
|
num_workers: int = 1,
|
|
collate_fn: Optional[Callable] = None,
|
|
pin_memory: bool = False,
|
|
use_multi_epochs_loader: bool = False,
|
|
persistent_workers: bool = True,
|
|
):
|
|
"""
|
|
|
|
Args:
|
|
dataset:
|
|
batch_size:
|
|
is_training:
|
|
dev_env:
|
|
pp_cfg:
|
|
mix_cfg:
|
|
create_transform:
|
|
normalize_in_transform:
|
|
separate_transform:
|
|
num_workers:
|
|
collate_fn:
|
|
pin_memory:
|
|
use_multi_epochs_loader:
|
|
persistent_workers:
|
|
|
|
Returns:
|
|
|
|
"""
|
|
if dev_env is None:
|
|
dev_env = DeviceEnv.instance()
|
|
|
|
if create_transform:
|
|
dataset.transform = create_transform_v2(
|
|
cfg=pp_cfg,
|
|
is_training=is_training,
|
|
normalize=normalize_in_transform,
|
|
separate=separate_transform,
|
|
)
|
|
|
|
sampler = None
|
|
if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
|
|
if is_training:
|
|
sampler = torch.utils.data.distributed.DistributedSampler(
|
|
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
|
|
else:
|
|
# This will add extra duplicate entries to result in equal num
|
|
# of samples per-process, will slightly alter validation results
|
|
sampler = OrderedDistributedSampler(
|
|
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
|
|
|
|
if collate_fn is None:
|
|
if mix_cfg is not None and mix_cfg.prob > 0:
|
|
collate_fn = FastCollateMixup(
|
|
mixup_alpha=mix_cfg.mixup_alpha,
|
|
cutmix_alpha=mix_cfg.cutmix_alpha,
|
|
cutmix_minmax=mix_cfg.cutmix_minmax,
|
|
prob=mix_cfg.prob,
|
|
switch_prob=mix_cfg.switch_prob,
|
|
mode=mix_cfg.mode,
|
|
correct_lam=mix_cfg.correct_lam,
|
|
label_smoothing=mix_cfg.label_smoothing,
|
|
num_classes=mix_cfg.num_classes,
|
|
)
|
|
else:
|
|
collate_fn = fast_collate
|
|
|
|
loader_class = torch.utils.data.DataLoader
|
|
if use_multi_epochs_loader:
|
|
loader_class = MultiEpochsDataLoader
|
|
|
|
loader_args = dict(
|
|
batch_size=batch_size,
|
|
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
|
|
num_workers=num_workers,
|
|
sampler=sampler,
|
|
collate_fn=collate_fn,
|
|
pin_memory=pin_memory,
|
|
drop_last=is_training,
|
|
worker_init_fn=_worker_init,
|
|
persistent_workers=persistent_workers)
|
|
try:
|
|
loader = loader_class(dataset, **loader_args)
|
|
except TypeError as e:
|
|
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
|
|
loader = loader_class(dataset, **loader_args)
|
|
|
|
fetcher_kwargs = dict(
|
|
normalize=not normalize_in_transform,
|
|
mean=pp_cfg.mean,
|
|
std=pp_cfg.std,
|
|
)
|
|
if not normalize_in_transform and is_training and pp_cfg.aug is not None:
|
|
# If normalization can be done in the prefetcher, random erasing is done there too
|
|
# NOTE RandomErasing does not work well in XLA so normalize_in_transform will be True
|
|
fetcher_kwargs.update(dict(
|
|
re_prob=pp_cfg.aug.re_prob,
|
|
re_mode=pp_cfg.aug.re_mode,
|
|
re_count=pp_cfg.aug.re_count,
|
|
num_aug_splits=pp_cfg.aug.num_aug_splits,
|
|
))
|
|
if dev_env.type_cuda:
|
|
loader = PrefetcherCuda(loader, **fetcher_kwargs)
|
|
else:
|
|
loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs)
|
|
|
|
return loader
|
|
|
|
|
|
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._DataLoader__initialized = False
|
|
self.batch_sampler = _RepeatSampler(self.batch_sampler)
|
|
self._DataLoader__initialized = True
|
|
self.iterator = super().__iter__()
|
|
|
|
def __len__(self):
|
|
return len(self.batch_sampler.sampler)
|
|
|
|
def __iter__(self):
|
|
for i in range(len(self)):
|
|
yield next(self.iterator)
|
|
|
|
|
|
class _RepeatSampler(object):
|
|
""" Sampler that repeats forever.
|
|
|
|
Args:
|
|
sampler (Sampler)
|
|
"""
|
|
|
|
def __init__(self, sampler):
|
|
self.sampler = sampler
|
|
|
|
def __iter__(self):
|
|
while True:
|
|
yield from iter(self.sampler)
|