Add epoch-repeats arg to multiply the number of dataset passes per epoch. Currently for iterable datasets (read TFDS wrapper) only.

pull/450/head
Ross Wightman 3 years ago
parent de97be9146
commit 2db2d87ff7

@ -5,7 +5,7 @@ from .constants import *
_logger = logging.getLogger(__name__)
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True):
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {}
default_cfg = default_cfg
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):

@ -73,12 +73,13 @@ class IterableImageDataset(data.IterableDataset):
batch_size=None,
class_map='',
load_bytes=False,
repeats=0,
transform=None,
):
assert parser is not None
if isinstance(parser, str):
self.parser = create_parser(
parser, root=root, split=split, is_training=is_training, batch_size=batch_size)
parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats)
else:
self.parser = parser
self.transform = transform

@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin
root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier
if search_split and os.path.isdir(root):
root = _search_split(root, split)
ds = ImageDataset(root, parser=name, **kwargs)

@ -52,7 +52,7 @@ class ParserTfds(Parser):
components.
"""
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None):
def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0):
super().__init__()
self.root = root
self.split = split
@ -62,6 +62,7 @@ class ParserTfds(Parser):
assert batch_size is not None,\
"Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper"
self.batch_size = batch_size
self.repeats = repeats
self.builder = tfds.builder(name, data_dir=root)
# NOTE: please use tfds command line app to download & prepare datasets, I don't want to call
@ -126,7 +127,7 @@ class ParserTfds(Parser):
# avoid overloading threading w/ combo fo TF ds threads + PyTorch workers
ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers)
ds.options().experimental_threading.max_intra_op_parallelism = 1
if self.is_training:
if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
@ -143,7 +144,7 @@ class ParserTfds(Parser):
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = math.ceil(self.num_samples / self._num_pipelines)
target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size
@ -176,7 +177,7 @@ class ParserTfds(Parser):
def __len__(self):
# this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math.ceil(self.num_samples / self.dist_num_replicas)
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to samples

@ -141,6 +141,8 @@ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--epochs', type=int, default=200, metavar='N',
help='number of epochs to train (default: 2)')
parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
@ -450,7 +452,9 @@ def main():
# create the train and eval datasets
dataset_train = create_dataset(
args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size)
args.dataset,
root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)

@ -152,7 +152,7 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
test_time_pool = False
if not args.no_test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)

Loading…
Cancel
Save