Merge pull request #964 from rwightman/more_datasets

Dataset additions
pull/968/head
Ross Wightman 3 years ago committed by GitHub
commit 65419f60cc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,15 +23,17 @@ class ImageDataset(data.Dataset):
self, self,
root, root,
parser=None, parser=None,
class_map='', class_map=None,
load_bytes=False, load_bytes=False,
transform=None, transform=None,
target_transform=None,
): ):
if parser is None or isinstance(parser, str): if parser is None or isinstance(parser, str):
parser = create_parser(parser or '', root=root, class_map=class_map) parser = create_parser(parser or '', root=root, class_map=class_map)
self.parser = parser self.parser = parser
self.load_bytes = load_bytes self.load_bytes = load_bytes
self.transform = transform self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0 self._consecutive_errors = 0
def __getitem__(self, index): def __getitem__(self, index):
@ -49,7 +51,9 @@ class ImageDataset(data.Dataset):
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if target is None:
target = torch.tensor(-1, dtype=torch.long) target = -1
elif self.target_transform is not None:
target = self.target_transform(target)
return img, target return img, target
def __len__(self): def __len__(self):
@ -71,26 +75,28 @@ class IterableImageDataset(data.IterableDataset):
split='train', split='train',
is_training=False, is_training=False,
batch_size=None, batch_size=None,
class_map='',
load_bytes=False,
repeats=0, repeats=0,
download=False,
transform=None, transform=None,
target_transform=None,
): ):
assert parser is not None assert parser is not None
if isinstance(parser, str): if isinstance(parser, str):
self.parser = create_parser( self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) parser, root=root, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, download=download)
else: else:
self.parser = parser self.parser = parser
self.transform = transform self.transform = transform
self.target_transform = target_transform
self._consecutive_errors = 0 self._consecutive_errors = 0
def __iter__(self): def __iter__(self):
for img, target in self.parser: for img, target in self.parser:
if self.transform is not None: if self.transform is not None:
img = self.transform(img) img = self.transform(img)
if target is None: if self.target_transform is not None:
target = torch.tensor(-1, dtype=torch.long) target = self.target_transform(target)
yield img, target yield img, target
def __len__(self): def __len__(self):

@ -1,7 +1,26 @@
import os import os
from torchvision.datasets import CIFAR100, CIFAR10, MNIST, QMNIST, KMNIST, FashionMNIST,\
Places365, ImageNet, ImageFolder
try:
from torchvision.datasets import INaturalist
has_inaturalist = True
except ImportError:
has_inaturalist = False
from .dataset import IterableImageDataset, ImageDataset from .dataset import IterableImageDataset, ImageDataset
_TORCH_BASIC_DS = dict(
cifar10=CIFAR10,
cifar100=CIFAR100,
mnist=MNIST,
qmist=QMNIST,
kmnist=KMNIST,
fashion_mnist=FashionMNIST,
)
_TRAIN_SYNONYM = {'train', 'training'}
_EVAL_SYNONYM = {'val', 'valid', 'validation', 'eval', 'evaluation'}
def _search_split(root, split): def _search_split(root, split):
# look for sub-folder with name of split in root and use that if it exists # look for sub-folder with name of split in root and use that if it exists
@ -9,22 +28,107 @@ def _search_split(root, split):
try_root = os.path.join(root, split_name) try_root = os.path.join(root, split_name)
if os.path.exists(try_root): if os.path.exists(try_root):
return try_root return try_root
if split_name == 'validation':
try_root = os.path.join(root, 'val') def _try(syn):
if os.path.exists(try_root): for s in syn:
return try_root try_root = os.path.join(root, s)
if os.path.exists(try_root):
return try_root
return root
if split_name in _TRAIN_SYNONYM:
root = _try(_TRAIN_SYNONYM)
elif split_name in _EVAL_SYNONYM:
root = _try(_EVAL_SYNONYM)
return root return root
def create_dataset(name, root, split='validation', search_split=True, is_training=False, batch_size=None, **kwargs): def create_dataset(
name,
root,
split='validation',
search_split=True,
class_map=None,
load_bytes=False,
is_training=False,
download=False,
batch_size=None,
repeats=0,
**kwargs
):
""" Dataset factory method
In parenthesis after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* all - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS)
batch_size: batch size hint for (TFDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS)
**kwargs: other args to pass to dataset
Returns:
Dataset object
"""
name = name.lower() name = name.lower()
if name.startswith('tfds'): if name.startswith('torch/'):
name = name.split('/', 2)[-1]
torch_kwargs = dict(root=root, download=download, **kwargs)
if name in _TORCH_BASIC_DS:
ds_class = _TORCH_BASIC_DS[name]
use_train = split in _TRAIN_SYNONYM
ds = ds_class(train=use_train, **torch_kwargs)
elif name == 'inaturalist' or name == 'inat':
assert has_inaturalist, 'Please update to PyTorch 1.10, torchvision 0.11+ for Inaturalist'
target_type = 'full'
split_split = split.split('/')
if len(split_split) > 1:
target_type = split_split[0].split('_')
if len(target_type) == 1:
target_type = target_type[0]
split = split_split[-1]
if split in _TRAIN_SYNONYM:
split = '2021_train'
elif split in _EVAL_SYNONYM:
split = '2021_valid'
ds = INaturalist(version=split, target_type=target_type, **torch_kwargs)
elif name == 'places365':
if split in _TRAIN_SYNONYM:
split = 'train-standard'
elif split in _EVAL_SYNONYM:
split = 'val'
ds = Places365(split=split, **torch_kwargs)
elif name == 'imagenet':
if split in _EVAL_SYNONYM:
split = 'val'
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
elif name.startswith('tfds/'):
ds = IterableImageDataset( ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
else: else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root): if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split) root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs) ds = ImageDataset(root, parser=name, class_map=class_map, load_bytes=load_bytes, **kwargs)
return ds return ds

@ -1,16 +1,19 @@
import os import os
def load_class_map(filename, root=''): def load_class_map(map_or_filename, root=''):
class_map_path = filename if isinstance(map_or_filename, dict):
assert dict, 'class_map dict must be non-empty'
return map_or_filename
class_map_path = map_or_filename
if not os.path.exists(class_map_path): if not os.path.exists(class_map_path):
class_map_path = os.path.join(root, filename) class_map_path = os.path.join(root, class_map_path)
assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % filename assert os.path.exists(class_map_path), 'Cannot locate specified class map file (%s)' % map_or_filename
class_map_ext = os.path.splitext(filename)[-1].lower() class_map_ext = os.path.splitext(map_or_filename)[-1].lower()
if class_map_ext == '.txt': if class_map_ext == '.txt':
with open(class_map_path) as f: with open(class_map_path) as f:
class_to_idx = {v.strip(): k for k, v in enumerate(f)} class_to_idx = {v.strip(): k for k, v in enumerate(f)}
else: else:
assert False, 'Unsupported class map extension' assert False, f'Unsupported class map file extension ({class_map_ext}).'
return class_to_idx return class_to_idx

@ -17,7 +17,7 @@ def create_parser(name, root, split='train', **kwargs):
# explicitly select other options shortly # explicitly select other options shortly
if prefix == 'tfds': if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, shuffle=kwargs.pop('shuffle', False), **kwargs) parser = ParserTfds(root, name, split=split, **kwargs)
else: else:
assert os.path.exists(root) assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -6,8 +6,6 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import os
import io
import math import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -17,6 +15,13 @@ try:
import tensorflow as tf import tensorflow as tf
tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu) tf.config.set_visible_devices([], 'GPU') # Hands off my GPU! (or pip install tensorflow-cpu)
import tensorflow_datasets as tfds import tensorflow_datasets as tfds
try:
tfds.even_splits('', 1, drop_remainder=False) # non-buggy even_splits has drop_remainder arg
has_buggy_even_splits = False
except TypeError:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
has_buggy_even_splits = True
except ImportError as e: except ImportError as e:
print(e) print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
@ -25,7 +30,7 @@ from .parser import Parser
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue
PREFETCH_SIZE = 2048 # samples to prefetch PREFETCH_SIZE = 2048 # samples to prefetch
@ -57,32 +62,71 @@ class ParserTfds(Parser):
components. components.
""" """
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): def __init__(
self,
root,
name,
split='train',
is_training=False,
batch_size=None,
download=False,
repeats=0,
seed=42,
prefetch_size=None,
shuffle_size=None,
max_threadpool_size=None
):
""" Tensorflow-datasets Wrapper
Args:
root: root data dir (ie your TFDS_DATA_DIR. not dataset specific sub-dir)
name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
prefetch_size: override default tf.data prefetch buffer size
shuffle_size: override default tf.data shuffle buffer size
max_threadpool_size: override default threadpool size for tf.data
"""
super().__init__() super().__init__()
self.root = root self.root = root
self.split = split self.split = split
self.shuffle = shuffle
self.is_training = is_training self.is_training = is_training
if self.is_training: if self.is_training:
assert batch_size is not None,\ assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size self.batch_size = batch_size
self.repeats = repeats self.repeats = repeats
self.subsplit = None self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.prefetch_size = prefetch_size or PREFETCH_SIZE
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
# TFDS builder and split information
self.builder = tfds.builder(name, data_dir=root) self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
# download_and_prepare() by default here as it's caused issues generating unwanted paths. if download:
self.num_samples = self.builder.info.splits[split].num_examples self.builder.download_and_prepare()
self.ds = None # initialized lazily on each dataloader worker process self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
self.worker_info = None # Distributed world state
self.dist_rank = 0 self.dist_rank = 0
self.dist_num_replicas = 1 self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank() self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size() self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1
self.worker_info = None
self.worker_seed = 0 # seed unique to each work instance
self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process
def _lazy_init(self): def _lazy_init(self):
""" Lazily initialize the dataset. """ Lazily initialize the dataset.
@ -97,78 +141,83 @@ class ParserTfds(Parser):
worker_info = torch.utils.data.get_worker_info() worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes # setup input context to split dataset across distributed processes
split = self.split
num_workers = 1 num_workers = 1
global_worker_id = 0
if worker_info is not None: if worker_info is not None:
self.worker_info = worker_info self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers num_workers = worker_info.num_workers
global_num_workers = self.dist_num_replicas * num_workers self.global_num_workers = self.dist_num_replicas * num_workers
worker_id = worker_info.id global_worker_id = self.dist_rank * num_workers + worker_info.id
# FIXME I need to spend more time figuring out the best way to distribute/split data across """ Data sharding
# combo of distributed replicas + dataloader worker processes
"""
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used. InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True) My understanding is that using split, the underling TFRecord files will shuffle (shuffle_files=True)
between the splits each iteration, but that understanding could be wrong. between the splits each iteration, but that understanding could be wrong.
Possible split options include:
* InputContext for both distributed & worker processes (current) I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
* InputContext for distributed and sub-splits for worker processes the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
* sub-splits for both in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
""" """
# split_size = self.num_samples // num_workers should_subsplit = self.global_num_workers > 1 and (
# start = worker_id * split_size self.split_info.num_shards < self.global_num_workers or not self.is_training)
# if worker_id == num_workers - 1: if should_subsplit:
# split = split + '[{}:]'.format(start) # split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# else: # read patterns for distributed training (overlap across shards) so better to use InputContext there
# split = split + '[{}:{}]'.format(start, start + split_size) if has_buggy_even_splits:
if not self.is_training and '[' not in self.split: # my even_split workaround doesn't work on subsplits, upgrade tfds!
# If not training, and split doesn't define a subsplit, manually split the dataset if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
# for more even samples / worker subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = even_split_indices(self.split, global_num_workers, self.num_samples)[ self.subsplit = subsplits[global_worker_id]
self.dist_rank * num_workers + worker_id] else:
subsplits = tfds.even_splits(self.split, self.global_num_workers)
if self.subsplit is None: self.subsplit = subsplits[global_worker_id]
input_context = None
if self.global_num_workers > 1 and self.subsplit is None:
# set input context to divide shards among distributed replicas
input_context = tf.distribute.InputContext( input_context = tf.distribute.InputContext(
num_input_pipelines=self.dist_num_replicas * num_workers, num_input_pipelines=self.global_num_workers,
input_pipeline_id=self.dist_rank * num_workers + worker_id, input_pipeline_id=global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
) )
else:
input_context = None
read_config = tfds.ReadConfig( read_config = tfds.ReadConfig(
shuffle_seed=42, shuffle_seed=self.common_seed,
shuffle_reshuffle_each_iteration=True, shuffle_reshuffle_each_iteration=True,
input_context=input_context) input_context=input_context)
ds = self.builder.as_dataset( ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers # avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options() options = tf.data.Options()
options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
options.experimental_threading.max_intra_op_parallelism = 1 getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options) ds = ds.with_options(options)
if self.is_training or self.repeats > 1: if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets # to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle: if self.is_training:
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds) self.ds = tfds.as_numpy(ds)
def __iter__(self): def __iter__(self):
if self.ds is None: if self.ds is None:
self._lazy_init() self._lazy_init()
# compute a rounded up sample count that is used to:
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation. # 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results. # This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around) # batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers)
if self.is_training: if self.is_training:
# round up to nearest batch_size per worker-replica # round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
sample_count = 0 sample_count = 0
for sample in self.ds: for sample in self.ds:
img = Image.fromarray(sample['image'], mode='RGB') img = Image.fromarray(sample['image'], mode='RGB')
@ -179,21 +228,17 @@ class ParserTfds(Parser):
# this results in extra samples per epoch but seems more desirable than dropping # this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break break
if not self.is_training and self.dist_num_replicas and 0 < sample_count < target_sample_count:
# Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes. # Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes. # For single process case, it won't matter if workers return different batch sizes.
# FIXME if using input_context or % based subsplits, sample count can vary by more than +/- 1 and this # If using input_context or % based splits, sample count can vary significantly across workers and this
# approach is not optimal # approach should not be used (hence disabled if self.subsplit isn't set).
yield img, sample['label'] # yield prev sample again while sample_count < target_sample_count:
sample_count += 1 yield img, sample['label'] # yield prev sample again
sample_count += 1
@property
def _num_workers(self):
return 1 if self.worker_info is None else self.worker_info.num_workers
@property
def _num_pipelines(self):
return self._num_workers * self.dist_num_replicas
def __len__(self): def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on # this is just an estimate and does not factor in extra samples added to pad batches based on
@ -201,7 +246,7 @@ class ParserTfds(Parser):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False): def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False): def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base""" """ Return all filenames in dataset, overrides base"""

@ -70,7 +70,7 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
# Dataset / Model parameters # Dataset parameters
parser.add_argument('data_dir', metavar='DIR', parser.add_argument('data_dir', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--dataset', '-d', metavar='NAME', default='', parser.add_argument('--dataset', '-d', metavar='NAME', default='',
@ -79,6 +79,12 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)') help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation', parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)') help='dataset validation split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
# Model parameters
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"') help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False, parser.add_argument('--pretrained', action='store_true', default=False,
@ -484,11 +490,16 @@ def main():
# create the train and eval datasets # create the train and eval datasets
dataset_train = create_dataset( dataset_train = create_dataset(
args.dataset, args.dataset, root=args.data_dir, split=args.train_split, is_training=True,
root=args.data_dir, split=args.train_split, is_training=True, class_map=args.class_map,
batch_size=args.batch_size, repeats=args.epoch_repeats) download=args.dataset_download,
batch_size=args.batch_size,
repeats=args.epoch_repeats)
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) args.dataset, root=args.data_dir, split=args.val_split, is_training=False,
class_map=args.class_map,
download=args.dataset_download,
batch_size=args.batch_size)
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None collate_fn = None

@ -48,6 +48,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)') help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation', parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)') help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
@ -175,7 +177,7 @@ def validate(args):
dataset = create_dataset( dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split, root=args.data, name=args.dataset, split=args.split,
load_bytes=args.tf_preprocessing, class_map=args.class_map) download=args.dataset_download, load_bytes=args.tf_preprocessing, class_map=args.class_map)
if args.valid_labels: if args.valid_labels:
with open(args.valid_labels, 'r') as f: with open(args.valid_labels, 'r') as f:

Loading…
Cancel
Save