Transforms, augmentation work for bits, add RandomErasing support for XLA (pushing into transforms), revamp of transform/preproc config, etc ongoing...

pull/1239/head
Ross Wightman 3 years ago
parent 847b4af144
commit 40457e5691

@ -14,6 +14,9 @@ import hashlib
import shutil import shutil
from collections import OrderedDict from collections import OrderedDict
from timm.models.helpers import load_state_dict
from timm.utils import setup_default_logging
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
@ -29,6 +32,7 @@ _TEMP_NAME = './_checkpoint.pth'
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
setup_default_logging()
if os.path.exists(args.output): if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
@ -37,17 +41,8 @@ def main():
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if args.checkpoint and os.path.isfile(args.checkpoint): if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> Loading checkpoint '{}'".format(args.checkpoint)) print("=> Loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint, map_location='cpu') state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema)
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
if state_dict_key in checkpoint:
state_dict = checkpoint[state_dict_key]
else:
state_dict = checkpoint
else:
assert False
for k, v in state_dict.items(): for k, v in state_dict.items():
if args.clean_aux_bn and 'aux_bn' in k: if args.clean_aux_bn and 'aux_bn' in k:
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and # If all aux_bn keys are removed, the SplitBN layers will end up as normal and

@ -13,7 +13,7 @@ import numpy as np
import torch import torch
from timm.models import create_model, apply_test_time_pool from timm.models import create_model, apply_test_time_pool
from timm.data import ImageDataset, create_loader, resolve_data_config from timm.data import ImageDataset, create_loader_v2, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -82,7 +82,7 @@ def main():
else: else:
model = model.cuda() model = model.cuda()
loader = create_loader( loader = create_loader_v2(
ImageDataset(args.data), ImageDataset(args.data),
input_size=config['input_size'], input_size=config['input_size'],
batch_size=args.batch_size, batch_size=args.batch_size,

@ -128,6 +128,9 @@ class DeviceEnv:
def mark_step(self): def mark_step(self):
pass # NO-OP for non-XLA devices pass # NO-OP for non-XLA devices
def synchronize(self, tensors: Optional[TensorList] = None):
pass
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False): def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
dist.all_reduce(tensor, op=op) dist.all_reduce(tensor, op=op)
if average: if average:

@ -6,7 +6,7 @@ from typing import Optional
import torch import torch
from torch.nn.parallel import DistributedDataParallel, DataParallel from torch.nn.parallel import DistributedDataParallel, DataParallel
from .device_env import DeviceEnv, DeviceEnvType from .device_env import DeviceEnv, DeviceEnvType, TensorList
def is_cuda_available(): def is_cuda_available():
@ -63,3 +63,6 @@ class DeviceEnvCuda(DeviceEnv):
assert not self.distributed assert not self.distributed
wrapped = [DataParallel(m, **kwargs) for m in modules] wrapped = [DataParallel(m, **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped return wrapped[0] if len(wrapped) == 1 else wrapped
def synchronize(self, tensors: Optional[TensorList] = None):
torch.cuda.synchronize(self.device)

@ -8,9 +8,11 @@ from torch.distributed import ReduceOp
try: try:
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla
_HAS_XLA = True _HAS_XLA = True
except ImportError as e: except ImportError as e:
xm = None xm = None
torch_xla = None
_HAS_XLA = False _HAS_XLA = False
try: try:
@ -81,6 +83,9 @@ class DeviceEnvXla(DeviceEnv):
def mark_step(self): def mark_step(self):
xm.mark_step() xm.mark_step()
def synchronize(self, tensors: Optional[TensorList] = None):
torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True)
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False): def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
op = _PT_TO_XM_OP[op] op = _PT_TO_XM_OP[op]

@ -89,7 +89,6 @@ def setup_model_and_optimizer(
train_state = TrainState(model=model, updater=updater, model_ema=model_ema) train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path: if resume_path:
# FIXME this is not implemented yet, do a hack job before proper TrainState serialization?
load_train_state( load_train_state(
train_state, train_state,
resume_path, resume_path,
@ -141,11 +140,7 @@ def setup_model_and_optimizer_deepspeed(
if resume_path: if resume_path:
# FIXME deepspeed resumes differently # FIXME deepspeed resumes differently
load_legacy_checkpoint( assert False
train_state,
resume_path,
load_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed: if dev_env.distributed:
train_state = dataclasses.replace( train_state = dataclasses.replace(

@ -4,9 +4,9 @@ from .config import resolve_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .loader import create_loader from .loader import create_loader_v2, PreprocessCfg, AugCfg, MixupCfg
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser from .parsers import create_parser
from .real_labels import RealLabelsImagenet from .real_labels import RealLabelsImagenet
from .transforms import * from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy
from .transforms_factory import create_transform from .transforms_factory import create_transform_v2, create_transform

@ -41,6 +41,22 @@ _HPARAMS_DEFAULT = dict(
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) _RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
def _pil_interp(method):
def _convert(m):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
return Image.BILINEAR
if isinstance(method, (list, tuple)):
return [_convert(m) if isinstance(m, str) else m for m in method]
else:
return _convert(method) if isinstance(method, str) else method
def _interpolation(kwargs): def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.BILINEAR) interpolation = kwargs.pop('resample', Image.BILINEAR)
if isinstance(interpolation, (list, tuple)): if isinstance(interpolation, (list, tuple)):
@ -325,7 +341,7 @@ class AugmentOp:
self.hparams = hparams.copy() self.hparams = hparams.copy()
self.kwargs = dict( self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, resample=_pil_interp(hparams['interpolation']) if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
) )
# If magnitude_std is > 0, we introduce some randomness # If magnitude_std is > 0, we introduce some randomness

@ -30,7 +30,7 @@ def fast_collate(batch):
elif isinstance(batch[0][0], torch.Tensor): elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=batch[0][0].dtype)
for i in range(batch_size): for i in range(batch_size):
tensor[i].copy_(batch[i][0]) tensor[i].copy_(batch[i][0])
return tensor, targets return tensor, targets

@ -1,10 +1,53 @@
import logging import logging
from dataclasses import dataclass
from typing import Tuple, Optional, Union
from .constants import * from .constants import *
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@dataclass
class AugCfg:
scale_range: Tuple[float, float] = (0.08, 1.0)
ratio_range: Tuple[float, float] = (3 / 4, 4 / 3)
hflip_prob: float = 0.5
vflip_prob: float = 0.
color_jitter: float = 0.4
auto_augment: Optional[str] = None
re_prob: float = 0.
re_mode: str = 'const'
re_count: int = 1
num_aug_splits: int = 0
@dataclass
class PreprocessCfg:
input_size: Tuple[int, int, int] = (3, 224, 224)
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD
interpolation: str = 'bilinear'
crop_pct: float = 0.875
aug: AugCfg = None
@dataclass
class MixupCfg:
prob: float = 1.0
switch_prob: float = 0.5
mixup_alpha: float = 1.
cutmix_alpha: float = 0.
cutmix_minmax: Optional[Tuple[float, float]] = None
mode: str = 'batch'
correct_lam: bool = True
label_smoothing: float = 0.1
num_classes: int = 0
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False): def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {} new_config = {}
default_cfg = default_cfg default_cfg = default_cfg

@ -2,7 +2,7 @@ import torch
from .constants import * from .constants import *
from .random_erasing import RandomErasing from .random_erasing import RandomErasing
from. mixup import FastCollateMixup from .mixup import FastCollateMixup
class FetcherXla: class FetcherXla:
@ -12,31 +12,55 @@ class FetcherXla:
class Fetcher: class Fetcher:
def __init__(self, def __init__(
loader, self,
mean=IMAGENET_DEFAULT_MEAN, loader,
std=IMAGENET_DEFAULT_STD, device: torch.device,
device=None, dtype=torch.float32,
dtype=None, normalize=True,
re_prob=0., normalize_shape=(1, 3, 1, 1),
re_mode='const', mean=IMAGENET_DEFAULT_MEAN,
re_count=1, std=IMAGENET_DEFAULT_STD,
re_num_splits=0): re_prob=0.,
re_mode='const',
re_count=1,
num_aug_splits=0,
use_mp_loader=False,
):
self.loader = loader self.loader = loader
self.device = torch.device(device) self.device = torch.device(device)
self.dtype = dtype or torch.float32 self.dtype = dtype
self.mean = torch.tensor([x * 255 for x in mean], dtype=self.dtype, device=self.device).view(1, 3, 1, 1) if normalize:
self.std = torch.tensor([x * 255 for x in std], dtype=self.dtype, device=self.device).view(1, 3, 1, 1) self.mean = torch.tensor(
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.: if re_prob > 0.:
# NOTE RandomErasing shouldn't be used here w/ XLA devices
self.random_erasing = RandomErasing( self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device=device) probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
else: else:
self.random_erasing = None self.random_erasing = None
self.use_mp_loader = use_mp_loader
if use_mp_loader:
# FIXME testing for TPU use
import torch_xla.distributed.parallel_loader as pl
self._loader = pl.MpDeviceLoader(loader, device)
else:
self._loader = loader
print('re', self.random_erasing, self.mean, self.std)
def __iter__(self): def __iter__(self):
for sample, target in self.loader: for sample, target in self._loader:
sample = sample.to(device=self.device, dtype=self.dtype).sub_(self.mean).div_(self.std) if not self.use_mp_loader:
target = target.to(device=self.device) sample = sample.to(device=self.device)
target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype)
if self.mean is not None:
sample.sub_(self.mean).div_(self.std)
if self.random_erasing is not None: if self.random_erasing is not None:
sample = self.random_erasing(sample) sample = self.random_erasing(sample)
yield sample, target yield sample, target

@ -6,74 +6,52 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from typing import Tuple, Optional, Union, Callable
import torch.utils.data import torch.utils.data
from timm.bits import DeviceEnv from timm.bits import DeviceEnv
from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda
from .collate import fast_collate from .collate import fast_collate
from .transforms_factory import create_transform from .config import PreprocessCfg, AugCfg, MixupCfg
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler from .distributed_sampler import OrderedDistributedSampler
from .fetcher import Fetcher
from .mixup import FastCollateMixup
from .prefetcher_cuda import PrefetcherCuda
def create_loader( def create_loader_v2(
dataset, dataset: torch.utils.data.Dataset,
input_size, batch_size: int,
batch_size, is_training: bool = False,
is_training=False, dev_env: Optional[DeviceEnv] = None,
dev_env=None, normalize=True,
no_aug=False, pp_cfg: PreprocessCfg = PreprocessCfg(),
re_prob=0., mix_cfg: MixupCfg = None,
re_mode='const', num_workers: int = 1,
re_count=1, collate_fn: Optional[Callable] = None,
re_split=False, pin_memory: bool = False,
scale=None, use_multi_epochs_loader: bool = False,
ratio=None, persistent_workers: bool = True,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
crop_pct=None,
collate_fn=None,
pin_memory=False,
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
): ):
re_num_splits = 0 """
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split Args:
re_num_splits = num_aug_splits or 2 dataset:
dataset.transform = create_transform( batch_size:
input_size, is_training:
is_training=is_training, dev_env:
use_fetcher=True, normalize:
no_aug=no_aug, pp_cfg:
scale=scale, mix_cfg:
ratio=ratio, num_workers:
hflip=hflip, collate_fn:
vflip=vflip, pin_memory:
color_jitter=color_jitter, use_multi_epochs_loader:
auto_augment=auto_augment, persistent_workers:
interpolation=interpolation,
mean=mean, Returns:
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=num_aug_splits > 0,
)
"""
if dev_env is None: if dev_env is None:
dev_env = DeviceEnv.instance() dev_env = DeviceEnv.instance()
@ -85,10 +63,24 @@ def create_loader(
else: else:
# This will add extra duplicate entries to result in equal num # This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results # of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank) sampler = OrderedDistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
if collate_fn is None: if collate_fn is None:
collate_fn = fast_collate if mix_cfg is not None and mix_cfg.prob > 0:
collate_fn = FastCollateMixup(
mixup_alpha=mix_cfg.mixup_alpha,
cutmix_alpha=mix_cfg.cutmix_alpha,
cutmix_minmax=mix_cfg.cutmix_minmax,
prob=mix_cfg.prob,
switch_prob=mix_cfg.switch_prob,
mode=mix_cfg.mode,
correct_lam=mix_cfg.correct_lam,
label_smoothing=mix_cfg.label_smoothing,
num_classes=mix_cfg.num_classes,
)
else:
collate_fn = fast_collate
loader_class = torch.utils.data.DataLoader loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader: if use_multi_epochs_loader:
@ -110,13 +102,18 @@ def create_loader(
loader = loader_class(dataset, **loader_args) loader = loader_class(dataset, **loader_args)
fetcher_kwargs = dict( fetcher_kwargs = dict(
mean=mean, normalize=normalize,
std=std, mean=pp_cfg.mean,
re_prob=re_prob if is_training and not no_aug else 0., std=pp_cfg.std,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
) )
if normalize and is_training and pp_cfg.aug is not None:
fetcher_kwargs.update(dict(
re_prob=pp_cfg.aug.re_prob,
re_mode=pp_cfg.aug.re_mode,
re_count=pp_cfg.aug.re_count,
num_aug_splits=pp_cfg.aug.num_aug_splits,
))
if dev_env.type_cuda: if dev_env.type_cuda:
loader = PrefetcherCuda(loader, **fetcher_kwargs) loader = PrefetcherCuda(loader, **fetcher_kwargs)
else: else:

@ -102,7 +102,7 @@ class Mixup:
num_classes (int): number of classes for target num_classes (int): number of classes for target
""" """
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): mode='batch', correct_lam=True, label_smoothing=0., num_classes=0):
self.mixup_alpha = mixup_alpha self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax self.cutmix_minmax = cutmix_minmax
@ -113,6 +113,8 @@ class Mixup:
self.mix_prob = prob self.mix_prob = prob
self.switch_prob = switch_prob self.switch_prob = switch_prob
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
if label_smoothing > 0.:
assert num_classes > 0
self.num_classes = num_classes self.num_classes = num_classes
self.mode = mode self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
@ -218,17 +220,30 @@ class Mixup:
return x, target return x, target
def blend(a, b, lam, is_tensor=False, round_output=True):
if is_tensor:
blend = a.to(dtype=torch.float32) * lam + b.to(dtype=torch.float32) * (1 - lam)
if round_output:
torch.round(blend, out=blend)
else:
blend = a.astype(np.float32) * lam + b.astype(np.float32) * (1 - lam)
if round_output:
np.rint(blend, out=blend)
return blend
class FastCollateMixup(Mixup): class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
A Mixup impl that's performed while collating the batches. A Mixup impl that's performed while collating the batches.
""" """
def _mix_elem_collate(self, output, batch, half=False): def _mix_elem_collate(self, output, batch, half=False, is_tensor=False):
batch_size = len(batch) batch_size = len(batch)
num_elem = batch_size // 2 if half else batch_size num_elem = batch_size // 2 if half else batch_size
assert len(output) == num_elem assert len(output) == num_elem
lam_batch, use_cutmix = self._params_per_elem(num_elem) lam_batch, use_cutmix = self._params_per_elem(num_elem)
round_output = output.dtype == torch.uint8
for i in range(num_elem): for i in range(num_elem):
j = batch_size - i - 1 j = batch_size - i - 1
lam = lam_batch[i] lam = lam_batch[i]
@ -236,22 +251,23 @@ class FastCollateMixup(Mixup):
if lam != 1.: if lam != 1.:
if use_cutmix[i]: if use_cutmix[i]:
if not half: if not half:
mixed = mixed.copy() mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam( (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam lam_batch[i] = lam
else: else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output)
np.rint(mixed, out=mixed) mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8))
output[i] += torch.from_numpy(mixed.astype(np.uint8)) output[i].copy_(mixed)
if half: if half:
lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
return torch.tensor(lam_batch).unsqueeze(1) return torch.tensor(lam_batch).unsqueeze(1)
def _mix_pair_collate(self, output, batch): def _mix_pair_collate(self, output, batch, is_tensor=False):
batch_size = len(batch) batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
round_output = output.dtype == torch.uint8
for i in range(batch_size // 2): for i in range(batch_size // 2):
j = batch_size - i - 1 j = batch_size - i - 1
lam = lam_batch[i] lam = lam_batch[i]
@ -262,24 +278,30 @@ class FastCollateMixup(Mixup):
if use_cutmix[i]: if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam( (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
patch_i = mixed_i[:, yl:yh, xl:xh].copy() patch_i = mixed_i[:, yl:yh, xl:xh]
patch_i = patch_i.clone() if is_tensor else patch_i.copy() # don't want to modify while iterating
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
mixed_j[:, yl:yh, xl:xh] = patch_i mixed_j[:, yl:yh, xl:xh] = patch_i
lam_batch[i] = lam lam_batch[i] = lam
else: else:
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) mixed_temp = blend(mixed_i, mixed_j, lam, is_tensor, round_output)
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) mixed_j = blend(mixed_j, mixed_i, lam, is_tensor, round_output)
mixed_i = mixed_temp mixed_i = mixed_temp
np.rint(mixed_j, out=mixed_j) if is_tensor:
np.rint(mixed_i, out=mixed_i) mixed_i = mixed_i.to(dtype=output.dtype)
output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) mixed_j = mixed_j.to(dtype=output.dtype)
output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) else:
mixed_i = torch.from_numpy(mixed_i.astype(np.uint8))
mixed_j = torch.from_numpy(mixed_j.astype(np.uint8))
output[i].copy_(mixed_i)
output[j].copy_(mixed_j)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return torch.tensor(lam_batch).unsqueeze(1) return torch.tensor(lam_batch).unsqueeze(1)
def _mix_batch_collate(self, output, batch): def _mix_batch_collate(self, output, batch, is_tensor=False):
batch_size = len(batch) batch_size = len(batch)
lam, use_cutmix = self._params_per_batch() lam, use_cutmix = self._params_per_batch()
round_output = output.dtype == torch.uint8
if use_cutmix: if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam( (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
@ -288,12 +310,12 @@ class FastCollateMixup(Mixup):
mixed = batch[i][0] mixed = batch[i][0]
if lam != 1.: if lam != 1.:
if use_cutmix: if use_cutmix:
mixed = mixed.copy() # don't want to modify the original while iterating mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else: else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output)
np.rint(mixed, out=mixed) mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8))
output[i] += torch.from_numpy(mixed.astype(np.uint8)) output[i].copy_(mixed)
return lam return lam
def __call__(self, batch, _=None): def __call__(self, batch, _=None):
@ -302,13 +324,15 @@ class FastCollateMixup(Mixup):
half = 'half' in self.mode half = 'half' in self.mode
if half: if half:
batch_size //= 2 batch_size //= 2
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) is_tensor = isinstance(batch[0][0], torch.Tensor)
output_dtype = batch[0][0].dtype if is_tensor else torch.uint8 # always uint8 if numpy src
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=output_dtype)
if self.mode == 'elem' or self.mode == 'half': if self.mode == 'elem' or self.mode == 'half':
lam = self._mix_elem_collate(output, batch, half=half) lam = self._mix_elem_collate(output, batch, half=half, is_tensor=is_tensor)
elif self.mode == 'pair': elif self.mode == 'pair':
lam = self._mix_pair_collate(output, batch) lam = self._mix_pair_collate(output, batch, is_tensor=is_tensor)
else: else:
lam = self._mix_batch_collate(output, batch) lam = self._mix_batch_collate(output, batch, is_tensor=is_tensor)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
target = target[:batch_size] target = target[:batch_size]

@ -7,25 +7,34 @@ from .random_erasing import RandomErasing
class PrefetcherCuda: class PrefetcherCuda:
def __init__(self, def __init__(
loader, self,
mean=IMAGENET_DEFAULT_MEAN, loader,
std=IMAGENET_DEFAULT_STD, device: torch.device = torch.device('cuda'),
fp16=False, dtype=torch.float32,
re_prob=0., normalize=True,
re_mode='const', normalize_shape=(1, 3, 1, 1),
re_count=1, mean=IMAGENET_DEFAULT_MEAN,
re_num_splits=0): std=IMAGENET_DEFAULT_STD,
num_aug_splits=0,
re_prob=0.,
re_mode='const',
re_count=1
):
self.loader = loader self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.device = device
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) self.dtype = dtype
self.fp16 = fp16 if normalize:
if fp16: self.mean = torch.tensor(
self.mean = self.mean.half() [x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = self.std.half() self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.: if re_prob > 0.:
self.random_erasing = RandomErasing( self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits, device=device)
else: else:
self.random_erasing = None self.random_erasing = None
@ -35,12 +44,11 @@ class PrefetcherCuda:
for next_input, next_target in self.loader: for next_input, next_target in self.loader:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True) next_input = next_input.to(device=self.device, non_blocking=True)
next_target = next_target.cuda(non_blocking=True) next_input = next_input.to(dtype=self.dtype)
if self.fp16: if self.mean is not None:
next_input = next_input.half().sub_(self.mean).div_(self.std) next_input.sub_(self.mean).div_(self.std)
else: next_target = next_target.to(device=self.device, non_blocking=True)
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None: if self.random_erasing is not None:
next_input = self.random_erasing(next_input) next_input = self.random_erasing(next_input)
@ -76,4 +84,4 @@ class PrefetcherCuda:
@mixup_enabled.setter @mixup_enabled.setter
def mixup_enabled(self, x): def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup): if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x self.loader.collate_fn.mixup_enabled = x

@ -38,21 +38,20 @@ class RandomErasing:
'const' - erase block is constant color of 0 for all channels 'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-channel random (normal) color 'rand' - erase block is same per-channel random (normal) color
'pixel' - erase block is per-pixel random (normal) color 'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count. count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value. per-image count is randomly chosen between 1 and this value.
""" """
def __init__( def __init__(
self, self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None, probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'): mode='const', count=1, num_splits=0):
self.probability = probability self.probability = probability
self.min_area = min_area self.min_area = min_area
self.max_area = max_area self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count self.count = count
self.max_count = max_count or min_count
self.num_splits = num_splits self.num_splits = num_splits
mode = mode.lower() mode = mode.lower()
self.rand_color = False self.rand_color = False
@ -63,14 +62,13 @@ class RandomErasing:
self.per_pixel = True # per pixel random normal self.per_pixel = True # per pixel random normal
else: else:
assert not mode or mode == 'const' assert not mode or mode == 'const'
self.device = device
def _erase(self, img, chan, img_h, img_w, dtype): def _erase(self, img, chan, img_h, img_w, dtype):
device = img.device
if random.random() > self.probability: if random.random() > self.probability:
return return
area = img_h * img_w area = img_h * img_w
count = self.min_count if self.min_count == self.max_count else \ count = random.randint(1, self.count) if self.count > 1 else self.count
random.randint(self.min_count, self.max_count)
for _ in range(count): for _ in range(count):
for attempt in range(10): for attempt in range(10):
target_area = random.uniform(self.min_area, self.max_area) * area / count target_area = random.uniform(self.min_area, self.max_area) * area / count
@ -81,17 +79,76 @@ class RandomErasing:
top = random.randint(0, img_h - h) top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w) left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels( img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w), self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=device)
dtype=dtype, device=self.device)
break break
def __call__(self, input): def __call__(self, x):
if len(input.size()) == 3: if len(x.size()) == 3:
self._erase(input, *input.size(), input.dtype) self._erase(x, *x.shape, x.dtype)
else: else:
batch_size, chan, img_h, img_w = input.size() batch_size, chan, img_h, img_w = x.shape
# skip first slice of batch if num_splits is set (for clean portion of samples) # skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0 batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
for i in range(batch_start, batch_size): for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype) self._erase(x[i], chan, img_h, img_w, x.dtype)
return input return x
class RandomErasingMasked:
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
This variant of RandomErasing is intended to be applied to either a batch
or single image tensor after it has been normalized by dataset mean and std.
Args:
probability: Probability that the Random Erasing operation will be performed for each box (count)
min_area: Minimum percentage of erased area wrt input image area.
max_area: Maximum percentage of erased area wrt input image area.
min_aspect: Minimum aspect ratio of erased area.
count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is between 0 and this value.
"""
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', count=1, num_splits=0):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.mode = mode # FIXME currently ignored, add back options besides normal mean=0, std=1 noise?
self.count = count
self.num_splits = num_splits
@torch.no_grad()
def __call__(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
batch_size, _, img_h, img_w = x.shape
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
# NOTE simplified from v1 with with one count value and same prob applied for all
enable = (torch.empty((batch_size, self.count), device=device).uniform_() < self.probability).float()
enable = enable / enable.sum(dim=1, keepdim=True).clamp(min=1)
target_area = torch.empty(
(batch_size, self.count), device=device).uniform_(self.min_area, self.max_area) * enable
aspect_ratio = torch.empty((batch_size, self.count), device=device).uniform_(*self.log_aspect_ratio).exp()
h_coord = torch.arange(0, img_h, device=device).unsqueeze(-1).expand(-1, self.count).float()
w_coord = torch.arange(0, img_w, device=device).unsqueeze(-1).expand(-1, self.count).float()
h_mid = torch.rand((batch_size, self.count), device=device) * img_h
w_mid = torch.rand((batch_size, self.count), device=device) * img_w
noise = torch.empty_like(x[0]).normal_()
for i in range(batch_start, batch_size):
h_half = (img_h / 2) * torch.sqrt(target_area[i] * aspect_ratio[i]) # 1/2 box h
h_mask = (h_coord > (h_mid[i] - h_half)) & (h_coord < (h_mid[i] + h_half))
w_half = (img_w / 2) * torch.sqrt(target_area[i] / aspect_ratio[i]) # 1/2 box w
w_mask = (w_coord > (w_mid[i] - w_half)) & (w_coord < (w_mid[i] + w_half))
#mask = (h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1)
#x[i].copy_(torch.where(mask, noise, x[i]))
mask = ~(h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1)
x[i] = x[i].where(mask, noise)
#x[i].masked_scatter_(mask, noise)
return x

@ -1,5 +1,7 @@
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from PIL import Image from PIL import Image
import warnings import warnings
import math import math
@ -30,29 +32,40 @@ class ToTensor:
return torch.from_numpy(np_img).to(dtype=self.dtype) return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = { class ToTensorNormalize:
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
def __init__(self, mean, std, dtype=torch.float32, device=torch.device('cpu')):
self.dtype = dtype
mean = torch.as_tensor(mean, dtype=dtype, device=device)
std = torch.as_tensor(std, dtype=dtype, device=device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
self.mean = mean
self.std = std
def _pil_interp(method): def __call__(self, pil_img):
if method == 'bicubic': mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
return Image.BICUBIC img = torch.from_numpy(
elif method == 'lanczos': np.array(pil_img, mode_to_nptype.get(pil_img.mode, np.uint8))
return Image.LANCZOS )
elif method == 'hamming': if pil_img.mode == '1':
return Image.HAMMING img = 255 * img
else: img = img.view(pil_img.size[1], pil_img.size[0], len(pil_img.getbands()))
# default bilinear, do we want to allow nearest? img = img.permute((2, 0, 1))
return Image.BILINEAR if isinstance(img, torch.ByteTensor):
img = img.to(self.dtype)
img.sub_(self.mean * 255.).div_(self.std * 255.)
else:
img = img.to(self.dtype)
img.sub_(self.mean).div_(self.std)
return img
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) _RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
class RandomResizedCropAndInterpolation: class RandomResizedCropAndInterpolation:
@ -82,7 +95,7 @@ class RandomResizedCropAndInterpolation:
if interpolation == 'random': if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION self.interpolation = _RANDOM_INTERPOLATION
else: else:
self.interpolation = _pil_interp(interpolation) self.interpolation = InterpolationMode(interpolation)
self.scale = scale self.scale = scale
self.ratio = ratio self.ratio = ratio
@ -146,9 +159,9 @@ class RandomResizedCropAndInterpolation:
def __repr__(self): def __repr__(self):
if isinstance(self.interpolation, (tuple, list)): if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation]) interpolate_str = ' '.join([x.value for x in self.interpolation])
else: else:
interpolate_str = _pil_interpolation_to_str[self.interpolation] interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size) format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))

@ -4,59 +4,50 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import math import math
from typing import Union, Tuple
import torch import torch
from torchvision import transforms from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor from timm.data.config import PreprocessCfg, AugCfg
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.random_erasing import RandomErasing from timm.data.random_erasing import RandomErasing
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensorNormalize
def transforms_noaug_train( def transforms_noaug_train(
img_size=224, img_size: Union[int, Tuple[int]] = 224,
interpolation='bilinear', interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
normalize=False,
): ):
if interpolation == 'random': if interpolation == 'random':
# random interpolation not supported with no-aug # random interpolation not supported with no-aug
interpolation = 'bilinear' interpolation = 'bilinear'
tfl = [ tfl = [
transforms.Resize(img_size, _pil_interp(interpolation)), transforms.Resize(img_size, transforms.InterpolationMode(interpolation)),
transforms.CenterCrop(img_size) transforms.CenterCrop(img_size)
] ]
if use_prefetcher: if normalize:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [ tfl += [
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))
mean=torch.tensor(mean),
std=torch.tensor(std))
] ]
else:
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl += [ToNumpy()]
return transforms.Compose(tfl) return transforms.Compose(tfl)
def transforms_imagenet_train( def transforms_imagenet_train(
img_size=224, img_size: Union[int, Tuple[int]] = 224,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
interpolation='random', interpolation='random',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
re_prob=0., aug_cfg=AugCfg(),
re_mode='const', normalize=False,
re_count=1,
re_num_splits=0,
separate=False, separate=False,
): ):
""" """
@ -66,18 +57,24 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform * a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform * normalizes and converts the branches above with the third, final transform
""" """
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range scale_range = tuple(aug_cfg.scale_range or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range ratio_range = tuple(aug_cfg.ratio_range or (3. / 4., 4. / 3.)) # default imagenet ratio range
# 'primary' train transforms include random resize + crop w/ optional horizontal and vertical flipping aug.
# This is the core of standard ImageNet ResNet and Inception pre-processing
primary_tfl = [ primary_tfl = [
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)] RandomResizedCropAndInterpolation(img_size, scale=scale_range, ratio=ratio_range, interpolation=interpolation)]
if hflip > 0.: if aug_cfg.hflip_prob > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)] primary_tfl += [transforms.RandomHorizontalFlip(p=aug_cfg.hflip_prob)]
if vflip > 0.: if aug_cfg.vflip_prob > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)] primary_tfl += [transforms.RandomVerticalFlip(p=aug_cfg.vflip_prob)]
# 'secondary' transform stage includes either color jitter (could add lighting too) or auto-augmentations
# such as AutoAugment, RandAugment, AugMix, etc
secondary_tfl = [] secondary_tfl = []
if auto_augment: if aug_cfg.auto_augment:
assert isinstance(auto_augment, str) aa = aug_cfg.auto_augment
assert isinstance(aa, str)
if isinstance(img_size, (tuple, list)): if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size) img_size_min = min(img_size)
else: else:
@ -87,58 +84,63 @@ def transforms_imagenet_train(
img_mean=tuple([min(255, round(255 * x)) for x in mean]), img_mean=tuple([min(255, round(255 * x)) for x in mean]),
) )
if interpolation and interpolation != 'random': if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation) aa_params['interpolation'] = interpolation
if auto_augment.startswith('rand'): if aa.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)] secondary_tfl += [rand_augment_transform(aa, aa_params)]
elif auto_augment.startswith('augmix'): elif aa.startswith('augmix'):
aa_params['translate_pct'] = 0.3 aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)] secondary_tfl += [augment_and_mix_transform(aa, aa_params)]
else: else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)] secondary_tfl += [auto_augment_transform(aa, aa_params)]
elif color_jitter is not None: elif aug_cfg.color_jitter is not None:
# color jitter is enabled when not using AA # color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)): cj = aug_cfg.color_jitter
if isinstance(cj, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue # or 4 if also augmenting hue
assert len(color_jitter) in (3, 4) assert len(cj) in (3, 4)
else: else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3 cj = (float(cj),) * 3
secondary_tfl += [transforms.ColorJitter(*color_jitter)] secondary_tfl += [transforms.ColorJitter(*cj)]
# 'final' transform stage includes normalization, followed by optional random erasing and tensor conversion
final_tfl = [] final_tfl = []
if use_prefetcher: if normalize:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
else:
final_tfl += [ final_tfl += [
transforms.ToTensor(), ToTensorNormalize(mean=mean, std=std)
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
] ]
if re_prob > 0.: if aug_cfg.re_prob > 0.:
final_tfl.append( final_tfl.append(RandomErasing(
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu')) aug_cfg.re_prob,
mode=aug_cfg.re_mode,
count=aug_cfg.re_count,
num_splits=aug_cfg.num_aug_splits))
else:
# when normalize disabled, (pre)fetcher and collate will handle tensor conversion and normalize
final_tfl += [ToNumpy()]
if separate: if separate:
# return each transform stage separately
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl) return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else: else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl) return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def transforms_imagenet_eval( def transforms_imagenet_eval(
img_size=224, img_size: Union[int, Tuple[int]] = 224,
crop_pct=None, crop_pct=None,
interpolation='bilinear', interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD): std=IMAGENET_DEFAULT_STD,
normalize=False,
):
crop_pct = crop_pct or DEFAULT_CROP_PCT crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)): if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2 assert len(img_size) == 2
if img_size[-1] == img_size[-2]: if img_size[-1] == img_size[-2]:
# FIXME handle case where img is square and we want non aspect preserving resize
# fall-back to older behaviour so Resize scales to shortest edge if target is square # fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct)) scale_size = int(math.floor(img_size[0] / crop_pct))
else: else:
@ -147,27 +149,87 @@ def transforms_imagenet_eval(
scale_size = int(math.floor(img_size / crop_pct)) scale_size = int(math.floor(img_size / crop_pct))
tfl = [ tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)), transforms.Resize(scale_size, transforms.InterpolationMode(interpolation)),
transforms.CenterCrop(img_size), transforms.CenterCrop(img_size),
] ]
if use_prefetcher: if normalize:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
tfl += [ tfl += [
transforms.ToTensor(), ToTensorNormalize(mean=mean, std=std)
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
] ]
else:
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl += [ToNumpy()]
return transforms.Compose(tfl) return transforms.Compose(tfl)
def create_transform_v2(
cfg=PreprocessCfg(),
is_training=False,
normalize=False,
separate=False,
tf_preprocessing=False,
):
"""
Args:
cfg: Pre-processing configuration
is_training (bool): Create transform for training pre-processing
tf_preprocessing (bool): Use Tensorflow pre-processing (for validation)
normalize (bool): Enable normalization in transforms (otherwise handled by fetcher/pre-fetcher)
separate (bool): Return transforms separated into stages (for train)
Returns:
"""
input_size = cfg.input_size
if isinstance(input_size, (tuple, list)):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing:
assert not normalize, "Expecting normalization to be handled in (pre)fetcher w/ TF preprocessing"
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=cfg.interpolation)
else:
if is_training and cfg.aug is None:
assert not separate, "Cannot perform split augmentation with no_aug"
transform = transforms_noaug_train(
img_size,
interpolation=cfg.interpolation,
normalize=normalize,
mean=cfg.mean,
std=cfg.std)
elif is_training:
transform = transforms_imagenet_train(
img_size,
interpolation=cfg.interpolation,
mean=cfg.mean,
std=cfg.std,
aug_cfg=cfg.aug,
normalize=normalize,
separate=separate)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=cfg.interpolation,
crop_pct=cfg.crop_pct,
mean=cfg.mean,
std=cfg.std,
normalize=normalize,
)
return transform
def create_transform( def create_transform(
input_size, input_size,
is_training=False, is_training=False,
use_fetcher=False, use_prefetcher=False,
no_aug=False, no_aug=False,
scale=None, scale=None,
ratio=None, ratio=None,
@ -191,7 +253,8 @@ def create_transform(
else: else:
img_size = input_size img_size = input_size
if tf_preprocessing and use_fetcher: normalize_in_transform = not use_prefetcher
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing" assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform( transform = TfPreprocessTransform(
@ -202,35 +265,41 @@ def create_transform(
transform = transforms_noaug_train( transform = transforms_noaug_train(
img_size, img_size,
interpolation=interpolation, interpolation=interpolation,
use_prefetcher=use_fetcher,
mean=mean, mean=mean,
std=std) std=std,
normalize=normalize_in_transform,
)
elif is_training: elif is_training:
transform = transforms_imagenet_train( aug_cfg = AugCfg(
img_size, scale_range=scale,
scale=scale, ratio_range=ratio,
ratio=ratio, hflip_prob=hflip,
hflip=hflip, vflip_prob=vflip,
vflip=vflip,
color_jitter=color_jitter, color_jitter=color_jitter,
auto_augment=auto_augment, auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_fetcher,
mean=mean,
std=std,
re_prob=re_prob, re_prob=re_prob,
re_mode=re_mode, re_mode=re_mode,
re_count=re_count, re_count=re_count,
re_num_splits=re_num_splits, num_aug_splits=re_num_splits,
separate=separate) )
transform = transforms_imagenet_train(
img_size,
interpolation=interpolation,
mean=mean,
std=std,
aug_cfg=aug_cfg,
normalize=normalize_in_transform,
separate=separate
)
else: else:
assert not separate, "Separate transforms not supported for validation preprocessing" assert not separate, "Separate transforms not supported for validation pre-processing"
transform = transforms_imagenet_eval( transform = transforms_imagenet_eval(
img_size, img_size,
interpolation=interpolation, interpolation=interpolation,
use_prefetcher=use_fetcher,
mean=mean, mean=mean,
std=std, std=std,
crop_pct=crop_pct) crop_pct=crop_pct,
normalize=normalize_in_transform,
)
return transform return transform

@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__)
def load_state_dict(checkpoint_path, use_ema=False): def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint: if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema' state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint: elif use_ema and 'model_ema' in checkpoint:
state_dict_key = 'model_ema'
elif 'state_dict' in checkpoint:
state_dict_key = 'state_dict'
elif 'model' in checkpoint:
state_dict_key = 'model'
if state_dict_key:
state_dict = checkpoint[state_dict_key]
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items(): for k, v in state_dict.items():
# strip `module.` prefix # strip `module.` prefix
name = k[7:] if k.startswith('module') else k name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v new_state_dict[name] = v

@ -30,7 +30,8 @@ import torchvision.utils
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\ from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\
TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\
PreprocessCfg, AugCfg, MixupCfg, AugMixDataset
from timm.models import create_model, safe_model_name, convert_splitbn_model from timm.models import create_model, safe_model_name, convert_splitbn_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import optimizer_kwargs from timm.optim import optimizer_kwargs
@ -283,10 +284,11 @@ def main():
else: else:
_logger.info('Training with a single process on 1 device.') _logger.info('Training with a single process on 1 device.')
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA) random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA)
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'
train_state = setup_train_task(args, dev_env, mixup_active) train_state = setup_train_task(args, dev_env, mixup_active)
train_cfg = train_state.train_cfg train_cfg = train_state.train_cfg
@ -421,11 +423,9 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
_logger.info( _logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
# setup augmentation batch splits for contrastive loss or split bn
assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'
# enable split bn (separate bn stats per batch-portion) # enable split bn (separate bn stats per batch-portion)
if args.split_bn: if args.split_bn:
assert args.aug_splits > 1 or args.resplit assert args.aug_splits > 1
model = convert_splitbn_model(model, max(args.aug_splits, 2)) model = convert_splitbn_model(model, max(args.aug_splits, 2))
train_state = setup_model_and_optimizer( train_state = setup_model_and_optimizer(
@ -481,7 +481,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
return train_state return train_state
def setup_data(args, default_cfg, dev_env, mixup_active): def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool):
data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary) data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)
# create the train and eval datasets # create the train and eval datasets
@ -489,18 +489,18 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
args.dataset, args.dataset,
root=args.data_dir, split=args.train_split, is_training=True, root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats) batch_size=args.batch_size, repeats=args.epoch_repeats)
dataset_eval = create_dataset( dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size) args.dataset,
root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
# setup mixup / cutmix # setup mixup / cutmix
collate_fn = None mixup_cfg = None
if mixup_active: if mixup_active:
mixup_args = dict( mixup_cfg = MixupCfg(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
label_smoothing=args.smoothing, num_classes=args.num_classes) label_smoothing=args.smoothing, num_classes=args.num_classes)
assert not args.aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
# wrap dataset in AugMix helper # wrap dataset in AugMix helper
if args.aug_splits > 1: if args.aug_splits > 1:
@ -510,46 +510,72 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
train_interpolation = args.train_interpolation train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation: if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation'] train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train, if args.no_aug:
train_aug_cfg = None
else:
train_aug_cfg = AugCfg(
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
ratio_range=args.ratio,
scale_range=args.scale,
hflip_prob=args.hflip,
vflip_prob=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_splits=args.aug_splits,
)
train_pp_cfg = PreprocessCfg(
input_size=data_config['input_size'], input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_splits=args.aug_splits,
interpolation=train_interpolation, interpolation=train_interpolation,
crop_pct=data_config['crop_pct'],
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
aug=train_aug_cfg,
)
# if using PyTorch XLA and RandomErasing is enabled, we must normalize and do RE in transforms on CPU
normalize_in_transform = dev_env.type_xla and args.reprob > 0
dataset_train.transform = create_transform_v2(
cfg=train_pp_cfg, is_training=True, normalize=normalize_in_transform)
loader_train = create_loader_v2(
dataset_train,
batch_size=args.batch_size,
is_training=True,
normalize=not normalize_in_transform,
pp_cfg=train_pp_cfg,
mix_cfg=mixup_cfg,
num_workers=args.workers, num_workers=args.workers,
collate_fn=collate_fn,
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader use_multi_epochs_loader=args.use_multi_epochs_loader
) )
eval_pp_cfg = PreprocessCfg(
input_size=data_config['input_size'],
interpolation=data_config['interpolation'],
crop_pct=data_config['crop_pct'],
mean=data_config['mean'],
std=data_config['std'],
)
dataset_eval.transform = create_transform_v2(
cfg=eval_pp_cfg, is_training=False, normalize=normalize_in_transform)
eval_workers = args.workers eval_workers = args.workers
if 'tfds' in args.dataset: if 'tfds' in args.dataset:
# FIXME reduce validation issues when using TFDS w/ workers and distributed training # FIXME reduce validation issues when using TFDS w/ workers and distributed training
eval_workers = min(2, args.workers) eval_workers = min(2, args.workers)
loader_eval = create_loader( loader_eval = create_loader_v2(
dataset_eval, dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size_multiplier * args.batch_size, batch_size=args.validation_batch_size_multiplier * args.batch_size,
is_training=False, is_training=False,
interpolation=data_config['interpolation'], normalize=not normalize_in_transform,
mean=data_config['mean'], pp_cfg=eval_pp_cfg,
std=data_config['std'],
num_workers=eval_workers, num_workers=eval_workers,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
) )
return data_config, loader_eval, loader_train return data_config, loader_eval, loader_train
@ -700,8 +726,12 @@ def evaluate(
loss = loss_fn(output, target) loss = loss_fn(output, target)
# FIXME, explictly marking step for XLA use since I'm not using the parallel xm loader # FIXME, explictly marking step for XLA use since I'm not using the parallel xm loader
# need to investigate whether parallel loader wrapper is helpful on tpu-vm or only usefor for 2-vm setup. # need to investigate whether parallel loader wrapper is helpful on tpu-vm or only use for 2-vm setup.
dev_env.mark_step() if dev_env.type_xla:
dev_env.mark_step()
elif dev_env.type_cuda:
dev_env.synchronize()
tracker.mark_iter_step_end() tracker.mark_iter_step_end()
losses_m.update(loss, output.size(0)) losses_m.update(loss, output.size(0))
accuracy_m.update(output, target) accuracy_m.update(output, target)

@ -20,7 +20,8 @@ from collections import OrderedDict
from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config, RealLabelsImagenet, \
PreprocessCfg
from timm.utils import natural_key, setup_default_logging from timm.utils import natural_key, setup_default_logging
@ -141,18 +142,22 @@ def validate(args):
else: else:
real_labels = None real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] eval_pp_cfg = PreprocessCfg(
loader = create_loader(
dataset,
input_size=data_config['input_size'], input_size=data_config['input_size'],
batch_size=args.batch_size,
interpolation=data_config['interpolation'], interpolation=data_config['interpolation'],
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
mean=data_config['mean'], mean=data_config['mean'],
std=data_config['std'], std=data_config['std'],
)
dataset.transform = create_transform_v2(cfg=eval_pp_cfg, is_training=False)
loader = create_loader_v2(
dataset,
batch_size=args.batch_size,
pp_cfg=eval_pp_cfg,
num_workers=args.workers, num_workers=args.workers,
crop_pct=crop_pct, pin_memory=args.pin_mem)
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
logger = Monitor(logger=_logger) logger = Monitor(logger=_logger)
tracker = Tracker() tracker = Tracker()
@ -175,16 +180,17 @@ def validate(args):
loss = criterion(output, target) loss = criterion(output, target)
if dev_env.type_cuda: if dev_env.type_cuda:
torch.cuda.synchronize() dev_env.synchronize()
tracker.mark_iter_step_end() tracker.mark_iter_step_end()
losses.update(loss.detach(), sample.size(0)) if dev_env.type_xla:
dev_env.mark_step()
if real_labels is not None: if real_labels is not None:
real_labels.add_result(output) real_labels.add_result(output)
accuracy.update(output.detach(), target)
if dev_env.type_xla: losses.update(loss.detach(), sample.size(0))
dev_env.mark_step() accuracy.update(output.detach(), target)
tracker.mark_iter() tracker.mark_iter()
if step_idx % args.log_freq == 0: if step_idx % args.log_freq == 0:
@ -212,7 +218,7 @@ def validate(args):
top5=round(top5a, 4), top5_err=round(100 - top5a, 4), top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2), param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1], img_size=data_config['input_size'][-1],
cropt_pct=crop_pct, cropt_pct=eval_pp_cfg.crop_pct,
interpolation=data_config['interpolation']) interpolation=data_config['interpolation'])
logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results) logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results)

Loading…
Cancel
Save