Merge pull request #964 from rwightman/more_datasets

Dataset additions
pull/968/head
Ross Wightman 3 years ago committed by GitHub
commit 65419f60cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,15 +23,17 @@ class ImageDataset(data.Dataset):
self,
root,
parser=None,
class_map='',
class_map=None,
load_bytes=False,
transform=None,
target_transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
@ -49,7 +51,9 @@ class ImageDataset(data.Dataset):
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset):
split='train',
is_training=False,
batch_size=None,
class_map='',
load_bytes=False,
repeats=0,
download=False,
transform=None,
target_transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
else:
self.parser = parser
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
if self.target_transform is not None:
target = self.target_transform(target)
yield img, target
def __len__(self):

@ -1,7 +1,26 @@
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\
Places365, ImageNet, ImageFolder
try:
from torchvision.datasets import INaturalist
has_inaturalist = True
except ImportError:
has_inaturalist = False
from .dataset import IterableImageDataset, ImageDataset
_TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
@ -9,22 +28,107 @@ def _search_split(root, split):
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val')
if os.path.exists(try_root):
return try_root
def _try(syn):
for s in syn:
try_root = os.path.join(root, s)
if os.path.exists(try_root):
return try_root
return root
if split_name in _TRAIN_SYNONYM:
root = _try(_TRAIN_SYNONYM)
elif split_name in _EVAL_SYNONYM:
root = _try(_EVAL_SYNONYM)
return root
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
repeats=0,
**kwargs
):
""" Dataset factory method
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* all - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
batch_size: batch size hint for (TFDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
name = name.lower()
if name.startswith('tfds'):
if name.startswith('torch/'):
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
ds_class = _TORCH_BASIC_DS[name]
use_train = split in _TRAIN_SYNONYM
ds = ds_class(train=use_train, **torch_kwargs)
elif name == 'inaturalist' or name == 'inat':
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
target_type = 'full'
split_split = split.split('/')
if len(split_split) > 1:
target_type = split_split[0].split('_')
if len(target_type) == 1:
target_type = target_type[0]
split = split_split[-1]
if split in _TRAIN_SYNONYM:
split = '2021_train'
elif split in _EVAL_SYNONYM:
split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365':
if split in _TRAIN_SYNONYM:
split = 'train-standard'
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'imagenet':
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds

@ -1,16 +1,19 @@
import os
def load_class_map(filename, root=''):
class_map_path = filename
def load_class_map(map_or_filename, root=''):
if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty'
return map_or_filename
class_map_path = map_or_filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
class_map_path = os.path.join(root, class_map_path)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
assert False, f'Unsupported class map file extension ({class_map_ext}).'
return class_to_idx

@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs):
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
parser = ParserTfds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import io
import math
import torch
import torch.distributed as dist
@ -17,6 +15,13 @@ try:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds
try:
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
has_buggy_even_splits = False
except TypeError:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
has_buggy_even_splits = True
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
@ -25,7 +30,7 @@ from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue
PREFETCH_SIZE = 2048 # samples to prefetch
@ -57,32 +62,71 @@ class ParserTfds(Parser):
components.
"""
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
def __init__(
self,
root,
name,
split='train',
is_training=False,
batch_size=None,
download=False,
repeats=0,
seed=42,
prefetch_size=None,
shuffle_size=None,
max_threadpool_size=None
):
""" Tensorflow-datasets Wrapper
Args:
root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size
max_threadpool_size: override default threadpool size for tf.data
"""
super().__init__()
self.root = root
self.split = split
self.shuffle = shuffle
self.is_training = is_training
if self.is_training:
assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.repeats = repeats
self.subsplit = None
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.prefetch_size = prefetch_size or PREFETCH_SIZE
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
# TFDS builder and split information
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
self.num_samples = self.builder.info.splits[split].num_examples
self.ds = None # initialized lazily on each dataloader worker process
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
self.worker_info = None
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1
self.worker_info = None
self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process
def _lazy_init(self):
""" Lazily initialize the dataset.
@ -97,78 +141,83 @@ class ParserTfds(Parser):
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
split = self.split
num_workers = 1
global_worker_id = 0
if worker_info is not None:
self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id
self.global_num_workers = self.dist_num_replicas * num_workers
global_worker_id = self.dist_rank * num_workers + worker_info.id
# FIXME I need to spend more time figuring out the best way to distribute/split data across
# combo of distributed replicas + dataloader worker processes
"""
""" Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong.
Possible split options include:
* InputContext for both distributed & worker processes (current)
* InputContext for distributed and sub-splits for worker processes
* sub-splits for both
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
"""
# split_size = self.num_samples // num_workers
# start = worker_id * split_size
# if worker_id == num_workers - 1:
# split = split + '[{}:]'.format(start)
# else:
# split = split + '[{}:{}]'.format(start, start + split_size)
if not self.is_training and '[' not in self.split:
# If not training, and split doesn't define a subsplit, manually split the dataset
# for more even samples / worker
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
self.dist_rank * num_workers + worker_id]
if self.subsplit is None:
should_subsplit = self.global_num_workers > 1 and (
self.split_info.num_shards < self.global_num_workers or not self.is_training)
if should_subsplit:
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = subsplits[global_worker_id]
else:
subsplits = tfds.even_splits(self.split, self.global_num_workers)
self.subsplit = subsplits[global_worker_id]
input_context = None
if self.global_num_workers > 1 and self.subsplit is None:
# set input context to divide shards among distributed replicas
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_input_pipelines=self.global_num_workers,
input_pipeline_id=global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
else:
input_context = None
read_config = tfds.ReadConfig(
shuffle_seed=42,
shuffle_seed=self.common_seed,
shuffle_reshuffle_each_iteration=True,
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
options.experimental_threading.max_intra_op_parallelism = 1
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
if self.is_training:
ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds)
def __iter__(self):
if self.ds is None:
self._lazy_init()
# compute a rounded up sample count that is used to:
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
sample_count = 0
for sample in self.ds:
img = Image.fromarray(sample['image'], mode='RGB')
@ -179,21 +228,17 @@ class ParserTfds(Parser):
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this
# approach is not optimal
yield img, sample['label'] # yield prev sample again
sample_count += 1
@property
def _num_workers(self):
return 1 if self.worker_info is None else self.worker_info.num_workers
@property
def _num_pipelines(self):
return self._num_workers * self.dist_num_replicas
# If using input_context or % based splits, sample count can vary significantly across workers and this
# approach should not be used (hence disabled if self.subsplit isn't set).
while sample_count < target_sample_count:
yield img, sample['label'] # yield prev sample again
sample_count += 1
def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on
@ -201,7 +246,7 @@ class ParserTfds(Parser):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""

@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
@ -484,11 +490,16 @@ def main():
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset,
root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats)
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
# setup mixup / cutmix
collate_fn = None

@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
@ -175,7 +177,7 @@ def validate(args):
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map)
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save