From 4f338556d8f943c5df8f3b3333690cf14b7af5f1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 12 Nov 2021 13:40:26 -0800 Subject: [PATCH] Fixes and improvements for metrics, tfds parser, loader / transform handling * add back ability to create transform with loader * change 'samples' -> 'examples' for tfds wrapper to match tfds naming * add support for specifying feature names for input and target in tfds wrapper * add class_to_idx for image classification datasets in tfds wrapper * add accumulate_type to avg meters and metrics to allow float32 or float64 accumulation control with lower prec data * minor cleanup, log output rate prev and avg --- timm/bits/avg_scalar.py | 6 -- timm/bits/avg_tensor.py | 18 +++-- timm/bits/metric.py | 6 +- timm/bits/metric_accuracy.py | 67 ++++++------------ timm/bits/monitor.py | 9 ++- timm/data/__init__.py | 4 +- timm/data/loader.py | 34 +++++++--- timm/data/mixup.py | 3 +- timm/data/parsers/parser_tfds.py | 112 ++++++++++++++++++++----------- timm/data/transforms_factory.py | 29 +++++--- train.py | 22 ++---- validate.py | 15 ++--- 12 files changed, 179 insertions(+), 146 deletions(-) diff --git a/timm/bits/avg_scalar.py b/timm/bits/avg_scalar.py index 04d41c8e..6a6ce31b 100644 --- a/timm/bits/avg_scalar.py +++ b/timm/bits/avg_scalar.py @@ -2,12 +2,6 @@ class AvgMinMaxScalar: """Computes and stores the average and current value""" def __init__(self): - self.val = 0 - self.avg = 0 - self.min = None - self.max = None - self.sum = 0 - self.count = 0 self.reset() def reset(self): diff --git a/timm/bits/avg_tensor.py b/timm/bits/avg_tensor.py index 0aaf92e3..01219b56 100644 --- a/timm/bits/avg_tensor.py +++ b/timm/bits/avg_tensor.py @@ -4,7 +4,8 @@ import torch class AvgTensor: """Computes and stores the average and current value""" - def __init__(self): + def __init__(self, accumulate_dtype=torch.float32): + self.accumulate_dtype = accumulate_dtype self.sum = None self.count = None self.reset() @@ -16,7 +17,7 @@ class AvgTensor: def update(self, val: torch.Tensor, n=1): if self.sum is None: - self.sum = torch.zeros_like(val) + self.sum = torch.zeros_like(val, dtype=self.accumulate_dtype) self.count = torch.tensor(0, dtype=torch.long, device=val.device) self.sum += (val * n) self.count += n @@ -28,7 +29,13 @@ class AvgTensor: class TensorEma: """Computes and stores the average and current value""" - def __init__(self, smoothing_factor=0.9, init_zero=False): + def __init__( + self, + smoothing_factor=0.9, + init_zero=False, + accumulate_dtype=torch.float32 + ): + self.accumulate_dtype = accumulate_dtype self.smoothing_factor = smoothing_factor self.init_zero = init_zero self.val = None @@ -40,5 +47,8 @@ class TensorEma: def update(self, val): if self.val is None: - self.val = torch.zeros_like(val) if self.init_zero else val.clone() + if self.init_zero: + self.val = torch.zeros_like(val, dtype=self.accumulate_dtype) + else: + self.val = val.clone().to(dtype=self.accumulate_dtype) self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val diff --git a/timm/bits/metric.py b/timm/bits/metric.py index b18282b8..0a0f6d8b 100644 --- a/timm/bits/metric.py +++ b/timm/bits/metric.py @@ -10,6 +10,7 @@ from .distributed import all_gather_sequence, all_reduce_sequence MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]] + @dataclass class ValueInfo: initial: Optional[MetricValueT] = 0. @@ -20,7 +21,10 @@ class ValueInfo: class Metric(abc.ABC): - def __init__(self, dev_env: DeviceEnv = None): + def __init__( + self, + dev_env: DeviceEnv = None + ): self._infos: Dict[str, ValueInfo] = {} self._values: Dict[str, Optional[MetricValueT]] = {} self._values_dist: Dict[str, Optional[MetricValueT]] = {} diff --git a/timm/bits/metric_accuracy.py b/timm/bits/metric_accuracy.py index 0db72c6d..1a3fbefc 100644 --- a/timm/bits/metric_accuracy.py +++ b/timm/bits/metric_accuracy.py @@ -7,15 +7,22 @@ from .metric import Metric, ValueInfo class Accuracy(Metric): - def __init__(self, threshold=0.5, multi_label=False, dev_env=None): + def __init__( + self, + threshold=0.5, + multi_label=False, + accumulate_dtype=torch.float32, + dev_env=None, + ): super().__init__(dev_env=dev_env) + self.accumulate_dtype = accumulate_dtype self.threshold = threshold self.eps = 1e-8 self.multi_label = multi_label # statistics / counts - self._register_value('correct') - self._register_value('total') + self._register_value('correct', ValueInfo(dtype=accumulate_dtype)) + self._register_value('total', ValueInfo(dtype=accumulate_dtype)) def _update(self, predictions, target): raise NotImplemented() @@ -24,65 +31,31 @@ class Accuracy(Metric): raise NotImplemented() -# class AccuracyTopK(torch.nn.Module): -# -# def __init__(self, topk=(1, 5), device=None): -# super().__init__() -# self.eps = 1e-8 -# self.device = device -# self.topk = topk -# self.maxk = max(topk) -# # FIXME handle distributed operation -# -# # statistics / counts -# self.reset() -# -# def update(self, predictions: torch.Tensor, target: torch.Tensor): -# sorted_indices = predictions.topk(self.maxk, dim=1)[1] -# sorted_indices.t_() -# correct = sorted_indices.eq(target.reshape(1, -1).expand_as(sorted_indices)) -# -# batch_size = target.shape[0] -# correct_k = {k: correct[:k].reshape(-1).float().sum(0) for k in self.topk} -# for k, v in correct_k.items(): -# attr = f'_correct_top{k}' -# old_v = getattr(self, attr) -# setattr(self, attr, old_v + v) -# self._total_sum += batch_size -# -# def reset(self): -# for k in self.topk: -# setattr(self, f'_correct_top{k}', torch.tensor(0, dtype=torch.float32)) -# self._total_sum = torch.tensor(0, dtype=torch.float32) -# -# @property -# def counts(self): -# pass -# -# def compute(self) -> Dict[str, torch.Tensor]: -# # FIXME handle distributed reduction -# return {f'top{k}': 100 * getattr(self, f'_correct_top{k}') / self._total_sum for k in self.topk} - - class AccuracyTopK(Metric): - def __init__(self, topk=(1, 5), dev_env: DeviceEnv = None): + def __init__( + self, + topk=(1, 5), + accumulate_dtype=torch.float32, + dev_env: DeviceEnv = None + ): super().__init__(dev_env=dev_env) + self.accumulate_dtype = accumulate_dtype self.eps = 1e-8 self.topk = topk self.maxk = max(topk) # statistics / counts for k in self.topk: - self._register_value(f'top{k}') - self._register_value('total') + self._register_value(f'top{k}', ValueInfo(dtype=accumulate_dtype)) + self._register_value('total', ValueInfo(dtype=accumulate_dtype)) self.reset() def _update(self, predictions: torch.Tensor, target: torch.Tensor): batch_size = predictions.shape[0] sorted_indices = predictions.topk(self.maxk, dim=1)[1] target_reshape = target.reshape(-1, 1).expand_as(sorted_indices) - correct = sorted_indices.eq(target_reshape).float().sum(0) + correct = sorted_indices.eq(target_reshape).to(dtype=self.accumulate_dtype).sum(0) for k in self.topk: attr_name = f'top{k}' correct_at_k = correct[:k].sum() diff --git a/timm/bits/monitor.py b/timm/bits/monitor.py index e4dd95f0..ca9c19be 100644 --- a/timm/bits/monitor.py +++ b/timm/bits/monitor.py @@ -156,7 +156,7 @@ class Monitor: step_end_idx: Optional[int] = None, epoch: Optional[int] = None, loss: Optional[float] = None, - rate: Optional[float] = None, + rate: Optional[Union[float, Tuple[float, float]]] = None, phase_suffix: str = '', **kwargs, ): @@ -168,12 +168,17 @@ class Monitor: step_end_idx = max(0, kwargs.pop('num_steps') - 1) phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:' progress = 100. * step_idx / step_end_idx if step_end_idx else 0. + rate_str = '' + if isinstance(rate, (tuple, list)): + rate_str = f'Rate: {rate[0]:.2f}/s ({rate[1]:.2f}/s)' + elif rate is not None: + rate_str = f'Rate: {rate:.2f}/s' text_update = [ phase_title, f'{epoch}' if epoch is not None else None, f'[{step_idx}]' if step_end_idx is None else None, f'[{step_idx}/{step_end_idx} ({progress:>3.0f}%)]' if step_end_idx is not None else None, - f'Rate: {rate:.2f}/s' if rate is not None else None, + rate_str, f'Loss: {loss:.5f}' if loss is not None else None, ] _add_kwargs(text_update, **kwargs) diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 163bcea7..de978419 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,10 +1,10 @@ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform -from .config import resolve_data_config +from .config import resolve_data_config, PreprocessCfg, AugCfg, MixupCfg from .constants import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset -from .loader import create_loader_v2, PreprocessCfg, AugCfg, MixupCfg +from .loader import create_loader_v2 from .mixup import Mixup, FastCollateMixup from .parsers import create_parser from .real_labels import RealLabelsImagenet diff --git a/timm/data/loader.py b/timm/data/loader.py index 67d30765..750067d4 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -6,18 +6,19 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d Hacked together by / Copyright 2020 Ross Wightman """ -from typing import Tuple, Optional, Union, Callable +from typing import Optional, Callable -import torch.utils.data import numpy as np +import torch.utils.data from timm.bits import DeviceEnv from .collate import fast_collate -from .config import PreprocessCfg, AugCfg, MixupCfg +from .config import PreprocessCfg, MixupCfg from .distributed_sampler import OrderedDistributedSampler from .fetcher import Fetcher from .mixup import FastCollateMixup from .prefetcher_cuda import PrefetcherCuda +from .transforms_factory import create_transform_v2 def _worker_init(worker_id): @@ -31,9 +32,11 @@ def create_loader_v2( batch_size: int, is_training: bool = False, dev_env: Optional[DeviceEnv] = None, - normalize=True, pp_cfg: PreprocessCfg = PreprocessCfg(), mix_cfg: MixupCfg = None, + create_transform: bool = True, + normalize_in_transform: bool = True, + separate_transform: bool = False, num_workers: int = 1, collate_fn: Optional[Callable] = None, pin_memory: bool = False, @@ -46,10 +49,12 @@ def create_loader_v2( dataset: batch_size: is_training: - dev_env: - normalize: + dev_env: pp_cfg: - mix_cfg: + mix_cfg: + create_transform: + normalize_in_transform: + separate_transform: num_workers: collate_fn: pin_memory: @@ -62,6 +67,14 @@ def create_loader_v2( if dev_env is None: dev_env = DeviceEnv.instance() + if create_transform: + dataset.transform = create_transform_v2( + cfg=pp_cfg, + is_training=is_training, + normalize=normalize_in_transform, + separate=separate_transform, + ) + sampler = None if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: @@ -110,18 +123,19 @@ def create_loader_v2( loader = loader_class(dataset, **loader_args) fetcher_kwargs = dict( - normalize=normalize, + normalize=not normalize_in_transform, mean=pp_cfg.mean, std=pp_cfg.std, ) - if normalize and is_training and pp_cfg.aug is not None: + if not normalize_in_transform and is_training and pp_cfg.aug is not None: + # If normalization can be done in the prefetcher, random erasing is done there too + # NOTE RandomErasing does not work well in XLA so normalize_in_transform will be True fetcher_kwargs.update(dict( re_prob=pp_cfg.aug.re_prob, re_mode=pp_cfg.aug.re_mode, re_count=pp_cfg.aug.re_count, num_aug_splits=pp_cfg.aug.num_aug_splits, )) - if dev_env.type_cuda: loader = PrefetcherCuda(loader, **fetcher_kwargs) else: diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 074b6941..bf5d1b0e 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -103,6 +103,7 @@ class Mixup: """ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, mode='batch', correct_lam=True, label_smoothing=0., num_classes=0): + assert num_classes > 0, 'num_classes must be set for target generation' self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.cutmix_minmax = cutmix_minmax @@ -113,8 +114,6 @@ class Mixup: self.mix_prob = prob self.switch_prob = switch_prob self.label_smoothing = label_smoothing - if label_smoothing > 0.: - assert num_classes > 0 self.num_classes = num_classes self.mode = mode self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index e96e827b..dd24c55c 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -27,36 +27,44 @@ except ImportError as e: exit(1) from .parser import Parser -from timm.bits import get_global_device +from timm.bits import get_global_device, is_global_device MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 16384 # samples to shuffle in DS queue -PREFETCH_SIZE = 2048 # samples to prefetch +SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue +PREFETCH_SIZE = 2048 # examples to prefetch -def even_split_indices(split, n, num_samples): - partitions = [round(i * num_samples / n) for i in range(n + 1)] +def even_split_indices(split, n, num_examples): + partitions = [round(i * num_examples / n) for i in range(n + 1)] return [f"{split}[{partitions[i]}:{partitions[i+1]}]" for i in range(n)] +def get_class_labels(info): + if 'label' not in info.features: + return {} + class_label = info.features['label'] + class_to_idx = {n: class_label.str2int(n) for n in class_label.names} + return class_to_idx + + class ParserTfds(Parser): """ Wrap Tensorflow Datasets for use in PyTorch There several things to be aware of: - * To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of + * To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last https://github.com/pytorch/pytorch/issues/33413 * With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch from each worker could be a different size. For training this is worked around by option above, for - validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced + validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced across replicas are of same size. This will slightly alter the results, distributed validation will not be 100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse - since there are up to N * J extra samples with IterableDatasets. + since there are up to N * J extra examples with IterableDatasets. * The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of replicas and dataloader workers you can use. For really small datasets that only contain a few shards you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the benefit of distributed training or fast dataloading should be much less for small datasets. - * This wrapper is currently configured to return individual, decompressed image samples from the TFDS + * This wrapper is currently configured to return individual, decompressed image examples from the TFDS dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream components. @@ -72,6 +80,10 @@ class ParserTfds(Parser): download=False, repeats=0, seed=42, + input_name='image', + input_image='RGB', + target_name='label', + target_image='', prefetch_size=None, shuffle_size=None, max_threadpool_size=None @@ -83,10 +95,13 @@ class ParserTfds(Parser): name: tfds dataset name (eg `imagenet2012`) split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`) is_training: training mode, shuffle enabled, dataset len rounded by batch_size - batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes + batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes download: download and build TFDS dataset if set, otherwise must use tfds CLI repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1) seed: common seed for shard shuffle across all distributed/worker instances + input_image: image mode if input is an image (currently PIL mode string) + target_name: name of Feature to return as target (label) + target_image: image mode if target is an image (currently PIL mode string) prefetch_size: override default tf.data prefetch buffer size shuffle_size: override default tf.data shuffle buffer size max_threadpool_size: override default threadpool size for tf.data @@ -101,25 +116,39 @@ class ParserTfds(Parser): self.batch_size = batch_size self.repeats = repeats self.common_seed = seed # a seed that's fixed across all worker / distributed instances + + # Performance settings self.prefetch_size = prefetch_size or PREFETCH_SIZE self.shuffle_size = shuffle_size or SHUFFLE_SIZE self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE # TFDS builder and split information + self.input_name = input_name # FIXME support tuples / lists of inputs and targets and full range of Feature + self.input_image = input_image + self.target_name = target_name + self.target_image = target_image self.builder = tfds.builder(name, data_dir=root) # NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag if download: self.builder.download_and_prepare() + self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {} self.split_info = self.builder.info.splits[split] - self.num_samples = self.split_info.num_examples + self.num_examples = self.split_info.num_examples # Distributed world state self.dist_rank = 0 self.dist_num_replicas = 1 - dev_env = get_global_device() # FIXME allow to work without devenv usage? - if dev_env.distributed and dev_env.world_size > 1: - self.dist_rank = dev_env.global_rank - self.dist_num_replicas = dev_env.world_size + if is_global_device(): + dev_env = get_global_device() + if dev_env.distributed and dev_env.world_size > 1: + self.dist_rank = dev_env.global_rank + self.dist_num_replicas = dev_env.world_size + else: + # FIXME warn if we fallback to torch distributed? + import torch.distributed as dist + if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1: + self.dist_rank = dist.get_rank() + self.dist_num_replicas = dist.get_world_size() # Attributes that are updated in _lazy_init, including the tf.data pipeline itself self.global_num_workers = 1 @@ -159,17 +188,17 @@ class ParserTfds(Parser): I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing the data across workers. For training InputContext is used to assign shards to nodes unless num_shards in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or - for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding. + for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding. """ should_subsplit = self.global_num_workers > 1 and ( self.split_info.num_shards < self.global_num_workers or not self.is_training) if should_subsplit: - # split the dataset w/o using sharding for more even samples / worker, can result in less optimal + # split the dataset w/o using sharding for more even examples / worker, can result in less optimal # read patterns for distributed training (overlap across shards) so better to use InputContext there if has_buggy_even_splits: # my even_split workaround doesn't work on subsplits, upgrade tfds! if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo): - subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples) + subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples) self.subsplit = subsplits[global_worker_id] else: subsplits = tfds.even_splits(self.split, self.global_num_workers) @@ -200,8 +229,8 @@ class ParserTfds(Parser): # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.is_training: - ds = ds.shuffle(min(self.num_samples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) - ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size)) + ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed) + ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size)) self.ds = tfds.as_numpy(ds) def __iter__(self): @@ -210,44 +239,49 @@ class ParserTfds(Parser): # Compute a rounded up sample count that is used to: # 1. make batches even cross workers & replicas in distributed validation. - # This adds extra samples and will slightly alter validation results. + # This adds extra examples and will slightly alter validation results. # 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size # batches are produced (underlying tfds iter wraps around) - target_sample_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers) + target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers) if self.is_training: # round up to nearest batch_size per worker-replica - target_sample_count = math.ceil(target_sample_count / self.batch_size) * self.batch_size + target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size # Iterate until exhausted or sample count hits target when training (ds.repeat enabled) - sample_count = 0 - for sample in self.ds: - img = Image.fromarray(sample['image'], mode='RGB') - yield img, sample['label'] - sample_count += 1 - if self.is_training and sample_count >= target_sample_count: + example_count = 0 + for example in self.ds: + input_data = example[self.input_name] + if self.input_image: + input_data = Image.fromarray(input_data, mode=self.input_image) + target_data = example[self.target_name] + if self.target_image: + target_data = Image.fromarray(target_data, mode=self.target_image) + yield input_data, target_data + example_count += 1 + if self.is_training and example_count >= target_example_count: # Need to break out of loop when repeat() is enabled for training w/ oversampling - # this results in extra samples per epoch but seems more desirable than dropping + # this results in extra examples per epoch but seems more desirable than dropping # up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes) break - # Pad across distributed nodes (make counts equal by adding samples) + # Pad across distributed nodes (make counts equal by adding examples) if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \ - 0 < sample_count < target_sample_count: + 0 < example_count < target_example_count: # Validation batch padding only done for distributed training where results are reduced across nodes. # For single process case, it won't matter if workers return different batch sizes. # If using input_context or % based splits, sample count can vary significantly across workers and this # approach should not be used (hence disabled if self.subsplit isn't set). - while sample_count < target_sample_count: - yield img, sample['label'] # yield prev sample again - sample_count += 1 + while example_count < target_example_count: + yield input_data, target_data # yield prev sample again + example_count += 1 def __len__(self): - # this is just an estimate and does not factor in extra samples added to pad batches based on + # this is just an estimate and does not factor in extra examples added to pad batches based on # complete worker & replica info (not available until init in dataloader). - return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas) + return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas) def _filename(self, index, basename=False, absolute=False): - assert False, "Not supported" # no random access to samples + assert False, "Not supported" # no random access to examples def filenames(self, basename=False, absolute=False): """ Return all filenames in dataset, overrides base""" @@ -255,7 +289,7 @@ class ParserTfds(Parser): self._lazy_init() names = [] for sample in self.ds: - if len(names) > self.num_samples: + if len(names) > self.num_examples: break # safety for ds.repeat() case if 'file_name' in sample: name = sample['file_name'] diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 1c8d15e2..24c89ce3 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -22,6 +22,7 @@ def transforms_noaug_train( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, normalize=False, + compose=True, ): if interpolation == 'random': # random interpolation not supported with no-aug @@ -38,7 +39,7 @@ def transforms_noaug_train( else: # (pre)fetcher and collate will handle tensor conversion and normalize tfl += [ToNumpy()] - return transforms.Compose(tfl) + return transforms.Compose(tfl) if compose else tfl def transforms_imagenet_train( @@ -49,6 +50,7 @@ def transforms_imagenet_train( aug_cfg=AugCfg(), normalize=False, separate=False, + compose=True, ): """ If separate==True, the transforms are returned as a tuple of 3 separate transforms @@ -122,9 +124,13 @@ def transforms_imagenet_train( if separate: # return each transform stage separately - return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + if compose: + return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) + else: + return primary_tfl, secondary_tfl, final_tfl else: - return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) + tfl = primary_tfl + secondary_tfl + final_tfl + return transforms.Compose(tfl) if compose else tfl def transforms_imagenet_eval( @@ -134,6 +140,7 @@ def transforms_imagenet_eval( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, normalize=False, + compose=True, ): crop_pct = crop_pct or DEFAULT_CROP_PCT @@ -160,7 +167,7 @@ def transforms_imagenet_eval( # (pre)fetcher and collate will handle tensor conversion and normalize tfl += [ToNumpy()] - return transforms.Compose(tfl) + return transforms.Compose(tfl) if compose else tfl def create_transform_v2( @@ -168,6 +175,7 @@ def create_transform_v2( is_training=False, normalize=False, separate=False, + compose=True, tf_preprocessing=False, ): """ @@ -175,10 +183,10 @@ def create_transform_v2( Args: cfg: Pre-processing configuration is_training (bool): Create transform for training pre-processing - tf_preprocessing (bool): Use Tensorflow pre-processing (for validation) normalize (bool): Enable normalization in transforms (otherwise handled by fetcher/pre-fetcher) separate (bool): Return transforms separated into stages (for train) - + compose (bool): Wrap transforms in transform.Compose(), returns list otherwise + tf_preprocessing (bool): Use Tensorflow pre-processing (for validation) Returns: """ @@ -202,7 +210,9 @@ def create_transform_v2( interpolation=cfg.interpolation, normalize=normalize, mean=cfg.mean, - std=cfg.std) + std=cfg.std, + compose=compose, + ) elif is_training: transform = transforms_imagenet_train( img_size, @@ -211,7 +221,9 @@ def create_transform_v2( std=cfg.std, aug_cfg=cfg.aug, normalize=normalize, - separate=separate) + separate=separate, + compose=compose, + ) else: assert not separate, "Separate transforms not supported for validation preprocessing" transform = transforms_imagenet_eval( @@ -221,6 +233,7 @@ def create_transform_v2( mean=cfg.mean, std=cfg.std, normalize=normalize, + compose=compose, ) return transform diff --git a/train.py b/train.py index fb6b4319..0440f551 100755 --- a/train.py +++ b/train.py @@ -563,17 +563,13 @@ def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool): # if using PyTorch XLA and RandomErasing is enabled, we must normalize and do RE in transforms on CPU normalize_in_transform = dev_env.type_xla and args.reprob > 0 - - dataset_train.transform = create_transform_v2( - cfg=train_pp_cfg, is_training=True, normalize=normalize_in_transform) - loader_train = create_loader_v2( dataset_train, batch_size=args.batch_size, is_training=True, - normalize=not normalize_in_transform, pp_cfg=train_pp_cfg, mix_cfg=mixup_cfg, + normalize_in_transform=normalize_in_transform, num_workers=args.workers, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader @@ -587,19 +583,17 @@ def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool): std=data_config['std'], ) - dataset_eval.transform = create_transform_v2( - cfg=eval_pp_cfg, is_training=False, normalize=normalize_in_transform) - eval_workers = args.workers if 'tfds' in args.dataset: - # FIXME reduce validation issues when using TFDS w/ workers and distributed training + # FIXME reduces validation padding issues when using TFDS w/ workers and distributed training eval_workers = min(2, args.workers) + loader_eval = create_loader_v2( dataset_eval, batch_size=args.validation_batch_size or args.batch_size, is_training=False, - normalize=not normalize_in_transform, pp_cfg=eval_pp_cfg, + normalize_in_transform=normalize_in_transform, num_workers=eval_workers, pin_memory=args.pin_mem, ) @@ -708,7 +702,7 @@ def after_train_step( step_end_idx=step_end_idx, epoch=state.epoch, loss=loss_avg.item(), - rate=tracker.get_avg_iter_rate(global_batch_size), + rate=(tracker.get_last_iter_rate(global_batch_size), tracker.get_avg_iter_rate(global_batch_size)), lr=lr_avg, ) @@ -756,16 +750,14 @@ def evaluate( dev_env.mark_step() elif dev_env.type_cuda: dev_env.synchronize() - - # FIXME uncommenting this fixes race btw model `output`/`loss` and loss_m/accuracy_m meter input + # FIXME uncommenting this fixes race btw model `output` / `loss` and loss_m / accuracy_m meter input # for PyTorch XLA GPU use. # This issue does not exist for normal PyTorch w/ GPU (CUDA) or PyTorch XLA w/ TPU. # loss.item() - tracker.mark_iter_step_end() + losses_m.update(loss, output.size(0)) accuracy_m.update(output, target) - if last_step or step_idx % log_interval == 0: top1, top5 = accuracy_m.compute().values() loss_avg = losses_m.compute() diff --git a/validate.py b/validate.py index 03a90dc0..d2eca03e 100755 --- a/validate.py +++ b/validate.py @@ -154,11 +154,10 @@ def validate(args): std=data_config['std'], ) - dataset.transform = create_transform_v2(cfg=eval_pp_cfg, is_training=False) - loader = create_loader_v2( dataset, batch_size=args.batch_size, + is_training=False, pp_cfg=eval_pp_cfg, num_workers=args.workers, pin_memory=args.pin_mem) @@ -176,24 +175,20 @@ def validate(args): last_step = step_idx == num_steps - 1 tracker.mark_iter_data_end() - # compute output with dev_env.autocast(): output = model(sample) - if valid_labels is not None: output = output[:, valid_labels] loss = criterion(output, target) - if dev_env.type_cuda: - dev_env.synchronize() - tracker.mark_iter_step_end() - if dev_env.type_xla: dev_env.mark_step() + elif dev_env.type_cuda: + dev_env.synchronize() + tracker.mark_iter_step_end() if real_labels is not None: real_labels.add_result(output) - losses.update(loss.detach(), sample.size(0)) accuracy.update(output.detach(), target) @@ -205,7 +200,7 @@ def validate(args): phase='eval', step_idx=step_idx, num_steps=num_steps, - rate=args.batch_size / tracker.iter_time.avg, + rate=(tracker.get_last_iter_rate(output.shape[0]), tracker.get_avg_iter_rate(args.batch_size)), loss=loss_avg.item(), top1=top1.item(), top5=top5.item(),