You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/loader.py

158 lines
4.6 KiB

""" Loader Factory, Fast Collate, CUDA Prefetcher
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch.utils.data
from timm.bits import get_device
from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda
from .collate import fast_collate
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
dev_env=None,
no_aug=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_split=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
crop_pct=None,
collate_fn=None,
pin_memory=False,
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
):
re_num_splits = 0
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = num_aug_splits or 2
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_fetcher=True,
no_aug=no_aug,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=num_aug_splits > 0,
)
if dev_env is None:
dev_env = get_device()
sampler = None
if dev_env.is_distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
if collate_fn is None:
collate_fn = fast_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
loader_class = MultiEpochsDataLoader
loader_args = dict(
batch_size=batch_size,
shuffle=not isinstance(dataset, torch.utils.data.IterableDataset) and sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
persistent_workers=persistent_workers)
try:
loader = loader_class(dataset, **loader_args)
except TypeError as e:
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
fetcher_kwargs = dict(
mean=mean,
std=std,
re_prob=re_prob if is_training and not no_aug else 0.,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
if dev_env.type == 'cuda':
loader = PrefetcherCuda(loader, **fetcher_kwargs)
else:
loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs)
return loader
class MultiEpochsDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)