""" Quick n Simple Image Folder, Tarfile based DataSet Hacked together by / Copyright 2020 Ross Wightman """ import torch.utils.data as data import os import torch from .parsers import ParserImageFolder, ParserImageTar class ImageDataset(data.Dataset): def __init__( self, img_root, parser=None, class_map='', load_bytes=False, transform=None, ): self.img_root = img_root if parser is None: if os.path.isfile(img_root) and os.path.splitext(img_root)[1] == '.tar': parser = ParserImageTar(img_root, load_bytes=load_bytes, class_map=class_map) else: parser = ParserImageFolder(img_root, load_bytes=load_bytes, class_map=class_map) self.parser = parser self.load_bytes = load_bytes self.transform = transform def __getitem__(self, index): img, target = self.parser[index] if self.transform is not None: img = self.transform(img) if target is None: target = torch.tensor(-1, dtype=torch.long) return img, target def __len__(self): return len(self.parser) def filename(self, index, basename=False, absolute=False): return self.parser.filename(index, basename, absolute) def filenames(self, basename=False, absolute=False): return self.parser.filenames(basename, absolute) class AugMixDataset(torch.utils.data.Dataset): """Dataset wrapper to perform AugMix or other clean/augmentation mixes""" def __init__(self, dataset, num_splits=2): self.augmentation = None self.normalize = None self.dataset = dataset if self.dataset.transform is not None: self._set_transforms(self.dataset.transform) self.num_splits = num_splits def _set_transforms(self, x): assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms' self.dataset.transform = x[0] self.augmentation = x[1] self.normalize = x[2] @property def transform(self): return self.dataset.transform @transform.setter def transform(self, x): self._set_transforms(x) def _normalize(self, x): return x if self.normalize is None else self.normalize(x) def __getitem__(self, i): x, y = self.dataset[i] # all splits share the same dataset base transform x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split) # run the full augmentation on the remaining splits for _ in range(self.num_splits - 1): x_list.append(self._normalize(self.augmentation(x))) return tuple(x_list), y def __len__(self): return len(self.dataset)