|
|
|
import torch
|
|
|
|
|
|
|
|
from .constants import *
|
|
|
|
from .random_erasing import RandomErasing
|
|
|
|
from .mixup import FastCollateMixup
|
|
|
|
|
|
|
|
|
|
|
|
class FetcherXla:
|
|
|
|
def __init__(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class Fetcher:
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
loader,
|
|
|
|
device: torch.device,
|
|
|
|
dtype=torch.float32,
|
|
|
|
normalize=True,
|
|
|
|
normalize_shape=(1, 3, 1, 1),
|
|
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
|
|
std=IMAGENET_DEFAULT_STD,
|
|
|
|
re_prob=0.,
|
|
|
|
re_mode='const',
|
|
|
|
re_count=1,
|
|
|
|
num_aug_splits=0,
|
|
|
|
use_mp_loader=False,
|
|
|
|
):
|
|
|
|
self.loader = loader
|
|
|
|
self.device = torch.device(device)
|
|
|
|
self.dtype = dtype
|
|
|
|
if normalize:
|
|
|
|
self.mean = torch.tensor(
|
|
|
|
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
|
|
|
|
self.std = torch.tensor(
|
|
|
|
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
|
|
|
|
else:
|
|
|
|
self.mean = None
|
|
|
|
self.std = None
|
|
|
|
if re_prob > 0.:
|
|
|
|
# NOTE RandomErasing shouldn't be used here w/ XLA devices
|
|
|
|
self.random_erasing = RandomErasing(
|
|
|
|
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
|
|
|
|
else:
|
|
|
|
self.random_erasing = None
|
|
|
|
self.use_mp_loader = use_mp_loader
|
|
|
|
if use_mp_loader:
|
|
|
|
# FIXME testing for TPU use
|
|
|
|
import torch_xla.distributed.parallel_loader as pl
|
|
|
|
self._loader = pl.MpDeviceLoader(loader, device)
|
|
|
|
else:
|
|
|
|
self._loader = loader
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
for sample, target in self._loader:
|
|
|
|
if not self.use_mp_loader:
|
|
|
|
sample = sample.to(device=self.device)
|
|
|
|
target = target.to(device=self.device)
|
|
|
|
sample = sample.to(dtype=self.dtype)
|
|
|
|
if self.mean is not None:
|
|
|
|
sample.sub_(self.mean).div_(self.std)
|
|
|
|
if self.random_erasing is not None:
|
|
|
|
sample = self.random_erasing(sample)
|
|
|
|
yield sample, target
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.loader)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def sampler(self):
|
|
|
|
return self.loader.sampler
|
|
|
|
|
|
|
|
@property
|
|
|
|
def dataset(self):
|
|
|
|
return self.loader.dataset
|
|
|
|
|
|
|
|
@property
|
|
|
|
def mixup_enabled(self):
|
|
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
|
|
return self.loader.collate_fn.mixup_enabled
|
|
|
|
else:
|
|
|
|
return False
|
|
|
|
|
|
|
|
@mixup_enabled.setter
|
|
|
|
def mixup_enabled(self, x):
|
|
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
|
|
self.loader.collate_fn.mixup_enabled = x
|