You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
4.4 KiB
147 lines
4.4 KiB
""" Quick n Simple Image Folder, Tarfile based DataSet
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
import torch.utils.data as data
|
|
import os
|
|
import torch
|
|
import logging
|
|
|
|
from PIL import Image
|
|
|
|
from .parsers import create_parser
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
_ERROR_RETRY = 50
|
|
|
|
|
|
class ImageDataset(data.Dataset):
|
|
|
|
def __init__(
|
|
self,
|
|
root,
|
|
parser=None,
|
|
class_map='',
|
|
load_bytes=False,
|
|
transform=None,
|
|
):
|
|
if parser is None or isinstance(parser, str):
|
|
parser = create_parser(parser or '', root=root, class_map=class_map)
|
|
self.parser = parser
|
|
self.load_bytes = load_bytes
|
|
self.transform = transform
|
|
self._consecutive_errors = 0
|
|
|
|
def __getitem__(self, index):
|
|
img, target = self.parser[index]
|
|
try:
|
|
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
|
|
except Exception as e:
|
|
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
|
|
self._consecutive_errors += 1
|
|
if self._consecutive_errors < _ERROR_RETRY:
|
|
return self.__getitem__((index + 1) % len(self.parser))
|
|
else:
|
|
raise e
|
|
self._consecutive_errors = 0
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
if target is None:
|
|
target = torch.tensor(-1, dtype=torch.long)
|
|
return img, target
|
|
|
|
def __len__(self):
|
|
return len(self.parser)
|
|
|
|
def filename(self, index, basename=False, absolute=False):
|
|
return self.parser.filename(index, basename, absolute)
|
|
|
|
def filenames(self, basename=False, absolute=False):
|
|
return self.parser.filenames(basename, absolute)
|
|
|
|
|
|
class IterableImageDataset(data.IterableDataset):
|
|
|
|
def __init__(
|
|
self,
|
|
root,
|
|
parser=None,
|
|
split='train',
|
|
is_training=False,
|
|
batch_size=None,
|
|
class_map='',
|
|
load_bytes=False,
|
|
repeats=0,
|
|
transform=None,
|
|
):
|
|
assert parser is not None
|
|
if isinstance(parser, str):
|
|
self.parser = create_parser(
|
|
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
|
|
else:
|
|
self.parser = parser
|
|
self.transform = transform
|
|
self._consecutive_errors = 0
|
|
|
|
def __iter__(self):
|
|
for img, target in self.parser:
|
|
if self.transform is not None:
|
|
img = self.transform(img)
|
|
if target is None:
|
|
target = torch.tensor(-1, dtype=torch.long)
|
|
yield img, target
|
|
|
|
def __len__(self):
|
|
if hasattr(self.parser, '__len__'):
|
|
return len(self.parser)
|
|
else:
|
|
return 0
|
|
|
|
def filename(self, index, basename=False, absolute=False):
|
|
assert False, 'Filename lookup by index not supported, use filenames().'
|
|
|
|
def filenames(self, basename=False, absolute=False):
|
|
return self.parser.filenames(basename, absolute)
|
|
|
|
|
|
class AugMixDataset(torch.utils.data.Dataset):
|
|
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
|
|
|
|
def __init__(self, dataset, num_splits=2):
|
|
self.augmentation = None
|
|
self.normalize = None
|
|
self.dataset = dataset
|
|
if self.dataset.transform is not None:
|
|
self._set_transforms(self.dataset.transform)
|
|
self.num_splits = num_splits
|
|
|
|
def _set_transforms(self, x):
|
|
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
|
|
self.dataset.transform = x[0]
|
|
self.augmentation = x[1]
|
|
self.normalize = x[2]
|
|
|
|
@property
|
|
def transform(self):
|
|
return self.dataset.transform
|
|
|
|
@transform.setter
|
|
def transform(self, x):
|
|
self._set_transforms(x)
|
|
|
|
def _normalize(self, x):
|
|
return x if self.normalize is None else self.normalize(x)
|
|
|
|
def __getitem__(self, i):
|
|
x, y = self.dataset[i] # all splits share the same dataset base transform
|
|
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
|
|
# run the full augmentation on the remaining splits
|
|
for _ in range(self.num_splits - 1):
|
|
x_list.append(self._normalize(self.augmentation(x)))
|
|
return tuple(x_list), y
|
|
|
|
def __len__(self):
|
|
return len(self.dataset)
|