diff --git a/timm/data/dataset.py b/timm/data/dataset.py index e719f3f6..d3603a23 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -23,15 +23,17 @@ class ImageDataset(data.Dataset): self, root, parser=None, - class_map='', + class_map=None, load_bytes=False, transform=None, + target_transform=None, ): if parser is None or isinstance(parser, str): parser = create_parser(parser or '', root=root, class_map=class_map) self.parser = parser self.load_bytes = load_bytes self.transform = transform + self.target_transform = target_transform self._consecutive_errors = 0 def __getitem__(self, index): @@ -49,7 +51,9 @@ class ImageDataset(data.Dataset): if self.transform is not None: img = self.transform(img) if target is None: - target = torch.tensor(-1, dtype=torch.long) + target = -1 + elif self.target_transform is not None: + target = self.target_transform(target) return img, target def __len__(self): @@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset): split='train', is_training=False, batch_size=None, - class_map='', - load_bytes=False, repeats=0, + download=False, transform=None, + target_transform=None, ): assert parser is not None if isinstance(parser, str): self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) + parser, root=root, split=split, is_training=is_training, + batch_size=batch_size, repeats=repeats, download=download) else: self.parser = parser self.transform = transform + self.target_transform = target_transform self._consecutive_errors = 0 def __iter__(self): for img, target in self.parser: if self.transform is not None: img = self.transform(img) - if target is None: - target = torch.tensor(-1, dtype=torch.long) + if self.target_transform is not None: + target = self.target_transform(target) yield img, target def __len__(self): diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index ccc99d5c..03b03cf5 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -1,7 +1,26 @@ import os +from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\ + Places365, ImageNet, ImageFolder +try: + from torchvision.datasets import INaturalist + has_inaturalist = True +except ImportError: + has_inaturalist = False + from .dataset import IterableImageDataset, ImageDataset +_TORCH_BASIC_DS = dict( + cifar10=CIFAR10, + cifar100=CIFAR100, + mnist=MNIST, + qmist=QMNIST, + kmnist=KMNIST, + fashion_mnist=FashionMNIST, +) +_TRAIN_SYNONYM = {'train', 'training'} +_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'} + def _search_split(root, split): # look for sub-folder with name of split in root and use that if it exists @@ -9,22 +28,107 @@ def _search_split(root, split): try_root = os.path.join(root, split_name) if os.path.exists(try_root): return try_root - if split_name == 'validation': - try_root = os.path.join(root, 'val') - if os.path.exists(try_root): - return try_root + + def _try(syn): + for s in syn: + try_root = os.path.join(root, s) + if os.path.exists(try_root): + return try_root + return root + if split_name in _TRAIN_SYNONYM: + root = _try(_TRAIN_SYNONYM) + elif split_name in _EVAL_SYNONYM: + root = _try(_EVAL_SYNONYM) return root -def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): +def create_dataset( + name, + root, + split='validation', + search_split=True, + class_map=None, + load_bytes=False, + is_training=False, + download=False, + batch_size=None, + repeats=0, + **kwargs +): + """ Dataset factory method + + In parenthesis after each arg are the type of dataset supported for each arg, one of: + * folder - default, timm folder (or tar) based ImageDataset + * torch - torchvision based datasets + * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset + * all - any of the above + + Args: + name: dataset name, empty is okay for folder based datasets + root: root folder of dataset (all) + split: dataset split (all) + search_split: search for split specific child fold from root so one can specify + `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) + class_map: specify class -> index mapping via text file or dict (folder) + load_bytes: load data, return images as undecoded bytes (folder) + download: download dataset if not present and supported (TFDS, torch) + is_training: create dataset in train mode, this is different from the split. + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) + batch_size: batch size hint for (TFDS) + repeats: dataset repeats per iteration i.e. epoch (TFDS) + **kwargs: other args to pass to dataset + + Returns: + Dataset object + """ name = name.lower() - if name.startswith('tfds'): + if name.startswith('torch/'): + name = name.split('/', 2)[-1] + torch_kwargs = dict(root=root, download=download, **kwargs) + if name in _TORCH_BASIC_DS: + ds_class = _TORCH_BASIC_DS[name] + use_train = split in _TRAIN_SYNONYM + ds = ds_class(train=use_train, **torch_kwargs) + elif name == 'inaturalist' or name == 'inat': + assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist' + target_type = 'full' + split_split = split.split('/') + if len(split_split) > 1: + target_type = split_split[0].split('_') + if len(target_type) == 1: + target_type = target_type[0] + split = split_split[-1] + if split in _TRAIN_SYNONYM: + split = '2021_train' + elif split in _EVAL_SYNONYM: + split = '2021_valid' + ds = INaturalist(version=split, target_type=target_type, **torch_kwargs) + elif name == 'places365': + if split in _TRAIN_SYNONYM: + split = 'train-standard' + elif split in _EVAL_SYNONYM: + split = 'val' + ds = Places365(split=split, **torch_kwargs) + elif name == 'imagenet': + if split in _EVAL_SYNONYM: + split = 'val' + ds = ImageNet(split=split, **torch_kwargs) + elif name == 'image_folder' or name == 'folder': + # in case torchvision ImageFolder is preferred over timm ImageDataset for some reason + if search_split and os.path.isdir(root): + # look for split specific sub-folder in root + root = _search_split(root, split) + ds = ImageFolder(root, **kwargs) + else: + assert False, f"Unknown torchvision dataset {name}" + elif name.startswith('tfds/'): ds = IterableImageDataset( - root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) + root, parser=name, split=split, is_training=is_training, + download=download, batch_size=batch_size, repeats=repeats, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future - kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier if search_split and os.path.isdir(root): + # look for split specific sub-folder in root root = _search_split(root, split) - ds = ImageDataset(root, parser=name, **kwargs) + ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs) return ds diff --git a/timm/data/parsers/class_map.py b/timm/data/parsers/class_map.py index 9ef4d1fa..6b6fe453 100644 --- a/timm/data/parsers/class_map.py +++ b/timm/data/parsers/class_map.py @@ -1,16 +1,19 @@ import os -def load_class_map(filename, root=''): - class_map_path = filename +def load_class_map(map_or_filename, root=''): + if isinstance(map_or_filename, dict): + assert dict, 'class_map dict must be non-empty' + return map_or_filename + class_map_path = map_or_filename if not os.path.exists(class_map_path): - class_map_path = os.path.join(root, filename) - assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename - class_map_ext = os.path.splitext(filename)[-1].lower() + class_map_path = os.path.join(root, class_map_path) + assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename + class_map_ext = os.path.splitext(map_or_filename)[-1].lower() if class_map_ext == '.txt': with open(class_map_path) as f: class_to_idx = {v.strip(): k for k, v in enumerate(f)} else: - assert False, 'Unsupported class map extension' + assert False, f'Unsupported class map file extension ({class_map_ext}).' return class_to_idx diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index 419ffe89..892090ad 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs): # explicitly select other options shortly if prefix == 'tfds': from .parser_tfds import ParserTfds # defer tensorflow import - parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) + parser = ParserTfds(root, name, split=split, **kwargs) else: assert os.path.exists(root) # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 2ff90b09..2b0cd731 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -57,23 +57,28 @@ class ParserTfds(Parser): components. """ - def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): + def __init__( + self, root, name, split='train', is_training=False, batch_size=None, + download=False, repeats=0, seed=42): super().__init__() self.root = root self.split = split - self.shuffle = shuffle self.is_training = is_training if self.is_training: assert batch_size is not None,\ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size self.repeats = repeats + self.common_seed = seed # seed across all worker / dist nodes + self.worker_seed = 0 # seed specific to each work instance self.subsplit = None self.builder = tfds.builder(name, data_dir=root) - # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call - # download_and_prepare() by default here as it's caused issues generating unwanted paths. - self.num_samples = self.builder.info.splits[split].num_examples + # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag + if download: + self.builder.download_and_prepare() + self.split_info = self.builder.info.splits[split] + self.num_samples = self.split_info.num_examples self.ds = None # initialized lazily on each dataloader worker process self.worker_info = None @@ -97,17 +102,18 @@ class ParserTfds(Parser): worker_info = torch.utils.data.get_worker_info() # setup input context to split dataset across distributed processes - split = self.split - num_workers = 1 + global_num_workers = num_workers = 1 + global_worker_id = 1 if worker_info is not None: self.worker_info = worker_info + self.worker_seed = worker_info.seed num_workers = worker_info.num_workers global_num_workers = self.dist_num_replicas * num_workers worker_id = worker_info.id + global_worker_id = self.dist_rank * num_workers + worker_id - # FIXME I need to spend more time figuring out the best way to distribute/split data across - # combo of distributed replicas + dataloader worker processes - """ + # FIXME verify best sharding approach + """ Data sharding InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) between the splits each iteration, but that understanding could be wrong. @@ -116,44 +122,39 @@ class ParserTfds(Parser): * InputContext for distributed and sub-splits for worker processes * sub-splits for both """ - # split_size = self.num_samples // num_workers - # start = worker_id * split_size - # if worker_id == num_workers - 1: - # split = split + '[{}:]'.format(start) - # else: - # split = split + '[{}:{}]'.format(start, start + split_size) - if not self.is_training and '[' not in self.split: - # If not training, and split doesn't define a subsplit, manually split the dataset - # for more even samples / worker - self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ - self.dist_rank * num_workers + worker_id] - - if self.subsplit is None: + can_subsplit = '[' not in self.split # can't subsplit a subsplit + should_subsplit = global_num_workers > 1 and ( + self.split_info.num_shards < global_num_workers or not self.is_training) + if can_subsplit and should_subsplit: + # manually split the dataset w/o sharding for more even samples / worker + self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[global_worker_id] + + input_context = None + if global_num_workers > 1 and self.subsplit is None: + # set input context to divide shards among distributed replicas input_context = tf.distribute.InputContext( - num_input_pipelines=self.dist_num_replicas * num_workers, - input_pipeline_id=self.dist_rank * num_workers + worker_id, + num_input_pipelines=global_num_workers, + input_pipeline_id=global_worker_id, num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) - else: - input_context = None - read_config = tfds.ReadConfig( - shuffle_seed=42, + shuffle_seed=self.common_seed, shuffle_reshuffle_each_iteration=True, input_context=input_context) ds = self.builder.as_dataset( - split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) + split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers options = tf.data.Options() - options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - options.experimental_threading.max_intra_op_parallelism = 1 + thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' + getattr(options, thread_member).private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + getattr(options, thread_member).max_intra_op_parallelism = 1 ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually - if self.shuffle: - ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) + if self.is_training: + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=self.worker_seed) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) self.ds = tfds.as_numpy(ds) diff --git a/train.py b/train.py index 332dec0c..10d839be 100755 --- a/train.py +++ b/train.py @@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') -# Dataset / Model parameters +# Dataset parameters parser.add_argument('data_dir', metavar='DIR', help='path to dataset') parser.add_argument('--dataset', '-d', metavar='NAME', default='', @@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') + +# Model parameters parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, @@ -484,11 +490,16 @@ def main(): # create the train and eval datasets dataset_train = create_dataset( - args.dataset, - root=args.data_dir, split=args.train_split, is_training=True, - batch_size=args.batch_size, repeats=args.epoch_repeats) + args.dataset, root=args.data_dir, split=args.train_split, is_training=True, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size, + repeats=args.epoch_repeats) dataset_eval = create_dataset( - args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) + args.dataset, root=args.data_dir, split=args.val_split, is_training=False, + class_map=args.class_map, + download=args.dataset_download, + batch_size=args.batch_size) # setup mixup / cutmix collate_fn = None diff --git a/validate.py b/validate.py index 2e18841f..a99e5b5c 100755 --- a/validate.py +++ b/validate.py @@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') +parser.add_argument('--dataset-download', action='store_true', default=False, + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', @@ -175,7 +177,7 @@ def validate(args): dataset = create_dataset( root=args.data, name=args.dataset, split=args.split, - load_bytes=args.tf_preprocessing, class_map=args.class_map) + download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map) if args.valid_labels: with open(args.valid_labels, 'r') as f: