Dataset work

* support some torchvision datasets
* improvements to TFDS wrapper for subsplit handling (fix #942), shuffle seed
* add class-map support to train (fix #957)
more_datasets
Ross Wightman 3 years ago
parent ddc29da974
commit ba65dfe2c6

@ -23,15 +23,17 @@ class ImageDataset(data.Dataset):
self,
root,
parser=None,
class_map='',
class_map=None,
load_bytes=False,
transform=None,
target_transform=None,
):
if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser
self.load_bytes = load_bytes
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __getitem__(self, index):
@ -49,7 +51,9 @@ class ImageDataset(data.Dataset):
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset):
split='train',
is_training=False,
batch_size=None,
class_map='',
load_bytes=False,
repeats=0,
download=False,
transform=None,
target_transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
else:
self.parser = parser
self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0
def __iter__(self):
for img, target in self.parser:
if self.transform is not None:
img = self.transform(img)
if target is None:
target = torch.tensor(-1, dtype=torch.long)
if self.target_transform is not None:
target = self.target_transform(target)
yield img, target
def __len__(self):

@ -1,7 +1,26 @@
import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\
Places365, ImageNet, ImageFolder
try:
from torchvision.datasets import INaturalist
has_inaturalist = True
except ImportError:
has_inaturalist = False
from .dataset import IterableImageDataset, ImageDataset
_TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists
@ -9,22 +28,107 @@ def _search_split(root, split):
try_root = os.path.join(root, split_name)
if os.path.exists(try_root):
return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val')
def _try(syn):
for s in syn:
try_root = os.path.join(root, s)
if os.path.exists(try_root):
return try_root
return root
if split_name in _TRAIN_SYNONYM:
root = _try(_TRAIN_SYNONYM)
elif split_name in _EVAL_SYNONYM:
root = _try(_EVAL_SYNONYM)
return root
def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
repeats=0,
**kwargs
):
""" Dataset factory method
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs):
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* all - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
batch_size: batch size hint for (TFDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
name = name.lower()
if name.startswith('tfds'):
if name.startswith('torch/'):
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
ds_class = _TORCH_BASIC_DS[name]
use_train = split in _TRAIN_SYNONYM
ds = ds_class(train=use_train, **torch_kwargs)
elif name == 'inaturalist' or name == 'inat':
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
target_type = 'full'
split_split = split.split('/')
if len(split_split) > 1:
target_type = split_split[0].split('_')
if len(target_type) == 1:
target_type = target_type[0]
split = split_split[-1]
if split in _TRAIN_SYNONYM:
split = '2021_train'
elif split in _EVAL_SYNONYM:
split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365':
if split in _TRAIN_SYNONYM:
split = 'train-standard'
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'imagenet':
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)
ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds

@ -1,16 +1,19 @@
import os
def load_class_map(filename, root=''):
class_map_path = filename
def load_class_map(map_or_filename, root=''):
if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty'
return map_or_filename
class_map_path = map_or_filename
if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename
class_map_ext = os.path.splitext(filename)[-1].lower()
class_map_path = os.path.join(root, class_map_path)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
if class_map_ext == '.txt':
with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else:
assert False, 'Unsupported class map extension'
assert False, f'Unsupported class map file extension ({class_map_ext}).'
return class_to_idx

@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs):
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs)
parser = ParserTfds(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -57,23 +57,28 @@ class ParserTfds(Parser):
components.
"""
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
def __init__(
self, root, name, split='train', is_training=False, batch_size=None,
download=False, repeats=0, seed=42):
super().__init__()
self.root = root
self.split = split
self.shuffle = shuffle
self.is_training = is_training
if self.is_training:
assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # seed across all worker / dist nodes
self.worker_seed = 0 # seed specific to each work instance
self.subsplit = None
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
# download_and_prepare() by default here as it's caused issues generating unwanted paths.
self.num_samples = self.builder.info.splits[split].num_examples
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
self.ds = None # initialized lazily on each dataloader worker process
self.worker_info = None
@ -97,17 +102,18 @@ class ParserTfds(Parser):
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
split = self.split
num_workers = 1
global_num_workers = num_workers = 1
global_worker_id = 1
if worker_info is not None:
self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id
global_worker_id = self.dist_rank * num_workers + worker_id
# FIXME I need to spend more time figuring out the best way to distribute/split data across
# combo of distributed replicas + dataloader worker processes
"""
# FIXME verify best sharding approach
""" Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong.
@ -116,44 +122,39 @@ class ParserTfds(Parser):
* InputContext for distributed and sub-splits for worker processes
* sub-splits for both
"""
# split_size = self.num_samples // num_workers
# start = worker_id * split_size
# if worker_id == num_workers - 1:
# split = split + '[{}:]'.format(start)
# else:
# split = split + '[{}:{}]'.format(start, start + split_size)
if not self.is_training and '[' not in self.split:
# If not training, and split doesn't define a subsplit, manually split the dataset
# for more even samples / worker
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[
self.dist_rank * num_workers + worker_id]
if self.subsplit is None:
can_subsplit = '[' not in self.split # can't subsplit a subsplit
should_subsplit = global_num_workers > 1 and (
self.split_info.num_shards < global_num_workers or not self.is_training)
if can_subsplit and should_subsplit:
# manually split the dataset w/o sharding for more even samples / worker
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id]
input_context = None
if global_num_workers > 1 and self.subsplit is None:
# set input context to divide shards among distributed replicas
input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id,
num_input_pipelines=global_num_workers,
input_pipeline_id=global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
else:
input_context = None
read_config = tfds.ReadConfig(
shuffle_seed=42,
shuffle_seed=self.common_seed,
shuffle_reshuffle_each_iteration=True,
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config)
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
options = tf.data.Options()
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
options.experimental_threading.max_intra_op_parallelism = 1
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
if self.is_training:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds)

@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters
# Dataset parameters
parser.add_argument('data_dir', metavar='DIR',
help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
@ -484,11 +490,16 @@ def main():
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset,
root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats)
args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
# setup mixup / cutmix
collate_fn = None

@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
@ -175,7 +177,7 @@ def validate(args):
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map)
download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels:
with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save