From 40457e569179642e21f8120c8854532b395ce1ca Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 Aug 2021 12:45:43 -0700 Subject: [PATCH] Transforms, augmentation work for bits, add RandomErasing support for XLA (pushing into transforms), revamp of transform/preproc config, etc ongoing... --- clean_checkpoint.py | 15 +- inference.py | 4 +- timm/bits/device_env.py | 3 + timm/bits/device_env_cuda.py | 5 +- timm/bits/device_env_xla.py | 5 + timm/bits/train_setup.py | 7 +- timm/data/__init__.py | 6 +- timm/data/auto_augment.py | 18 ++- timm/data/collate.py | 2 +- timm/data/config.py | 43 ++++++ timm/data/fetcher.py | 60 +++++--- timm/data/loader.py | 131 +++++++++-------- timm/data/mixup.py | 70 +++++++--- timm/data/prefetcher_cuda.py | 54 ++++--- timm/data/random_erasing.py | 87 ++++++++++-- timm/data/transforms.py | 57 +++++--- timm/data/transforms_factory.py | 241 ++++++++++++++++++++------------ timm/models/helpers.py | 13 +- train.py | 106 +++++++++----- validate.py | 34 +++-- 20 files changed, 628 insertions(+), 333 deletions(-) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index a8edcc91..1553fc4b 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -14,6 +14,9 @@ import hashlib import shutil from collections import OrderedDict +from timm.models.helpers import load_state_dict +from timm.utils import setup_default_logging + parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') @@ -29,6 +32,7 @@ _TEMP_NAME = './_checkpoint.pth' def main(): args = parser.parse_args() + setup_default_logging() if os.path.exists(args.output): print("Error: Output filename ({}) already exists.".format(args.output)) @@ -37,17 +41,8 @@ def main(): # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save if args.checkpoint and os.path.isfile(args.checkpoint): print("=> Loading checkpoint '{}'".format(args.checkpoint)) - checkpoint = torch.load(args.checkpoint, map_location='cpu') - + state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema) new_state_dict = OrderedDict() - if isinstance(checkpoint, dict): - state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict' - if state_dict_key in checkpoint: - state_dict = checkpoint[state_dict_key] - else: - state_dict = checkpoint - else: - assert False for k, v in state_dict.items(): if args.clean_aux_bn and 'aux_bn' in k: # If all aux_bn keys are removed, the SplitBN layers will end up as normal and diff --git a/inference.py b/inference.py index 5fcf1e60..1f248dc7 100755 --- a/inference.py +++ b/inference.py @@ -13,7 +13,7 @@ import numpy as np import torch from timm.models import create_model, apply_test_time_pool -from timm.data import ImageDataset, create_loader, resolve_data_config +from timm.data import ImageDataset, create_loader_v2, resolve_data_config from timm.utils import AverageMeter, setup_default_logging torch.backends.cudnn.benchmark = True @@ -82,7 +82,7 @@ def main(): else: model = model.cuda() - loader = create_loader( + loader = create_loader_v2( ImageDataset(args.data), input_size=config['input_size'], batch_size=args.batch_size, diff --git a/timm/bits/device_env.py b/timm/bits/device_env.py index 0a926e69..e992ee80 100644 --- a/timm/bits/device_env.py +++ b/timm/bits/device_env.py @@ -128,6 +128,9 @@ class DeviceEnv: def mark_step(self): pass # NO-OP for non-XLA devices + def synchronize(self, tensors: Optional[TensorList] = None): + pass + def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False): dist.all_reduce(tensor, op=op) if average: diff --git a/timm/bits/device_env_cuda.py b/timm/bits/device_env_cuda.py index c57dfda5..33760d97 100644 --- a/timm/bits/device_env_cuda.py +++ b/timm/bits/device_env_cuda.py @@ -6,7 +6,7 @@ from typing import Optional import torch from torch.nn.parallel import DistributedDataParallel, DataParallel -from .device_env import DeviceEnv, DeviceEnvType +from .device_env import DeviceEnv, DeviceEnvType, TensorList def is_cuda_available(): @@ -63,3 +63,6 @@ class DeviceEnvCuda(DeviceEnv): assert not self.distributed wrapped = [DataParallel(m, **kwargs) for m in modules] return wrapped[0] if len(wrapped) == 1 else wrapped + + def synchronize(self, tensors: Optional[TensorList] = None): + torch.cuda.synchronize(self.device) diff --git a/timm/bits/device_env_xla.py b/timm/bits/device_env_xla.py index 46517f7a..2dad9273 100644 --- a/timm/bits/device_env_xla.py +++ b/timm/bits/device_env_xla.py @@ -8,9 +8,11 @@ from torch.distributed import ReduceOp try: import torch_xla.core.xla_model as xm + import torch_xla _HAS_XLA = True except ImportError as e: xm = None + torch_xla = None _HAS_XLA = False try: @@ -81,6 +83,9 @@ class DeviceEnvXla(DeviceEnv): def mark_step(self): xm.mark_step() + def synchronize(self, tensors: Optional[TensorList] = None): + torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True) + def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False): assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed op = _PT_TO_XM_OP[op] diff --git a/timm/bits/train_setup.py b/timm/bits/train_setup.py index 1480de63..5aca908f 100644 --- a/timm/bits/train_setup.py +++ b/timm/bits/train_setup.py @@ -89,7 +89,6 @@ def setup_model_and_optimizer( train_state = TrainState(model=model, updater=updater, model_ema=model_ema) if resume_path: - # FIXME this is not implemented yet, do a hack job before proper TrainState serialization? load_train_state( train_state, resume_path, @@ -141,11 +140,7 @@ def setup_model_and_optimizer_deepspeed( if resume_path: # FIXME deepspeed resumes differently - load_legacy_checkpoint( - train_state, - resume_path, - load_opt=resume_opt, - log_info=dev_env.primary) + assert False if dev_env.distributed: train_state = dataclasses.replace( diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7d3cb2b4..163bcea7 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -4,9 +4,9 @@ from .config import resolve_data_config from .constants import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset -from .loader import create_loader +from .loader import create_loader_v2, PreprocessCfg, AugCfg, MixupCfg from .mixup import Mixup, FastCollateMixup from .parsers import create_parser from .real_labels import RealLabelsImagenet -from .transforms import * -from .transforms_factory import create_transform \ No newline at end of file +from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy +from .transforms_factory import create_transform_v2, create_transform diff --git a/timm/data/auto_augment.py b/timm/data/auto_augment.py index 7cbd2dee..46c36531 100644 --- a/timm/data/auto_augment.py +++ b/timm/data/auto_augment.py @@ -41,6 +41,22 @@ _HPARAMS_DEFAULT = dict( _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +def _pil_interp(method): + def _convert(m): + if method == 'bicubic': + return Image.BICUBIC + elif method == 'lanczos': + return Image.LANCZOS + elif method == 'hamming': + return Image.HAMMING + else: + return Image.BILINEAR + if isinstance(method, (list, tuple)): + return [_convert(m) if isinstance(m, str) else m for m in method] + else: + return _convert(method) if isinstance(method, str) else method + + def _interpolation(kwargs): interpolation = kwargs.pop('resample', Image.BILINEAR) if isinstance(interpolation, (list, tuple)): @@ -325,7 +341,7 @@ class AugmentOp: self.hparams = hparams.copy() self.kwargs = dict( fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, - resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + resample=_pil_interp(hparams['interpolation']) if 'interpolation' in hparams else _RANDOM_INTERPOLATION, ) # If magnitude_std is > 0, we introduce some randomness diff --git a/timm/data/collate.py b/timm/data/collate.py index a1e37e1f..28f2af2a 100644 --- a/timm/data/collate.py +++ b/timm/data/collate.py @@ -30,7 +30,7 @@ def fast_collate(batch): elif isinstance(batch[0][0], torch.Tensor): targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) assert len(targets) == batch_size - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=batch[0][0].dtype) for i in range(batch_size): tensor[i].copy_(batch[i][0]) return tensor, targets diff --git a/timm/data/config.py b/timm/data/config.py index 06920d7d..f9ed7b6c 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -1,10 +1,53 @@ import logging +from dataclasses import dataclass +from typing import Tuple, Optional, Union + from .constants import * _logger = logging.getLogger(__name__) +@dataclass +class AugCfg: + scale_range: Tuple[float, float] = (0.08, 1.0) + ratio_range: Tuple[float, float] = (3 / 4, 4 / 3) + hflip_prob: float = 0.5 + vflip_prob: float = 0. + + color_jitter: float = 0.4 + auto_augment: Optional[str] = None + + re_prob: float = 0. + re_mode: str = 'const' + re_count: int = 1 + + num_aug_splits: int = 0 + + +@dataclass +class PreprocessCfg: + input_size: Tuple[int, int, int] = (3, 224, 224) + mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN + std: Tuple[float, ...] = IMAGENET_DEFAULT_STD + interpolation: str = 'bilinear' + crop_pct: float = 0.875 + aug: AugCfg = None + + +@dataclass +class MixupCfg: + prob: float = 1.0 + switch_prob: float = 0.5 + mixup_alpha: float = 1. + cutmix_alpha: float = 0. + cutmix_minmax: Optional[Tuple[float, float]] = None + mode: str = 'batch' + correct_lam: bool = True + label_smoothing: float = 0.1 + num_classes: int = 0 + + def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): new_config = {} default_cfg = default_cfg diff --git a/timm/data/fetcher.py b/timm/data/fetcher.py index ec5afe8a..c833b596 100644 --- a/timm/data/fetcher.py +++ b/timm/data/fetcher.py @@ -2,7 +2,7 @@ import torch from .constants import * from .random_erasing import RandomErasing -from. mixup import FastCollateMixup +from .mixup import FastCollateMixup class FetcherXla: @@ -12,31 +12,55 @@ class FetcherXla: class Fetcher: - def __init__(self, - loader, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - device=None, - dtype=None, - re_prob=0., - re_mode='const', - re_count=1, - re_num_splits=0): + def __init__( + self, + loader, + device: torch.device, + dtype=torch.float32, + normalize=True, + normalize_shape=(1, 3, 1, 1), + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + re_prob=0., + re_mode='const', + re_count=1, + num_aug_splits=0, + use_mp_loader=False, + ): self.loader = loader self.device = torch.device(device) - self.dtype = dtype or torch.float32 - self.mean = torch.tensor([x * 255 for x in mean], dtype=self.dtype, device=self.device).view(1, 3, 1, 1) - self.std = torch.tensor([x * 255 for x in std], dtype=self.dtype, device=self.device).view(1, 3, 1, 1) + self.dtype = dtype + if normalize: + self.mean = torch.tensor( + [x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape) + self.std = torch.tensor( + [x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape) + else: + self.mean = None + self.std = None if re_prob > 0.: + # NOTE RandomErasing shouldn't be used here w/ XLA devices self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device=device) + probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits) else: self.random_erasing = None + self.use_mp_loader = use_mp_loader + if use_mp_loader: + # FIXME testing for TPU use + import torch_xla.distributed.parallel_loader as pl + self._loader = pl.MpDeviceLoader(loader, device) + else: + self._loader = loader + print('re', self.random_erasing, self.mean, self.std) def __iter__(self): - for sample, target in self.loader: - sample = sample.to(device=self.device, dtype=self.dtype).sub_(self.mean).div_(self.std) - target = target.to(device=self.device) + for sample, target in self._loader: + if not self.use_mp_loader: + sample = sample.to(device=self.device) + target = target.to(device=self.device) + sample = sample.to(dtype=self.dtype) + if self.mean is not None: + sample.sub_(self.mean).div_(self.std) if self.random_erasing is not None: sample = self.random_erasing(sample) yield sample, target diff --git a/timm/data/loader.py b/timm/data/loader.py index e8722b29..9d60cd59 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -6,74 +6,52 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Tuple, Optional, Union, Callable + import torch.utils.data from timm.bits import DeviceEnv - -from .fetcher import Fetcher -from .prefetcher_cuda import PrefetcherCuda from .collate import fast_collate -from .transforms_factory import create_transform -from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .config import PreprocessCfg, AugCfg, MixupCfg from .distributed_sampler import OrderedDistributedSampler +from .fetcher import Fetcher +from .mixup import FastCollateMixup +from .prefetcher_cuda import PrefetcherCuda -def create_loader( - dataset, - input_size, - batch_size, - is_training=False, - dev_env=None, - no_aug=False, - re_prob=0., - re_mode='const', - re_count=1, - re_split=False, - scale=None, - ratio=None, - hflip=0.5, - vflip=0., - color_jitter=0.4, - auto_augment=None, - num_aug_splits=0, - interpolation='bilinear', - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - num_workers=1, - crop_pct=None, - collate_fn=None, - pin_memory=False, - tf_preprocessing=False, - use_multi_epochs_loader=False, - persistent_workers=True, +def create_loader_v2( + dataset: torch.utils.data.Dataset, + batch_size: int, + is_training: bool = False, + dev_env: Optional[DeviceEnv] = None, + normalize=True, + pp_cfg: PreprocessCfg = PreprocessCfg(), + mix_cfg: MixupCfg = None, + num_workers: int = 1, + collate_fn: Optional[Callable] = None, + pin_memory: bool = False, + use_multi_epochs_loader: bool = False, + persistent_workers: bool = True, ): - re_num_splits = 0 - if re_split: - # apply RE to second half of batch if no aug split otherwise line up with aug split - re_num_splits = num_aug_splits or 2 - dataset.transform = create_transform( - input_size, - is_training=is_training, - use_fetcher=True, - no_aug=no_aug, - scale=scale, - ratio=ratio, - hflip=hflip, - vflip=vflip, - color_jitter=color_jitter, - auto_augment=auto_augment, - interpolation=interpolation, - mean=mean, - std=std, - crop_pct=crop_pct, - tf_preprocessing=tf_preprocessing, - re_prob=re_prob, - re_mode=re_mode, - re_count=re_count, - re_num_splits=re_num_splits, - separate=num_aug_splits > 0, - ) + """ + + Args: + dataset: + batch_size: + is_training: + dev_env: + normalize: + pp_cfg: + mix_cfg: + num_workers: + collate_fn: + pin_memory: + use_multi_epochs_loader: + persistent_workers: + + Returns: + """ if dev_env is None: dev_env = DeviceEnv.instance() @@ -85,10 +63,24 @@ def create_loader( else: # This will add extra duplicate entries to result in equal num # of samples per-process, will slightly alter validation results - sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) + sampler = OrderedDistributedSampler( + dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) if collate_fn is None: - collate_fn = fast_collate + if mix_cfg is not None and mix_cfg.prob > 0: + collate_fn = FastCollateMixup( + mixup_alpha=mix_cfg.mixup_alpha, + cutmix_alpha=mix_cfg.cutmix_alpha, + cutmix_minmax=mix_cfg.cutmix_minmax, + prob=mix_cfg.prob, + switch_prob=mix_cfg.switch_prob, + mode=mix_cfg.mode, + correct_lam=mix_cfg.correct_lam, + label_smoothing=mix_cfg.label_smoothing, + num_classes=mix_cfg.num_classes, + ) + else: + collate_fn = fast_collate loader_class = torch.utils.data.DataLoader if use_multi_epochs_loader: @@ -110,13 +102,18 @@ def create_loader( loader = loader_class(dataset, **loader_args) fetcher_kwargs = dict( - mean=mean, - std=std, - re_prob=re_prob if is_training and not no_aug else 0., - re_mode=re_mode, - re_count=re_count, - re_num_splits=re_num_splits + normalize=normalize, + mean=pp_cfg.mean, + std=pp_cfg.std, ) + if normalize and is_training and pp_cfg.aug is not None: + fetcher_kwargs.update(dict( + re_prob=pp_cfg.aug.re_prob, + re_mode=pp_cfg.aug.re_mode, + re_count=pp_cfg.aug.re_count, + num_aug_splits=pp_cfg.aug.num_aug_splits, + )) + if dev_env.type_cuda: loader = PrefetcherCuda(loader, **fetcher_kwargs) else: diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 38477548..b618bb7c 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -102,7 +102,7 @@ class Mixup: num_classes (int): number of classes for target """ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, - mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): + mode='batch', correct_lam=True, label_smoothing=0., num_classes=0): self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.cutmix_minmax = cutmix_minmax @@ -113,6 +113,8 @@ class Mixup: self.mix_prob = prob self.switch_prob = switch_prob self.label_smoothing = label_smoothing + if label_smoothing > 0.: + assert num_classes > 0 self.num_classes = num_classes self.mode = mode self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix @@ -218,17 +220,30 @@ class Mixup: return x, target +def blend(a, b, lam, is_tensor=False, round_output=True): + if is_tensor: + blend = a.to(dtype=torch.float32) * lam + b.to(dtype=torch.float32) * (1 - lam) + if round_output: + torch.round(blend, out=blend) + else: + blend = a.astype(np.float32) * lam + b.astype(np.float32) * (1 - lam) + if round_output: + np.rint(blend, out=blend) + return blend + + class FastCollateMixup(Mixup): """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch A Mixup impl that's performed while collating the batches. """ - def _mix_elem_collate(self, output, batch, half=False): + def _mix_elem_collate(self, output, batch, half=False, is_tensor=False): batch_size = len(batch) num_elem = batch_size // 2 if half else batch_size assert len(output) == num_elem lam_batch, use_cutmix = self._params_per_elem(num_elem) + round_output = output.dtype == torch.uint8 for i in range(num_elem): j = batch_size - i - 1 lam = lam_batch[i] @@ -236,22 +251,23 @@ class FastCollateMixup(Mixup): if lam != 1.: if use_cutmix[i]: if not half: - mixed = mixed.copy() + mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] lam_batch[i] = lam else: - mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.rint(mixed, out=mixed) - output[i] += torch.from_numpy(mixed.astype(np.uint8)) + mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output) + mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8)) + output[i].copy_(mixed) if half: lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) return torch.tensor(lam_batch).unsqueeze(1) - def _mix_pair_collate(self, output, batch): + def _mix_pair_collate(self, output, batch, is_tensor=False): batch_size = len(batch) lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + round_output = output.dtype == torch.uint8 for i in range(batch_size // 2): j = batch_size - i - 1 lam = lam_batch[i] @@ -262,24 +278,30 @@ class FastCollateMixup(Mixup): if use_cutmix[i]: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - patch_i = mixed_i[:, yl:yh, xl:xh].copy() + patch_i = mixed_i[:, yl:yh, xl:xh] + patch_i = patch_i.clone() if is_tensor else patch_i.copy() # don't want to modify while iterating mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] mixed_j[:, yl:yh, xl:xh] = patch_i lam_batch[i] = lam else: - mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) - mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) + mixed_temp = blend(mixed_i, mixed_j, lam, is_tensor, round_output) + mixed_j = blend(mixed_j, mixed_i, lam, is_tensor, round_output) mixed_i = mixed_temp - np.rint(mixed_j, out=mixed_j) - np.rint(mixed_i, out=mixed_i) - output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) - output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + if is_tensor: + mixed_i = mixed_i.to(dtype=output.dtype) + mixed_j = mixed_j.to(dtype=output.dtype) + else: + mixed_i = torch.from_numpy(mixed_i.astype(np.uint8)) + mixed_j = torch.from_numpy(mixed_j.astype(np.uint8)) + output[i].copy_(mixed_i) + output[j].copy_(mixed_j) lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) return torch.tensor(lam_batch).unsqueeze(1) - def _mix_batch_collate(self, output, batch): + def _mix_batch_collate(self, output, batch, is_tensor=False): batch_size = len(batch) lam, use_cutmix = self._params_per_batch() + round_output = output.dtype == torch.uint8 if use_cutmix: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) @@ -288,12 +310,12 @@ class FastCollateMixup(Mixup): mixed = batch[i][0] if lam != 1.: if use_cutmix: - mixed = mixed.copy() # don't want to modify the original while iterating + mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] else: - mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.rint(mixed, out=mixed) - output[i] += torch.from_numpy(mixed.astype(np.uint8)) + mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output) + mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8)) + output[i].copy_(mixed) return lam def __call__(self, batch, _=None): @@ -302,13 +324,15 @@ class FastCollateMixup(Mixup): half = 'half' in self.mode if half: batch_size //= 2 - output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + is_tensor = isinstance(batch[0][0], torch.Tensor) + output_dtype = batch[0][0].dtype if is_tensor else torch.uint8 # always uint8 if numpy src + output = torch.zeros((batch_size, *batch[0][0].shape), dtype=output_dtype) if self.mode == 'elem' or self.mode == 'half': - lam = self._mix_elem_collate(output, batch, half=half) + lam = self._mix_elem_collate(output, batch, half=half, is_tensor=is_tensor) elif self.mode == 'pair': - lam = self._mix_pair_collate(output, batch) + lam = self._mix_pair_collate(output, batch, is_tensor=is_tensor) else: - lam = self._mix_batch_collate(output, batch) + lam = self._mix_batch_collate(output, batch, is_tensor=is_tensor) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') target = target[:batch_size] diff --git a/timm/data/prefetcher_cuda.py b/timm/data/prefetcher_cuda.py index 4f1c4e10..9432df59 100644 --- a/timm/data/prefetcher_cuda.py +++ b/timm/data/prefetcher_cuda.py @@ -7,25 +7,34 @@ from .random_erasing import RandomErasing class PrefetcherCuda: - def __init__(self, - loader, - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD, - fp16=False, - re_prob=0., - re_mode='const', - re_count=1, - re_num_splits=0): + def __init__( + self, + loader, + device: torch.device = torch.device('cuda'), + dtype=torch.float32, + normalize=True, + normalize_shape=(1, 3, 1, 1), + mean=IMAGENET_DEFAULT_MEAN, + std=IMAGENET_DEFAULT_STD, + num_aug_splits=0, + re_prob=0., + re_mode='const', + re_count=1 + ): self.loader = loader - self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) - self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) - self.fp16 = fp16 - if fp16: - self.mean = self.mean.half() - self.std = self.std.half() + self.device = device + self.dtype = dtype + if normalize: + self.mean = torch.tensor( + [x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape) + self.std = torch.tensor( + [x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape) + else: + self.mean = None + self.std = None if re_prob > 0.: self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) + probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits, device=device) else: self.random_erasing = None @@ -35,12 +44,11 @@ class PrefetcherCuda: for next_input, next_target in self.loader: with torch.cuda.stream(stream): - next_input = next_input.cuda(non_blocking=True) - next_target = next_target.cuda(non_blocking=True) - if self.fp16: - next_input = next_input.half().sub_(self.mean).div_(self.std) - else: - next_input = next_input.float().sub_(self.mean).div_(self.std) + next_input = next_input.to(device=self.device, non_blocking=True) + next_input = next_input.to(dtype=self.dtype) + if self.mean is not None: + next_input.sub_(self.mean).div_(self.std) + next_target = next_target.to(device=self.device, non_blocking=True) if self.random_erasing is not None: next_input = self.random_erasing(next_input) @@ -76,4 +84,4 @@ class PrefetcherCuda: @mixup_enabled.setter def mixup_enabled(self, x): if isinstance(self.loader.collate_fn, FastCollateMixup): - self.loader.collate_fn.mixup_enabled = x \ No newline at end of file + self.loader.collate_fn.mixup_enabled = x diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index 78967d10..65d085a9 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -38,21 +38,20 @@ class RandomErasing: 'const' - erase block is constant color of 0 for all channels 'rand' - erase block is same per-channel random (normal) color 'pixel' - erase block is per-pixel random (normal) color - max_count: maximum number of erasing blocks per image, area per box is scaled by count. + count: maximum number of erasing blocks per image, area per box is scaled by count. per-image count is randomly chosen between 1 and this value. """ def __init__( self, probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, - mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): + mode='const', count=1, num_splits=0): self.probability = probability self.min_area = min_area self.max_area = max_area max_aspect = max_aspect or 1 / min_aspect self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) - self.min_count = min_count - self.max_count = max_count or min_count + self.count = count self.num_splits = num_splits mode = mode.lower() self.rand_color = False @@ -63,14 +62,13 @@ class RandomErasing: self.per_pixel = True # per pixel random normal else: assert not mode or mode == 'const' - self.device = device def _erase(self, img, chan, img_h, img_w, dtype): + device = img.device if random.random() > self.probability: return area = img_h * img_w - count = self.min_count if self.min_count == self.max_count else \ - random.randint(self.min_count, self.max_count) + count = random.randint(1, self.count) if self.count > 1 else self.count for _ in range(count): for attempt in range(10): target_area = random.uniform(self.min_area, self.max_area) * area / count @@ -81,17 +79,76 @@ class RandomErasing: top = random.randint(0, img_h - h) left = random.randint(0, img_w - w) img[:, top:top + h, left:left + w] = _get_pixels( - self.per_pixel, self.rand_color, (chan, h, w), - dtype=dtype, device=self.device) + self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=device) break - def __call__(self, input): - if len(input.size()) == 3: - self._erase(input, *input.size(), input.dtype) + def __call__(self, x): + if len(x.size()) == 3: + self._erase(x, *x.shape, x.dtype) else: - batch_size, chan, img_h, img_w = input.size() + batch_size, chan, img_h, img_w = x.shape # skip first slice of batch if num_splits is set (for clean portion of samples) batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 for i in range(batch_start, batch_size): - self._erase(input[i], chan, img_h, img_w, input.dtype) - return input + self._erase(x[i], chan, img_h, img_w, x.dtype) + return x + + +class RandomErasingMasked: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + This variant of RandomErasing is intended to be applied to either a batch + or single image tensor after it has been normalized by dataset mean and std. + Args: + probability: Probability that the Random Erasing operation will be performed for each box (count) + min_area: Minimum percentage of erased area wrt input image area. + max_area: Maximum percentage of erased area wrt input image area. + min_aspect: Minimum aspect ratio of erased area. + count: maximum number of erasing blocks per image, area per box is scaled by count. + per-image count is between 0 and this value. + """ + + def __init__( + self, + probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, + mode='const', count=1, num_splits=0): + self.probability = probability + self.min_area = min_area + self.max_area = max_area + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + self.mode = mode # FIXME currently ignored, add back options besides normal mean=0, std=1 noise? + self.count = count + self.num_splits = num_splits + + @torch.no_grad() + def __call__(self, x: torch.Tensor) -> torch.Tensor: + device = x.device + batch_size, _, img_h, img_w = x.shape + batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 + + # NOTE simplified from v1 with with one count value and same prob applied for all + enable = (torch.empty((batch_size, self.count), device=device).uniform_() < self.probability).float() + enable = enable / enable.sum(dim=1, keepdim=True).clamp(min=1) + target_area = torch.empty( + (batch_size, self.count), device=device).uniform_(self.min_area, self.max_area) * enable + aspect_ratio = torch.empty((batch_size, self.count), device=device).uniform_(*self.log_aspect_ratio).exp() + h_coord = torch.arange(0, img_h, device=device).unsqueeze(-1).expand(-1, self.count).float() + w_coord = torch.arange(0, img_w, device=device).unsqueeze(-1).expand(-1, self.count).float() + h_mid = torch.rand((batch_size, self.count), device=device) * img_h + w_mid = torch.rand((batch_size, self.count), device=device) * img_w + noise = torch.empty_like(x[0]).normal_() + + for i in range(batch_start, batch_size): + h_half = (img_h / 2) * torch.sqrt(target_area[i] * aspect_ratio[i]) # 1/2 box h + h_mask = (h_coord > (h_mid[i] - h_half)) & (h_coord < (h_mid[i] + h_half)) + w_half = (img_w / 2) * torch.sqrt(target_area[i] / aspect_ratio[i]) # 1/2 box w + w_mask = (w_coord > (w_mid[i] - w_half)) & (w_coord < (w_mid[i] + w_half)) + #mask = (h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1) + #x[i].copy_(torch.where(mask, noise, x[i])) + mask = ~(h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1) + x[i] = x[i].where(mask, noise) + #x[i].masked_scatter_(mask, noise) + return x diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 4220304f..03f0e825 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -1,5 +1,7 @@ import torch import torchvision.transforms.functional as F +from torchvision.transforms import InterpolationMode + from PIL import Image import warnings import math @@ -30,29 +32,40 @@ class ToTensor: return torch.from_numpy(np_img).to(dtype=self.dtype) -_pil_interpolation_to_str = { - Image.NEAREST: 'PIL.Image.NEAREST', - Image.BILINEAR: 'PIL.Image.BILINEAR', - Image.BICUBIC: 'PIL.Image.BICUBIC', - Image.LANCZOS: 'PIL.Image.LANCZOS', - Image.HAMMING: 'PIL.Image.HAMMING', - Image.BOX: 'PIL.Image.BOX', -} +class ToTensorNormalize: + def __init__(self, mean, std, dtype=torch.float32, device=torch.device('cpu')): + self.dtype = dtype + mean = torch.as_tensor(mean, dtype=dtype, device=device) + std = torch.as_tensor(std, dtype=dtype, device=device) + if (std == 0).any(): + raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype)) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + self.mean = mean + self.std = std -def _pil_interp(method): - if method == 'bicubic': - return Image.BICUBIC - elif method == 'lanczos': - return Image.LANCZOS - elif method == 'hamming': - return Image.HAMMING - else: - # default bilinear, do we want to allow nearest? - return Image.BILINEAR + def __call__(self, pil_img): + mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32} + img = torch.from_numpy( + np.array(pil_img, mode_to_nptype.get(pil_img.mode, np.uint8)) + ) + if pil_img.mode == '1': + img = 255 * img + img = img.view(pil_img.size[1], pil_img.size[0], len(pil_img.getbands())) + img = img.permute((2, 0, 1)) + if isinstance(img, torch.ByteTensor): + img = img.to(self.dtype) + img.sub_(self.mean * 255.).div_(self.std * 255.) + else: + img = img.to(self.dtype) + img.sub_(self.mean).div_(self.std) + return img -_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) +_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC) class RandomResizedCropAndInterpolation: @@ -82,7 +95,7 @@ class RandomResizedCropAndInterpolation: if interpolation == 'random': self.interpolation = _RANDOM_INTERPOLATION else: - self.interpolation = _pil_interp(interpolation) + self.interpolation = InterpolationMode(interpolation) self.scale = scale self.ratio = ratio @@ -146,9 +159,9 @@ class RandomResizedCropAndInterpolation: def __repr__(self): if isinstance(self.interpolation, (tuple, list)): - interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) + interpolate_str = ' '.join([x.value for x in self.interpolation]) else: - interpolate_str = _pil_interpolation_to_str[self.interpolation] + interpolate_str = self.interpolation.value format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 16e08a39..1c8d15e2 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -4,59 +4,50 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M Hacked together by / Copyright 2020 Ross Wightman """ import math +from typing import Union, Tuple import torch from torchvision import transforms -from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform -from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor +from timm.data.config import PreprocessCfg, AugCfg +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT from timm.data.random_erasing import RandomErasing +from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensorNormalize def transforms_noaug_train( - img_size=224, + img_size: Union[int, Tuple[int]] = 224, interpolation='bilinear', - use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + normalize=False, ): if interpolation == 'random': # random interpolation not supported with no-aug interpolation = 'bilinear' tfl = [ - transforms.Resize(img_size, _pil_interp(interpolation)), + transforms.Resize(img_size, transforms.InterpolationMode(interpolation)), transforms.CenterCrop(img_size) ] - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: + if normalize: tfl += [ transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) + transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std)) ] + else: + # (pre)fetcher and collate will handle tensor conversion and normalize + tfl += [ToNumpy()] return transforms.Compose(tfl) def transforms_imagenet_train( - img_size=224, - scale=None, - ratio=None, - hflip=0.5, - vflip=0., - color_jitter=0.4, - auto_augment=None, + img_size: Union[int, Tuple[int]] = 224, interpolation='random', - use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - re_prob=0., - re_mode='const', - re_count=1, - re_num_splits=0, + aug_cfg=AugCfg(), + normalize=False, separate=False, ): """ @@ -66,18 +57,24 @@ def transforms_imagenet_train( * a portion of the data through the secondary transform * normalizes and converts the branches above with the third, final transform """ - scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range - ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range + scale_range = tuple(aug_cfg.scale_range or (0.08, 1.0)) # default imagenet scale range + ratio_range = tuple(aug_cfg.ratio_range or (3. / 4., 4. / 3.)) # default imagenet ratio range + + # 'primary' train transforms include random resize + crop w/ optional horizontal and vertical flipping aug. + # This is the core of standard ImageNet ResNet and Inception pre-processing primary_tfl = [ - RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] - if hflip > 0.: - primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] - if vflip > 0.: - primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] + RandomResizedCropAndInterpolation(img_size, scale=scale_range, ratio=ratio_range, interpolation=interpolation)] + if aug_cfg.hflip_prob > 0.: + primary_tfl += [transforms.RandomHorizontalFlip(p=aug_cfg.hflip_prob)] + if aug_cfg.vflip_prob > 0.: + primary_tfl += [transforms.RandomVerticalFlip(p=aug_cfg.vflip_prob)] + # 'secondary' transform stage includes either color jitter (could add lighting too) or auto-augmentations + # such as AutoAugment, RandAugment, AugMix, etc secondary_tfl = [] - if auto_augment: - assert isinstance(auto_augment, str) + if aug_cfg.auto_augment: + aa = aug_cfg.auto_augment + assert isinstance(aa, str) if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: @@ -87,58 +84,63 @@ def transforms_imagenet_train( img_mean=tuple([min(255, round(255 * x)) for x in mean]), ) if interpolation and interpolation != 'random': - aa_params['interpolation'] = _pil_interp(interpolation) - if auto_augment.startswith('rand'): - secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] - elif auto_augment.startswith('augmix'): + aa_params['interpolation'] = interpolation + if aa.startswith('rand'): + secondary_tfl += [rand_augment_transform(aa, aa_params)] + elif aa.startswith('augmix'): aa_params['translate_pct'] = 0.3 - secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] + secondary_tfl += [augment_and_mix_transform(aa, aa_params)] else: - secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] - elif color_jitter is not None: + secondary_tfl += [auto_augment_transform(aa, aa_params)] + elif aug_cfg.color_jitter is not None: # color jitter is enabled when not using AA - if isinstance(color_jitter, (list, tuple)): + cj = aug_cfg.color_jitter + if isinstance(cj, (list, tuple)): # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # or 4 if also augmenting hue - assert len(color_jitter) in (3, 4) + assert len(cj) in (3, 4) else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue - color_jitter = (float(color_jitter),) * 3 - secondary_tfl += [transforms.ColorJitter(*color_jitter)] + cj = (float(cj),) * 3 + secondary_tfl += [transforms.ColorJitter(*cj)] + # 'final' transform stage includes normalization, followed by optional random erasing and tensor conversion final_tfl = [] - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - final_tfl += [ToNumpy()] - else: + if normalize: final_tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) + ToTensorNormalize(mean=mean, std=std) ] - if re_prob > 0.: - final_tfl.append( - RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) + if aug_cfg.re_prob > 0.: + final_tfl.append(RandomErasing( + aug_cfg.re_prob, + mode=aug_cfg.re_mode, + count=aug_cfg.re_count, + num_splits=aug_cfg.num_aug_splits)) + else: + # when normalize disabled, (pre)fetcher and collate will handle tensor conversion and normalize + final_tfl += [ToNumpy()] if separate: + # return each transform stage separately return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) else: return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) def transforms_imagenet_eval( - img_size=224, + img_size: Union[int, Tuple[int]] = 224, crop_pct=None, interpolation='bilinear', - use_prefetcher=False, mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD): + std=IMAGENET_DEFAULT_STD, + normalize=False, +): crop_pct = crop_pct or DEFAULT_CROP_PCT if isinstance(img_size, (tuple, list)): assert len(img_size) == 2 if img_size[-1] == img_size[-2]: + # FIXME handle case where img is square and we want non aspect preserving resize # fall-back to older behaviour so Resize scales to shortest edge if target is square scale_size = int(math.floor(img_size[0] / crop_pct)) else: @@ -147,27 +149,87 @@ def transforms_imagenet_eval( scale_size = int(math.floor(img_size / crop_pct)) tfl = [ - transforms.Resize(scale_size, _pil_interp(interpolation)), + transforms.Resize(scale_size, transforms.InterpolationMode(interpolation)), transforms.CenterCrop(img_size), ] - if use_prefetcher: - # prefetcher and collate will handle tensor conversion and norm - tfl += [ToNumpy()] - else: + if normalize: tfl += [ - transforms.ToTensor(), - transforms.Normalize( - mean=torch.tensor(mean), - std=torch.tensor(std)) + ToTensorNormalize(mean=mean, std=std) ] + else: + # (pre)fetcher and collate will handle tensor conversion and normalize + tfl += [ToNumpy()] return transforms.Compose(tfl) +def create_transform_v2( + cfg=PreprocessCfg(), + is_training=False, + normalize=False, + separate=False, + tf_preprocessing=False, +): + """ + + Args: + cfg: Pre-processing configuration + is_training (bool): Create transform for training pre-processing + tf_preprocessing (bool): Use Tensorflow pre-processing (for validation) + normalize (bool): Enable normalization in transforms (otherwise handled by fetcher/pre-fetcher) + separate (bool): Return transforms separated into stages (for train) + + Returns: + + """ + input_size = cfg.input_size + if isinstance(input_size, (tuple, list)): + img_size = input_size[-2:] + else: + img_size = input_size + + if tf_preprocessing: + assert not normalize, "Expecting normalization to be handled in (pre)fetcher w/ TF preprocessing" + assert not separate, "Separate transforms not supported for TF preprocessing" + from timm.data.tf_preprocessing import TfPreprocessTransform + transform = TfPreprocessTransform( + is_training=is_training, size=img_size, interpolation=cfg.interpolation) + else: + if is_training and cfg.aug is None: + assert not separate, "Cannot perform split augmentation with no_aug" + transform = transforms_noaug_train( + img_size, + interpolation=cfg.interpolation, + normalize=normalize, + mean=cfg.mean, + std=cfg.std) + elif is_training: + transform = transforms_imagenet_train( + img_size, + interpolation=cfg.interpolation, + mean=cfg.mean, + std=cfg.std, + aug_cfg=cfg.aug, + normalize=normalize, + separate=separate) + else: + assert not separate, "Separate transforms not supported for validation preprocessing" + transform = transforms_imagenet_eval( + img_size, + interpolation=cfg.interpolation, + crop_pct=cfg.crop_pct, + mean=cfg.mean, + std=cfg.std, + normalize=normalize, + ) + + return transform + + def create_transform( input_size, is_training=False, - use_fetcher=False, + use_prefetcher=False, no_aug=False, scale=None, ratio=None, @@ -191,7 +253,8 @@ def create_transform( else: img_size = input_size - if tf_preprocessing and use_fetcher: + normalize_in_transform = not use_prefetcher + if tf_preprocessing and use_prefetcher: assert not separate, "Separate transforms not supported for TF preprocessing" from timm.data.tf_preprocessing import TfPreprocessTransform transform = TfPreprocessTransform( @@ -202,35 +265,41 @@ def create_transform( transform = transforms_noaug_train( img_size, interpolation=interpolation, - use_prefetcher=use_fetcher, mean=mean, - std=std) + std=std, + normalize=normalize_in_transform, + ) elif is_training: - transform = transforms_imagenet_train( - img_size, - scale=scale, - ratio=ratio, - hflip=hflip, - vflip=vflip, + aug_cfg = AugCfg( + scale_range=scale, + ratio_range=ratio, + hflip_prob=hflip, + vflip_prob=vflip, color_jitter=color_jitter, auto_augment=auto_augment, - interpolation=interpolation, - use_prefetcher=use_fetcher, - mean=mean, - std=std, re_prob=re_prob, re_mode=re_mode, re_count=re_count, - re_num_splits=re_num_splits, - separate=separate) + num_aug_splits=re_num_splits, + ) + transform = transforms_imagenet_train( + img_size, + interpolation=interpolation, + mean=mean, + std=std, + aug_cfg=aug_cfg, + normalize=normalize_in_transform, + separate=separate + ) else: - assert not separate, "Separate transforms not supported for validation preprocessing" + assert not separate, "Separate transforms not supported for validation pre-processing" transform = transforms_imagenet_eval( img_size, interpolation=interpolation, - use_prefetcher=use_fetcher, mean=mean, std=std, - crop_pct=crop_pct) + crop_pct=crop_pct, + normalize=normalize_in_transform, + ) return transform diff --git a/timm/models/helpers.py b/timm/models/helpers.py index adfef550..39f44c87 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__) def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') - state_dict_key = 'state_dict' + state_dict_key = '' if isinstance(checkpoint, dict): if use_ema and 'state_dict_ema' in checkpoint: state_dict_key = 'state_dict_ema' - if state_dict_key and state_dict_key in checkpoint: + elif use_ema and 'model_ema' in checkpoint: + state_dict_key = 'model_ema' + elif 'state_dict' in checkpoint: + state_dict_key = 'state_dict' + elif 'model' in checkpoint: + state_dict_key = 'model' + if state_dict_key: + state_dict = checkpoint[state_dict_key] new_state_dict = OrderedDict() - for k, v in checkpoint[state_dict_key].items(): + for k, v in state_dict.items(): # strip `module.` prefix name = k[7:] if k.startswith('module') else k new_state_dict[name] = v diff --git a/train.py b/train.py index cca814fd..1e95c831 100755 --- a/train.py +++ b/train.py @@ -30,7 +30,8 @@ import torchvision.utils from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\ TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn -from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\ + PreprocessCfg, AugCfg, MixupCfg, AugMixDataset from timm.models import create_model, safe_model_name, convert_splitbn_model from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import optimizer_kwargs @@ -283,10 +284,11 @@ def main(): else: _logger.info('Training with a single process on 1 device.') - mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None - random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA) + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense' + train_state = setup_train_task(args, dev_env, mixup_active) train_cfg = train_state.train_cfg @@ -421,11 +423,9 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') - # setup augmentation batch splits for contrastive loss or split bn - assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense' # enable split bn (separate bn stats per batch-portion) if args.split_bn: - assert args.aug_splits > 1 or args.resplit + assert args.aug_splits > 1 model = convert_splitbn_model(model, max(args.aug_splits, 2)) train_state = setup_model_and_optimizer( @@ -481,7 +481,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): return train_state -def setup_data(args, default_cfg, dev_env, mixup_active): +def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool): data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary) # create the train and eval datasets @@ -489,18 +489,18 @@ def setup_data(args, default_cfg, dev_env, mixup_active): args.dataset, root=args.data_dir, split=args.train_split, is_training=True, batch_size=args.batch_size, repeats=args.epoch_repeats) + dataset_eval = create_dataset( - args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) + args.dataset, + root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) # setup mixup / cutmix - collate_fn = None + mixup_cfg = None if mixup_active: - mixup_args = dict( - mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + mixup_cfg = MixupCfg( prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, label_smoothing=args.smoothing, num_classes=args.num_classes) - assert not args.aug_splits # collate conflict (need to support deinterleaving in collate mixup) - collate_fn = FastCollateMixup(**mixup_args) # wrap dataset in AugMix helper if args.aug_splits > 1: @@ -510,46 +510,72 @@ def setup_data(args, default_cfg, dev_env, mixup_active): train_interpolation = args.train_interpolation if args.no_aug or not train_interpolation: train_interpolation = data_config['interpolation'] - loader_train = create_loader( - dataset_train, + + if args.no_aug: + train_aug_cfg = None + else: + train_aug_cfg = AugCfg( + re_prob=args.reprob, + re_mode=args.remode, + re_count=args.recount, + ratio_range=args.ratio, + scale_range=args.scale, + hflip_prob=args.hflip, + vflip_prob=args.vflip, + color_jitter=args.color_jitter, + auto_augment=args.aa, + num_aug_splits=args.aug_splits, + ) + + train_pp_cfg = PreprocessCfg( input_size=data_config['input_size'], - batch_size=args.batch_size, - is_training=True, - no_aug=args.no_aug, - re_prob=args.reprob, - re_mode=args.remode, - re_count=args.recount, - re_split=args.resplit, - scale=args.scale, - ratio=args.ratio, - hflip=args.hflip, - vflip=args.vflip, - color_jitter=args.color_jitter, - auto_augment=args.aa, - num_aug_splits=args.aug_splits, interpolation=train_interpolation, + crop_pct=data_config['crop_pct'], mean=data_config['mean'], std=data_config['std'], + aug=train_aug_cfg, + ) + + # if using PyTorch XLA and RandomErasing is enabled, we must normalize and do RE in transforms on CPU + normalize_in_transform = dev_env.type_xla and args.reprob > 0 + + dataset_train.transform = create_transform_v2( + cfg=train_pp_cfg, is_training=True, normalize=normalize_in_transform) + + loader_train = create_loader_v2( + dataset_train, + batch_size=args.batch_size, + is_training=True, + normalize=not normalize_in_transform, + pp_cfg=train_pp_cfg, + mix_cfg=mixup_cfg, num_workers=args.workers, - collate_fn=collate_fn, pin_memory=args.pin_mem, use_multi_epochs_loader=args.use_multi_epochs_loader ) + eval_pp_cfg = PreprocessCfg( + input_size=data_config['input_size'], + interpolation=data_config['interpolation'], + crop_pct=data_config['crop_pct'], + mean=data_config['mean'], + std=data_config['std'], + ) + + dataset_eval.transform = create_transform_v2( + cfg=eval_pp_cfg, is_training=False, normalize=normalize_in_transform) + eval_workers = args.workers if 'tfds' in args.dataset: # FIXME reduce validation issues when using TFDS w/ workers and distributed training eval_workers = min(2, args.workers) - loader_eval = create_loader( + loader_eval = create_loader_v2( dataset_eval, - input_size=data_config['input_size'], batch_size=args.validation_batch_size_multiplier * args.batch_size, is_training=False, - interpolation=data_config['interpolation'], - mean=data_config['mean'], - std=data_config['std'], + normalize=not normalize_in_transform, + pp_cfg=eval_pp_cfg, num_workers=eval_workers, - crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, ) return data_config, loader_eval, loader_train @@ -700,8 +726,12 @@ def evaluate( loss = loss_fn(output, target) # FIXME, explictly marking step for XLA use since I'm not using the parallel xm loader - # need to investigate whether parallel loader wrapper is helpful on tpu-vm or only usefor for 2-vm setup. - dev_env.mark_step() + # need to investigate whether parallel loader wrapper is helpful on tpu-vm or only use for 2-vm setup. + if dev_env.type_xla: + dev_env.mark_step() + elif dev_env.type_cuda: + dev_env.synchronize() + tracker.mark_iter_step_end() losses_m.update(loss, output.size(0)) accuracy_m.update(output, target) diff --git a/validate.py b/validate.py index cee359c3..f4dc84e8 100755 --- a/validate.py +++ b/validate.py @@ -20,7 +20,8 @@ from collections import OrderedDict from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models -from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet +from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config, RealLabelsImagenet, \ + PreprocessCfg from timm.utils import natural_key, setup_default_logging @@ -141,18 +142,22 @@ def validate(args): else: real_labels = None - crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] - loader = create_loader( - dataset, + eval_pp_cfg = PreprocessCfg( input_size=data_config['input_size'], - batch_size=args.batch_size, interpolation=data_config['interpolation'], + crop_pct=1.0 if test_time_pool else data_config['crop_pct'], mean=data_config['mean'], std=data_config['std'], + ) + + dataset.transform = create_transform_v2(cfg=eval_pp_cfg, is_training=False) + + loader = create_loader_v2( + dataset, + batch_size=args.batch_size, + pp_cfg=eval_pp_cfg, num_workers=args.workers, - crop_pct=crop_pct, - pin_memory=args.pin_mem, - tf_preprocessing=args.tf_preprocessing) + pin_memory=args.pin_mem) logger = Monitor(logger=_logger) tracker = Tracker() @@ -175,16 +180,17 @@ def validate(args): loss = criterion(output, target) if dev_env.type_cuda: - torch.cuda.synchronize() + dev_env.synchronize() tracker.mark_iter_step_end() - losses.update(loss.detach(), sample.size(0)) + if dev_env.type_xla: + dev_env.mark_step() + if real_labels is not None: real_labels.add_result(output) - accuracy.update(output.detach(), target) - if dev_env.type_xla: - dev_env.mark_step() + losses.update(loss.detach(), sample.size(0)) + accuracy.update(output.detach(), target) tracker.mark_iter() if step_idx % args.log_freq == 0: @@ -212,7 +218,7 @@ def validate(args): top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], - cropt_pct=crop_pct, + cropt_pct=eval_pp_cfg.crop_pct, interpolation=data_config['interpolation']) logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results)