You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/data/dataset.py

179 lines
5.5 KiB

""" Quick n Simple Image Folder, Tarfile based DataSet
Hacked together by / Copyright 2020 Ross Wightman
"""
import csv
import torch.utils.data as data
import os
import torch
import logging
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from sklearn import preprocessing
from .parsers import create_parser
_logger = logging.getLogger(__name__)
_ERROR_RETRY = 50
class ImageDataset(data.Dataset):
def __init__(
self,
root,
parser=None,
class_map='',
load_bytes=False,
transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self._consecutive_errors = 0
def __getitem__(self, index):
img, target = self.parser[index]
try:
img = img.read() if self.load_bytes else Image.open(img).convert('RGB')
except Exception as e:
_logger.warning(f'Skipped sample (index {index}, file {self.parser.filename(index)}). {str(e)}')
self._consecutive_errors += 1
if self._consecutive_errors < _ERROR_RETRY:
return self.__getitem__((index + 1) % len(self.parser))
else:
raise e
self._consecutive_errors = 0
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
return img, target
def __len__(self):
return len(self.parser)
def filename(self, index, basename=False, absolute=False):
return self.parser.filename(index, basename, absolute)
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class IterableImageDataset(data.IterableDataset):
def __init__(
self,
root,
parser=None,
split='train',
is_training=False,
batch_size=None,
class_map='',
load_bytes=False,
repeats=0,
transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
else:
self.parser = parser
self.transform = transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
yield img, target
def __len__(self):
if hasattr(self.parser, '__len__'):
return len(self.parser)
else:
return 0
def filename(self, index, basename=False, absolute=False):
assert False, 'Filename lookup by index not supported, use filenames().'
def filenames(self, basename=False, absolute=False):
return self.parser.filenames(basename, absolute)
class AugMixDataset(torch.utils.data.Dataset):
"""Dataset wrapper to perform AugMix or other clean/augmentation mixes"""
def __init__(self, dataset, num_splits=2):
self.augmentation = None
self.normalize = None
self.dataset = dataset
if self.dataset.transform is not None:
self._set_transforms(self.dataset.transform)
self.num_splits = num_splits
def _set_transforms(self, x):
assert isinstance(x, (list, tuple)) and len(x) == 3, 'Expecting a tuple/list of 3 transforms'
self.dataset.transform = x[0]
self.augmentation = x[1]
self.normalize = x[2]
@property
def transform(self):
return self.dataset.transform
@transform.setter
def transform(self, x):
self._set_transforms(x)
def _normalize(self, x):
return x if self.normalize is None else self.normalize(x)
def __getitem__(self, i):
x, y = self.dataset[i] # all splits share the same dataset base transform
x_list = [self._normalize(x)] # first split only normalizes (this is the 'clean' split)
# run the full augmentation on the remaining splits
for _ in range(self.num_splits - 1):
x_list.append(self._normalize(self.augmentation(x)))
return tuple(x_list), y
def __len__(self):
return len(self.dataset)
class COAIImageClassDataset(Dataset):
def __init__(self, dict, base_path='', transform=None):
self.transform = transform
self.base_path = base_path
self.dict = dict
df = pd.DataFrame(list(dict.items()), columns=['image', 'label'])
classes = df['label'].unique()
le = preprocessing.LabelEncoder()
le.fit(classes)
df['encoded_label'] = le.transform(df['label'])
self.df = df
def __len__(self):
index = self.df.index
number_of_rows = len(index)
return number_of_rows
def __getitem__(self, index):
img_path = self.df.iloc[index]['image']
image = Image.open(self.base_path + img_path)
np_img = np.array(image)
# print(np_img.shape) #(h=512,w=512,c=3)
if self.transform:
np_img = self.transform(np_img)
# print(np_img.shape) #(1,256,256)
return np_img, self.df.iloc[index]['encoded_label']