diff --git a/timm/data/config.py b/timm/data/config.py index dad8eb13..38f5689a 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -5,7 +5,7 @@ from .constants import * _logger = logging.getLogger(__name__) -def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=True): +def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): new_config = {} default_cfg = default_cfg if not default_cfg and model is not None and hasattr(model, 'default_cfg'): diff --git a/timm/data/dataset.py b/timm/data/dataset.py index a7c5ebed..e719f3f6 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -73,12 +73,13 @@ class IterableImageDataset(data.IterableDataset): batch_size=None, class_map='', load_bytes=False, + repeats=0, transform=None, ): assert parser is not None if isinstance(parser, str): self.parser = create_parser( - parser, root=root, split=split, is_training=is_training, batch_size=batch_size) + parser, root=root, split=split, is_training=is_training, batch_size=batch_size, repeats=repeats) else: self.parser = parser self.transform = transform diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index b2c9688f..ccc99d5c 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -23,6 +23,7 @@ def create_dataset(name, root, split='validation', search_split=True, is_trainin root, parser=name, split=split, is_training=is_training, batch_size=batch_size, **kwargs) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future + kwargs.pop('repeats', 0) # FIXME currently only Iterable dataset support the repeat multiplier if search_split and os.path.isdir(root): root = _search_split(root, split) ds = ImageDataset(root, parser=name, **kwargs) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 15361cb5..0c2e10c0 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -52,7 +52,7 @@ class ParserTfds(Parser): components. """ - def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None): + def __init__(self, root, name, split='train', shuffle=False, is_training=False, batch_size=None, repeats=0): super().__init__() self.root = root self.split = split @@ -62,6 +62,7 @@ class ParserTfds(Parser): assert batch_size is not None,\ "Must specify batch_size in training mode for reasonable behaviour w/ TFDS wrapper" self.batch_size = batch_size + self.repeats = repeats self.builder = tfds.builder(name, data_dir=root) # NOTE: please use tfds command line app to download & prepare datasets, I don't want to call @@ -126,7 +127,7 @@ class ParserTfds(Parser): # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) ds.options().experimental_threading.max_intra_op_parallelism = 1 - if self.is_training: + if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually @@ -143,7 +144,7 @@ class ParserTfds(Parser): # This adds extra samples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(self.num_samples / self._num_pipelines) + target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self._num_pipelines) if self.is_training: # round up to nearest batch_size per worker-replica target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size @@ -176,7 +177,7 @@ class ParserTfds(Parser): def __len__(self): # this is just an estimate and does not factor in extra samples added to pad batches based on # complete worker & replica info (not available until init in dataloader). - return math.ceil(self.num_samples / self.dist_num_replicas) + return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): assert False, "Not supported" # no random access to samples diff --git a/train.py b/train.py index 9db5175b..2fdf68d8 100755 --- a/train.py +++ b/train.py @@ -141,6 +141,8 @@ parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 2)') +parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', + help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') parser.add_argument('--decay-epochs', type=float, default=30, metavar='N', @@ -450,7 +452,9 @@ def main(): # create the train and eval datasets dataset_train = create_dataset( - args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size) + args.dataset, + root=args.data_dir, split=args.train_split, is_training=True, + batch_size=args.batch_size, repeats=args.epoch_repeats) dataset_eval = create_dataset( args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) diff --git a/validate.py b/validate.py index a311112d..6df71aab 100755 --- a/validate.py +++ b/validate.py @@ -152,7 +152,7 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) _logger.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(vars(args), model=model, use_test_size=True) + data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True) test_time_pool = False if not args.no_test_pool: model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)