Dataset work

* support some torchvision datasets
* improvements to TFDS wrapper for subsplit handling (fix #942), shuffle seed
* add class-map support to train (fix #957)
more_datasets
Ross Wightman 3 years ago
parent ddc29da974
commit ba65dfe2c6

@ -23,15 +23,17 @@ class ImageDataset(data.Dataset):
self, self,
root, root,
parser=None, parser=None,
class_map='', class_map=None,
load_bytes=False, load_bytes=False,
transform=None, transform=None,
target_transform=None,
): ):
if parser is None or isinstance(parser, str): if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map) parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser self.parser = parser
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.transform = transform self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0 self._consecutive_errors = 0
def __getitem__(self, index): def __getitem__(self, index):
@ -49,7 +51,9 @@ class ImageDataset(data.Dataset):
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:
target = torch.tensor(-1, dtype=torch.long) target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target return img, target
def __len__(self): def __len__(self):
@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset):
split='train', split='train',
is_training=False, is_training=False,
batch_size=None, batch_size=None,
class_map='',
load_bytes=False,
repeats=0, repeats=0,
download=False,
transform=None, transform=None,
target_transform=None,
): ):
assert parser is not None assert parser is not None
if isinstance(parser, str): if isinstance(parser, str):
self.parser = create_parser( self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
else: else:
self.parser = parser self.parser = parser
self.transform = transform self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0 self._consecutive_errors = 0
def __iter__(self): def __iter__(self):
for img, target in self.parser: for img, target in self.parser:
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if self.target_transform is not None:
target = torch.tensor(-1, dtype=torch.long) target = self.target_transform(target)
yield img, target yield img, target
def __len__(self): def __len__(self):

@ -1,7 +1,26 @@
import os import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\
Places365, ImageNet, ImageFolder
try:
from torchvision.datasets import INaturalist
has_inaturalist = True
except ImportError:
has_inaturalist = False
from .dataset import IterableImageDataset, ImageDataset from .dataset import IterableImageDataset, ImageDataset
_TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
def _search_split(root, split): def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists # look for sub-folder with name of split in root and use that if it exists
@ -9,22 +28,107 @@ def _search_split(root, split):
try_root = os.path.join(root, split_name) try_root = os.path.join(root, split_name)
if os.path.exists(try_root): if os.path.exists(try_root):
return try_root return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val') def _try(syn):
if os.path.exists(try_root): for s in syn:
return try_root try_root = os.path.join(root, s)
if os.path.exists(try_root):
return try_root
return root
if split_name in _TRAIN_SYNONYM:
root = _try(_TRAIN_SYNONYM)
elif split_name in _EVAL_SYNONYM:
root = _try(_EVAL_SYNONYM)
return root return root
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
repeats=0,
**kwargs
):
""" Dataset factory method
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* all - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
batch_size: batch size hint for (TFDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
name = name.lower() name = name.lower()
if name.startswith('tfds'): if name.startswith('torch/'):
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
ds_class = _TORCH_BASIC_DS[name]
use_train = split in _TRAIN_SYNONYM
ds = ds_class(train=use_train, **torch_kwargs)
elif name == 'inaturalist' or name == 'inat':
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
target_type = 'full'
split_split = split.split('/')
if len(split_split) > 1:
target_type = split_split[0].split('_')
if len(target_type) == 1:
target_type = target_type[0]
split = split_split[-1]
if split in _TRAIN_SYNONYM:
split = '2021_train'
elif split in _EVAL_SYNONYM:
split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365':
if split in _TRAIN_SYNONYM:
split = 'train-standard'
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'imagenet':
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('tfds/'):
ds = IterableImageDataset( ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
else: else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root): if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split) root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs) ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds return ds

@ -1,16 +1,19 @@
import os import os
def load_class_map(filename, root=''): def load_class_map(map_or_filename, root=''):
class_map_path = filename if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty'
return map_or_filename
class_map_path = map_or_filename
if not os.path.exists(class_map_path): if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename) class_map_path = os.path.join(root, class_map_path)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
class_map_ext = os.path.splitext(filename)[-1].lower() class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
if class_map_ext == '.txt': if class_map_ext == '.txt':
with open(class_map_path) as f: with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)} class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else: else:
assert False, 'Unsupported class map extension' assert False, f'Unsupported class map file extension ({class_map_ext}).'
return class_to_idx return class_to_idx

@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs):
# explicitly select other options shortly # explicitly select other options shortly
if prefix == 'tfds': if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) parser = ParserTfds(root, name, split=split, **kwargs)
else: else:
assert os.path.exists(root) assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -57,23 +57,28 @@ class ParserTfds(Parser):
components. components.
""" """
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): def __init__(
self, root, name, split='train', is_training=False, batch_size=None,
download=False, repeats=0, seed=42):
super().__init__() super().__init__()
self.root = root self.root = root
self.split = split self.split = split
self.shuffle = shuffle
self.is_training = is_training self.is_training = is_training
if self.is_training: if self.is_training:
assert batch_size is not None,\ assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size self.batch_size = batch_size
self.repeats = repeats self.repeats = repeats
self.common_seed = seed # seed across all worker / dist nodes
self.worker_seed = 0 # seed specific to each work instance
self.subsplit = None self.subsplit = None
self.builder = tfds.builder(name, data_dir=root) self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
# download_and_prepare() by default here as it's caused issues generating unwanted paths. if download:
self.num_samples = self.builder.info.splits[split].num_examples self.builder.download_and_prepare()
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
self.ds = None # initialized lazily on each dataloader worker process self.ds = None # initialized lazily on each dataloader worker process
self.worker_info = None self.worker_info = None
@ -97,17 +102,18 @@ class ParserTfds(Parser):
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes # setup input context to split dataset across distributed processes
split = self.split global_num_workers = num_workers = 1
num_workers = 1 global_worker_id = 1
if worker_info is not None: if worker_info is not None:
self.worker_info = worker_info self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id worker_id = worker_info.id
global_worker_id = self.dist_rank * num_workers + worker_id
# FIXME I need to spend more time figuring out the best way to distribute/split data across # FIXME verify best sharding approach
# combo of distributed replicas + dataloader worker processes """ Data sharding
"""
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong. between the splits each iteration, but that understanding could be wrong.
@ -116,44 +122,39 @@ class ParserTfds(Parser):
* InputContext for distributed and sub-splits for worker processes * InputContext for distributed and sub-splits for worker processes
* sub-splits for both * sub-splits for both
""" """
# split_size = self.num_samples // num_workers can_subsplit = '[' not in self.split # can't subsplit a subsplit
# start = worker_id * split_size should_subsplit = global_num_workers > 1 and (
# if worker_id == num_workers - 1: self.split_info.num_shards < global_num_workers or not self.is_training)
# split = split + '[{}:]'.format(start) if can_subsplit and should_subsplit:
# else: # manually split the dataset w/o sharding for more even samples / worker
# split = split + '[{}:{}]'.format(start, start + split_size) self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id]
if not self.is_training and '[' not in self.split:
# If not training, and split doesn't define a subsplit, manually split the dataset input_context = None
# for more even samples / worker if global_num_workers > 1 and self.subsplit is None:
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ # set input context to divide shards among distributed replicas
self.dist_rank * num_workers + worker_id]
if self.subsplit is None:
input_context = tf.distribute.InputContext( input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers, num_input_pipelines=global_num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id, input_pipeline_id=global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
) )
else:
input_context = None
read_config = tfds.ReadConfig( read_config = tfds.ReadConfig(
shuffle_seed=42, shuffle_seed=self.common_seed,
shuffle_reshuffle_each_iteration=True, shuffle_reshuffle_each_iteration=True,
input_context=input_context) input_context=input_context)
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
options = tf.data.Options() options = tf.data.Options()
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
options.experimental_threading.max_intra_op_parallelism = 1 getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options) ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle: if self.is_training:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)

@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters # Dataset parameters
parser.add_argument('data_dir', metavar='DIR', parser.add_argument('data_dir', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='', parser.add_argument('--dataset', '-d', metavar='NAME', default='',
@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)') help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation', parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)') help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"') help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False, parser.add_argument('--pretrained', action='store_true', default=False,
@ -484,11 +490,16 @@ def main():
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map,
batch_size=args.batch_size, repeats=args.epoch_repeats) download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None collate_fn = None

@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)') help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation', parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)') help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
@ -175,7 +177,7 @@ def validate(args):
dataset = create_dataset( dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split, root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map) download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels: if args.valid_labels:
with open(args.valid_labels, 'r') as f: with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save