diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 0599c78a..6a71db41 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -89,6 +89,7 @@ class IterableImageDataset(data.IterableDataset): split='train', is_training=False, batch_size=None, + seed=42, repeats=0, download=False, transform=None, @@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset): split=split, is_training=is_training, batch_size=batch_size, + seed=seed, repeats=repeats, download=download, ) @@ -125,6 +127,11 @@ class IterableImageDataset(data.IterableDataset): else: return 0 + def set_epoch(self, count): + # TFDS and WDS need external epoch count for deterministic cross process shuffle + if hasattr(self.parser, 'set_epoch'): + self.parser.set_epoch(count) + def filename(self, index, basename=False, absolute=False): assert False, 'Filename lookup by index not supported, use filenames().' diff --git a/timm/data/dataset_factory.py b/timm/data/dataset_factory.py index c2be63ad..2c2bb0bf 100644 --- a/timm/data/dataset_factory.py +++ b/timm/data/dataset_factory.py @@ -60,6 +60,7 @@ def create_dataset( is_training=False, download=False, batch_size=None, + seed=42, repeats=0, **kwargs ): @@ -68,7 +69,9 @@ def create_dataset( In parenthesis after each arg are the type of dataset supported for each arg, one of: * folder - default, timm folder (or tar) based ImageDataset * torch - torchvision based datasets + * HFDS - Hugging Face Datasets * TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset + * WDS - Webdataset * all - any of the above Args: @@ -79,11 +82,12 @@ def create_dataset( `imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder) class_map: specify class -> index mapping via text file or dict (folder) load_bytes: load data, return images as undecoded bytes (folder) - download: download dataset if not present and supported (TFDS, torch) + download: download dataset if not present and supported (HFDS, TFDS, torch) is_training: create dataset in train mode, this is different from the split. - For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS) - batch_size: batch size hint for (TFDS) - repeats: dataset repeats per iteration i.e. epoch (TFDS) + For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS) + batch_size: batch size hint for (TFDS, WDS) + seed: seed for iterable datasets (TFDS, WDS) + repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS) **kwargs: other args to pass to dataset Returns: @@ -130,14 +134,33 @@ def create_dataset( ds = ImageFolder(root, **kwargs) else: assert False, f"Unknown torchvision dataset {name}" - elif name.startswith('tfds/'): - ds = IterableImageDataset( - root, parser=name, split=split, is_training=is_training, - download=download, batch_size=batch_size, repeats=repeats, **kwargs) elif name.startswith('hfds/'): # NOTE right now, HF datasets default arrow format is a random-access Dataset, # There will be a IterableDataset variant too, TBD ds = ImageDataset(root, parser=name, split=split, **kwargs) + elif name.startswith('tfds/'): + ds = IterableImageDataset( + root, + parser=name, + split=split, + is_training=is_training, + download=download, + batch_size=batch_size, + repeats=repeats, + seed=seed, + **kwargs + ) + elif name.startswith('wds/'): + ds = IterableImageDataset( + root, + parser=name, + split=split, + is_training=is_training, + batch_size=batch_size, + repeats=repeats, + seed=seed, + **kwargs + ) else: # FIXME support more advance split cfg for ImageFolder/Tar datasets in the future if search_split and os.path.isdir(root): diff --git a/timm/data/loader.py b/timm/data/loader.py index a77e0a4c..35ccd503 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -5,6 +5,7 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d Hacked together by / Copyright 2019, Ross Wightman """ +import logging import random from contextlib import suppress from functools import partial @@ -22,6 +23,9 @@ from .random_erasing import RandomErasing from .mixup import FastCollateMixup +_logger = logging.getLogger(__name__) + + def fast_collate(batch): """ A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)""" assert isinstance(batch[0], tuple) @@ -57,11 +61,13 @@ def fast_collate(batch): assert False -def expand_to_chs(x, n): +def adapt_to_chs(x, n): if not isinstance(x, (tuple, list)): x = tuple(repeat(x, n)) - elif len(x) == 1: - x = x * n + elif len(x) != n: + x_mean = np.mean(x).item() + x = (x_mean,) * n + _logger.warning(f'Pretrained mean/std different shape than model, using avg value {x}.') else: assert len(x) == n, 'normalization stats must match image channels' return x @@ -83,8 +89,8 @@ class PrefetchLoader: re_count=1, re_num_splits=0): - mean = expand_to_chs(mean, channels) - std = expand_to_chs(std, channels) + mean = adapt_to_chs(mean, channels) + std = adapt_to_chs(std, channels) normalization_shape = (1, channels, 1, 1) self.loader = loader diff --git a/timm/data/parsers/parser_factory.py b/timm/data/parsers/parser_factory.py index a204bf6a..f5133433 100644 --- a/timm/data/parsers/parser_factory.py +++ b/timm/data/parsers/parser_factory.py @@ -14,12 +14,16 @@ def create_parser(name, root, split='train', **kwargs): # FIXME improve the selection right now just tfds prefix or fallback path, will need options to # explicitly select other options shortly - if prefix == 'tfds': - from .parser_tfds import ParserTfds # defer tensorflow import - parser = ParserTfds(root, name, split=split, **kwargs) - elif prefix == 'hfds': + if prefix == 'hfds': from .parser_hfds import ParserHfds # defer tensorflow import parser = ParserHfds(root, name, split=split, **kwargs) + elif prefix == 'tfds': + from .parser_tfds import ParserTfds # defer tensorflow import + parser = ParserTfds(root, name, split=split, **kwargs) + elif prefix == 'wds': + from .parser_wds import ParserWds + kwargs.pop('download', False) + parser = ParserWds(root, name, split=split, **kwargs) else: assert os.path.exists(root) # default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index c0128a5b..dd16b87a 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -7,6 +7,8 @@ https://www.tensorflow.org/datasets/catalog/overview#image_classification Hacked together by / Copyright 2020 Ross Wightman """ import math +import os + import torch import torch.distributed as dist from PIL import Image @@ -30,12 +32,14 @@ except ImportError as e: print(e) print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.") exit(1) + from .parser import Parser +from .shared_count import SharedCount -MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue -PREFETCH_SIZE = 2048 # examples to prefetch +MAX_TP_SIZE = os.environ.get('TFDS_TP_SIZE', 8) # maximum TF threadpool size, for jpeg decodes and queuing activities +SHUFFLE_SIZE = os.environ.get('TFDS_SHUFFLE_SIZE', 8192) # examples to shuffle in DS queue +PREFETCH_SIZE = os.environ.get('TFDS_PREFETCH_SIZE', 2048) # examples to prefetch def even_split_indices(split, n, num_examples): @@ -154,6 +158,14 @@ class ParserTfds(Parser): self.worker_seed = 0 # seed unique to each work instance self.subsplit = None # set when data is distributed across workers using sub-splits self.ds = None # initialized lazily on each dataloader worker process + self.init_count = 0 # number of ds TF data pipeline initializations + self.epoch_count = SharedCount() + # FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour + # of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs + self.reinit_each_iter = self.is_training + + def set_epoch(self, count): + self.epoch_count.value = count def _lazy_init(self): """ Lazily initialize the dataset. @@ -211,11 +223,15 @@ class ParserTfds(Parser): num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact? ) read_config = tfds.ReadConfig( - shuffle_seed=self.common_seed, + shuffle_seed=self.common_seed + self.epoch_count.value, shuffle_reshuffle_each_iteration=True, - input_context=input_context) + input_context=input_context, + ) ds = self.builder.as_dataset( - split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config) + split=self.subsplit or self.split, + shuffle_files=self.is_training, + read_config=read_config, + ) # avoid overloading threading w/ combo of TF ds threads + PyTorch workers options = tf.data.Options() thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading' @@ -230,9 +246,10 @@ class ParserTfds(Parser): ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) + self.init_count += 1 def __iter__(self): - if self.ds is None: + if self.ds is None or self.reinit_each_iter: self._lazy_init() # Compute a rounded up sample count that is used to: diff --git a/timm/data/parsers/parser_wds.py b/timm/data/parsers/parser_wds.py new file mode 100644 index 00000000..92429009 --- /dev/null +++ b/timm/data/parsers/parser_wds.py @@ -0,0 +1,448 @@ +""" Dataset parser interface for webdataset + +Hacked together by / Copyright 2022 Ross Wightman +""" +import io +import json +import logging +import math +import os +import random +import sys +from dataclasses import dataclass +from functools import partial +from itertools import islice +from typing import Dict, Tuple + +import torch +import torch.distributed as dist +import yaml +from PIL import Image +from torch.utils.data import Dataset, IterableDataset, get_worker_info + +try: + import webdataset as wds + from webdataset.filters import _shuffle + from webdataset.shardlists import expand_urls + from webdataset.tariterators import base_plus_ext, url_opener, tar_file_expander, valid_sample +except ImportError: + wds = None + expand_urls = None + +from .parser import Parser +from .shared_count import SharedCount + +_logger = logging.getLogger(__name__) + +SHUFFLE_SIZE = os.environ.get('WDS_SHUFFLE_SIZE', 8192) + + +def _load_info(root, basename='info'): + info_json = os.path.join(root, basename + '.json') + info_yaml = os.path.join(root, basename + '.yaml') + err_str = '' + try: + with wds.gopen.gopen(info_json) as f: + info_dict = json.load(f) + return info_dict + except Exception as e: + err_str = str(e) + try: + with wds.gopen.gopen(info_yaml) as f: + info_dict = yaml.safe_load(f) + return info_dict + except Exception: + pass + _logger.warning( + f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. ' + 'Falling back to provided split and size arg.') + return {} + + +@dataclass +class SplitInfo: + num_samples: int + filenames: Tuple[str] + shard_lengths: Tuple[int] = () + alt_label: str = '' + name: str = '' + + +def _parse_split_info(split: str, info: Dict): + def _info_convert(dict_info): + return SplitInfo( + num_samples=dict_info['num_samples'], + filenames=tuple(dict_info['filenames']), + shard_lengths=tuple(dict_info['shard_lengths']), + alt_label=dict_info.get('alt_label', ''), + name=dict_info['name'], + ) + + if 'tar' in split or '..' in split: + # split in WDS string braceexpand format, sample count can be included with a | separator + # ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples + split = split.split('|') + num_samples = 0 + split_name = '' + if len(split) > 1: + num_samples = int(split[1]) + split = split[0] + if '::' not in split: + split_parts = split.split('-', 3) + split_idx = len(split_parts) - 1 + if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']: + split_name = split_parts[split_idx] + + split_filenames = expand_urls(split) + if split_name: + split_info = info['splits'][split_name] + if not num_samples: + _fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])} + num_samples = sum(_fc[f] for f in split_filenames) + split_info['filenames'] = tuple(_fc.keys()) + split_info['shard_lengths'] = tuple(_fc.values()) + split_info['num_samples'] = num_samples + split_info = _info_convert(split_info) + else: + split_info = SplitInfo( + name=split_name, + num_samples=num_samples, + filenames=split_filenames, + ) + else: + if split not in info['splits']: + raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})") + split = split + split_info = info['splits'][split] + split_info = _info_convert(split_info) + + return split_info + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + _logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.') + return True + + +def _decode( + sample, + image_key='jpg', + image_format='RGB', + target_key='cls', + alt_label='' +): + """ Custom sample decode + * decode and convert PIL Image + * cls byte string label to int + * pass through JSON byte string (if it exists) without parse + """ + # decode class label, skip if alternate label not valid + if alt_label: + # alternative labels are encoded in json metadata + meta = json.loads(sample['json']) + class_label = int(meta[alt_label]) + if class_label < 0: + # skipped labels currently encoded as -1, may change to a null/None value + return None + else: + class_label = int(sample[target_key]) + + # decode image + with io.BytesIO(sample[image_key]) as b: + img = Image.open(b) + img.load() + if image_format: + img = img.convert(image_format) + + # json passed through in undecoded state + decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None)) + return decoded + + +def _decode_samples( + data, + image_key='jpg', + image_format='RGB', + target_key='cls', + alt_label='', + handler=log_and_continue): + """Decode samples with skip.""" + for sample in data: + try: + result = _decode( + sample, + image_key=image_key, + image_format=image_format, + target_key=target_key, + alt_label=alt_label + ) + except Exception as exn: + if handler(exn): + continue + else: + break + + # null results are skipped + if result is not None: + if isinstance(sample, dict) and isinstance(result, dict): + result["__key__"] = sample.get("__key__") + yield result + + +def pytorch_worker_seed(): + """get dataloader worker seed from pytorch""" + worker_info = get_worker_info() + if worker_info is not None: + # favour the seed already created for pytorch dataloader workers if it exists + return worker_info.seed + # fallback to wds rank based seed + return wds.utils.pytorch_worker_seed() + + +if wds is not None: + # conditional to avoid mandatory wds import (via inheritance of wds.PipelineStage) + class detshuffle2(wds.PipelineStage): + def __init__( + self, + bufsize=1000, + initial=100, + seed=0, + epoch=-1, + ): + self.bufsize = bufsize + self.initial = initial + self.seed = seed + self.epoch = epoch + + def run(self, src): + if isinstance(self.epoch, SharedCount): + epoch = self.epoch.value + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + + if self.seed < 0: + seed = pytorch_worker_seed() + epoch + else: + seed = self.seed + epoch + _logger.info('shuffle', self.seed, epoch, seed) # FIXME temporary + rng = random.Random(seed) + return _shuffle(src, self.bufsize, self.initial, rng) + +else: + detshuffle2 = None + + +class ResampledShards2(IterableDataset): + """An iterable dataset yielding a list of urls.""" + + def __init__( + self, + urls, + nshards=sys.maxsize, + worker_seed=None, + deterministic=True, + epoch=-1, + ): + """Sample shards from the shard list with replacement. + + :param urls: a list of URLs as a Python list or brace notation string + """ + super().__init__() + urls = wds.shardlists.expand_urls(urls) + self.urls = urls + assert isinstance(self.urls[0], str) + self.nshards = nshards + self.rng = random.Random() + self.worker_seed = pytorch_worker_seed if worker_seed is None else worker_seed + self.deterministic = deterministic + self.epoch = epoch + + def __iter__(self): + """Return an iterator over the shards.""" + if isinstance(self.epoch, SharedCount): + epoch = self.epoch.value + else: + # NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train) + # situation as different workers may wrap at different times (or not at all). + self.epoch += 1 + epoch = self.epoch + + if self.deterministic: + # reset seed w/ epoch if deterministic, worker seed should be deterministic due to arg.seed + self.rng = random.Random(self.worker_seed() + epoch) + + for _ in range(self.nshards): + index = self.rng.randint(0, len(self.urls) - 1) + yield dict(url=self.urls[index]) + + +class ParserWds(Parser): + def __init__( + self, + root, + name, + split, + is_training=False, + batch_size=None, + repeats=0, + seed=42, + input_name='jpg', + input_image='RGB', + target_name='cls', + target_image='', + prefetch_size=None, + shuffle_size=None, + ): + super().__init__() + if wds is None: + raise RuntimeError( + 'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.') + self.root = root + self.is_training = is_training + self.batch_size = batch_size + self.repeats = repeats + self.common_seed = seed # a seed that's fixed across all worker / distributed instances + self.shard_shuffle_size = 500 + self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE + + self.image_key = input_name + self.image_format = input_image + self.target_key = target_name + self.filename_key = 'filename' + self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet) + + self.info = _load_info(self.root) + self.split_info = _parse_split_info(split, self.info) + self.num_samples = self.split_info.num_samples + if not self.num_samples: + raise RuntimeError(f'Invalid split definition, no samples found.') + + # Distributed world state + self.dist_rank = 0 + self.dist_num_replicas = 1 + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + self.dist_rank = dist.get_rank() + self.dist_num_replicas = dist.get_world_size() + + # Attributes that are updated in _lazy_init + self.worker_info = None + self.worker_id = 0 + self.worker_seed = seed # seed unique to each worker instance + self.num_workers = 1 + self.global_worker_id = 0 + self.global_num_workers = 1 + self.init_count = 0 + self.epoch_count = SharedCount() + + # DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed + # is not handled in manner where it can be deterministic for each worker AND initialized up front + self.ds = None + + def set_epoch(self, count): + self.epoch_count.value = count + + def _lazy_init(self): + """ Lazily initialize worker (in worker processes) + """ + if self.worker_info is None: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + self.worker_info = worker_info + self.worker_id = worker_info.id + self.worker_seed = worker_info.seed + self.num_workers = worker_info.num_workers + self.global_num_workers = self.dist_num_replicas * self.num_workers + self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id + + # init data pipeline + abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames] + pipeline = [wds.SimpleShardList(abs_shard_filenames)] + # at this point we have an iterator over all the shards + if self.is_training: + pipeline.extend([ + detshuffle2(self.shard_shuffle_size, seed=self.common_seed, epoch=self.epoch_count), + self._split_by_node_and_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + self.sample_shuffle_size, + rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline + ]) + else: + pipeline.extend([ + self._split_by_node_and_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ]) + pipeline.extend([ + partial( + _decode_samples, + image_key=self.image_key, + image_format=self.image_format, + alt_label=self.split_info.alt_label + ) + ]) + self.ds = wds.DataPipeline(*pipeline) + + def _split_by_node_and_worker(self, src): + if self.global_num_workers > 1: + for s in islice(src, self.global_worker_id, None, self.global_num_workers): + yield s + else: + for s in src: + yield s + + def __iter__(self): + if self.ds is None: + self._lazy_init() + + if self.is_training: + num_worker_samples = math.floor(self.num_samples / self.global_num_workers) + if self.batch_size is not None: + num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size + ds = self.ds.with_epoch(num_worker_samples) + else: + if self.dist_num_replicas > 1: + # doing distributed validation w/ WDS is messy, hard to meet constraints that + # same # of batches needed across all replicas w/ seeing each sample once. + # with_epoch() is simple but could miss a shard's worth of samples in some workers, + # and duplicate in others. Best to keep num DL workers low and a divisor of #val shards. + num_worker_samples = math.ceil(self.num_samples / self.global_num_workers) + ds = self.ds.with_epoch(num_worker_samples) + else: + ds = self.ds + + i = 0 + _logger.info('start', i, self.worker_id) # FIXME temporary debug + for sample in ds: + yield sample[self.image_key], sample[self.target_key] + i += 1 + _logger.info('end', i, self.worker_id) # FIXME temporary debug + + def __len__(self): + return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) + + def _filename(self, index, basename=False, absolute=False): + assert False, "Not supported" # no random access to examples + + def filenames(self, basename=False, absolute=False): + """ Return all filenames in dataset, overrides base""" + if self.ds is None: + self._lazy_init() + + names = [] + for sample in self.ds: + if self.filename_key in sample: + name = sample[self.filename_key] + elif '__key__' in sample: + name = sample['__key__'] + self.key_ext + else: + assert False, "No supported name field present" + names.append(name) + if len(names) >= self.num_samples: + break # safety for ds.repeat() case + return names diff --git a/timm/data/parsers/shared_count.py b/timm/data/parsers/shared_count.py new file mode 100644 index 00000000..fe4e85b4 --- /dev/null +++ b/timm/data/parsers/shared_count.py @@ -0,0 +1,14 @@ +from multiprocessing import Value + + +class SharedCount: + def __init__(self, epoch: int = 0): + self.shared_epoch = Value('i', epoch) + + @property + def value(self): + return self.shared_epoch.value + + @value.setter + def value(self, epoch): + self.shared_epoch.value = epoch diff --git a/train.py b/train.py index 91980cb6..aeb31d6a 100755 --- a/train.py +++ b/train.py @@ -111,7 +111,9 @@ group.add_argument('--num-classes', type=int, default=None, metavar='N', group.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') group.add_argument('--img-size', type=int, default=None, metavar='N', - help='Image patch size (default: None => model default)') + help='Image size (default: None => model default)') +group.add_argument('--in-chans', type=int, default=None, metavar='N', + help='Image input channels (default: None => 3)') group.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') group.add_argument('--crop-pct', default=None, type=float, @@ -394,9 +396,16 @@ def main(): if args.fast_norm: set_fast_norm() + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chanes + elif args.input_size is not None: + in_chans = args.input_size[0] + model = create_model( args.model, pretrained=args.pretrained, + in_chans=in_chans, num_classes=args.num_classes, drop_rate=args.drop, drop_path_rate=args.drop_path, @@ -537,7 +546,8 @@ def main(): class_map=args.class_map, download=args.dataset_download, batch_size=args.batch_size, - repeats=args.epoch_repeats + seed=args.seed, + repeats=args.epoch_repeats, ) dataset_eval = create_dataset( @@ -547,7 +557,7 @@ def main(): is_training=False, class_map=args.class_map, download=args.dataset_download, - batch_size=args.batch_size + batch_size=args.batch_size, ) # setup mixup / cutmix @@ -610,6 +620,10 @@ def main(): worker_seeding=args.worker_seeding, ) + eval_workers = args.workers + if args.distributed and ('tfds' in args.dataset or 'wds' in args.dataset): + # FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training + eval_workers = min(2, args.workers) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], @@ -619,7 +633,7 @@ def main(): interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], - num_workers=args.workers, + num_workers=eval_workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, @@ -679,7 +693,9 @@ def main(): try: for epoch in range(start_epoch, num_epochs): - if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): + if hasattr(dataset_train, 'set_epoch'): + dataset_train.set_epoch(epoch) + elif args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) train_metrics = train_one_epoch(