""" Quick n Simple Image Folder, Tarfile based DataSet Hacked together by / Copyright 2019, Ross Wightman """ import io import logging import torch import torch.utils.data as data from PIL import Image from .parsers import create_parser _logger = logging.getLogger(__name__) _ERROR_RETRY = 50 class ImageDataset(data.Dataset): def __init__( self, root, parser=None, split='train', class_map=None, load_bytes=False, img_mode='RGB', transform=None, target_transform=None, ): if parser is None or isinstance(parser, str): parser = create_parser( parser or '', root=root, split=split, class_map=class_map ) self.parser = parser self.load_bytes = load_bytes self.img_mode = img_mode self.transform = transform self.target_transform = target_transform self._consecutive_errors = 0 def __getitem__(self, index): img, target = self.parser[index] try: img = img.read() if self.load_bytes else Image.open(img) except Exception as e: _logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}') self._consecutive_errors += 1 if self._consecutive_errors < _ERROR_RETRY: return self.__getitem__((index + 1) % len(self.parser)) else: raise e self._consecutive_errors = 0 if self.img_mode and not self.load_bytes: img = img.convert(self.img_mode) if self.transform is not None: img = self.transform(img) if target is None: target = -1 elif self.target_transform is not None: target = self.target_transform(target) return img, target def __len__(self): return len(self.parser) def filename(self, index, basename=False, absolute=False): return self.parser.filename(index, basename, absolute) def filenames(self, basename=False, absolute=False): return self.parser.filenames(basename, absolute) class IterableImageDataset(data.IterableDataset): def __init__( self, root, parser=None, split='train', is_training=False, batch_size=None, repeats=0, download=False, transform=None, target_transform=None, ): assert parser is not None if isinstance(parser, str): self.parser = create_parser( parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats, download=download, ) else: self.parser = parser self.transform = transform self.target_transform = target_transform self._consecutive_errors = 0 def __iter__(self): for img, target in self.parser: if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) yield img, target def __len__(self): if hasattr(self.parser, '__len__'): return len(self.parser) else: return 0 def filename(self, index, basename=False, absolute=False): assert False, 'Filename lookup by index not supported, use filenames().' def filenames(self, basename=False, absolute=False): return self.parser.filenames(basename, absolute) class AugMixDataset(torch.utils.data.Dataset): """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" def __init__(self, dataset, num_splits=2): self.augmentation = None self.normalize = None self.dataset = dataset if self.dataset.transform is not None: self._set_transforms(self.dataset.transform) self.num_splits = num_splits def _set_transforms(self, x): assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' self.dataset.transform = x[0] self.augmentation = x[1] self.normalize = x[2] @property def transform(self): return self.dataset.transform @transform.setter def transform(self, x): self._set_transforms(x) def _normalize(self, x): return x if self.normalize is None else self.normalize(x) def __getitem__(self, i): x, y = self.dataset[i] # all splits share the same dataset base transform x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) # run the full augmentation on the remaining splits for _ in range(self.num_splits - 1): x_list.append(self._normalize(self.augmentation(x))) return tuple(x_list), y def __len__(self): return len(self.dataset)