You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/fetcher.py

88 lines
2.7 KiB

import torch
from .constants import *
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
class FetcherXla:
def __init__(self):
pass
class Fetcher:
def __init__(
self,
loader,
device: torch.device,
dtype=torch.float32,
normalize=True,
normalize_shape=(1, 3, 1, 1),
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
num_aug_splits=0,
use_mp_loader=False,
):
self.loader = loader
self.device = torch.device(device)
self.dtype = dtype
if normalize:
self.mean = torch.tensor(
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.:
# NOTE RandomErasing shouldn't be used here w/ XLA devices
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
else:
self.random_erasing = None
self.use_mp_loader = use_mp_loader
if use_mp_loader:
# FIXME testing for TPU use
import torch_xla.distributed.parallel_loader as pl
self._loader = pl.MpDeviceLoader(loader, device)
else:
self._loader = loader
def __iter__(self):
for sample, target in self._loader:
if not self.use_mp_loader:
sample = sample.to(device=self.device)
target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype)
if self.mean is not None:
sample.sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
sample = self.random_erasing(sample)
yield sample, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x