Transforms, augmentation work for bits, add RandomErasing support for XLA (pushing into transforms), revamp of transform/preproc config, etc ongoing...

pull/1239/head
Ross Wightman 3 years ago
parent 847b4af144
commit 40457e5691

@ -14,6 +14,9 @@ import hashlib
import shutil
from collections import OrderedDict
from timm.models.helpers import load_state_dict
from timm.utils import setup_default_logging
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
@ -29,6 +32,7 @@ _TEMP_NAME = './_checkpoint.pth'
def main():
args = parser.parse_args()
setup_default_logging()
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
@ -37,17 +41,8 @@ def main():
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> Loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint, map_location='cpu')
state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema)
new_state_dict = OrderedDict()
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
if state_dict_key in checkpoint:
state_dict = checkpoint[state_dict_key]
else:
state_dict = checkpoint
else:
assert False
for k, v in state_dict.items():
if args.clean_aux_bn and 'aux_bn' in k:
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and

@ -13,7 +13,7 @@ import numpy as np
import torch
from timm.models import create_model, apply_test_time_pool
from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.data import ImageDataset, create_loader_v2, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True
@ -82,7 +82,7 @@ def main():
else:
model = model.cuda()
loader = create_loader(
loader = create_loader_v2(
ImageDataset(args.data),
input_size=config['input_size'],
batch_size=args.batch_size,

@ -128,6 +128,9 @@ class DeviceEnv:
def mark_step(self):
pass # NO-OP for non-XLA devices
def synchronize(self, tensors: Optional[TensorList] = None):
pass
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
dist.all_reduce(tensor, op=op)
if average:

@ -6,7 +6,7 @@ from typing import Optional
import torch
from torch.nn.parallel import DistributedDataParallel, DataParallel
from .device_env import DeviceEnv, DeviceEnvType
from .device_env import DeviceEnv, DeviceEnvType, TensorList
def is_cuda_available():
@ -63,3 +63,6 @@ class DeviceEnvCuda(DeviceEnv):
assert not self.distributed
wrapped = [DataParallel(m, **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def synchronize(self, tensors: Optional[TensorList] = None):
torch.cuda.synchronize(self.device)

@ -8,9 +8,11 @@ from torch.distributed import ReduceOp
try:
import torch_xla.core.xla_model as xm
import torch_xla
_HAS_XLA = True
except ImportError as e:
xm = None
torch_xla = None
_HAS_XLA = False
try:
@ -81,6 +83,9 @@ class DeviceEnvXla(DeviceEnv):
def mark_step(self):
xm.mark_step()
def synchronize(self, tensors: Optional[TensorList] = None):
torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True)
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
op = _PT_TO_XM_OP[op]

@ -89,7 +89,6 @@ def setup_model_and_optimizer(
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
# FIXME this is not implemented yet, do a hack job before proper TrainState serialization?
load_train_state(
train_state,
resume_path,
@ -141,11 +140,7 @@ def setup_model_and_optimizer_deepspeed(
if resume_path:
# FIXME deepspeed resumes differently
load_legacy_checkpoint(
train_state,
resume_path,
load_opt=resume_opt,
log_info=dev_env.primary)
assert False
if dev_env.distributed:
train_state = dataclasses.replace(

@ -4,9 +4,9 @@ from .config import resolve_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .loader import create_loader_v2, PreprocessCfg, AugCfg, MixupCfg
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform
from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy
from .transforms_factory import create_transform_v2, create_transform

@ -41,6 +41,22 @@ _HPARAMS_DEFAULT = dict(
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
def _pil_interp(method):
def _convert(m):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
return Image.BILINEAR
if isinstance(method, (list, tuple)):
return [_convert(m) if isinstance(m, str) else m for m in method]
else:
return _convert(method) if isinstance(method, str) else method
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', Image.BILINEAR)
if isinstance(interpolation, (list, tuple)):
@ -325,7 +341,7 @@ class AugmentOp:
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
resample=_pil_interp(hparams['interpolation']) if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)
# If magnitude_std is > 0, we introduce some randomness

@ -30,7 +30,7 @@ def fast_collate(batch):
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=batch[0][0].dtype)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets

@ -1,10 +1,53 @@
import logging
from dataclasses import dataclass
from typing import Tuple, Optional, Union
from .constants import *
_logger = logging.getLogger(__name__)
@dataclass
class AugCfg:
scale_range: Tuple[float, float] = (0.08, 1.0)
ratio_range: Tuple[float, float] = (3 / 4, 4 / 3)
hflip_prob: float = 0.5
vflip_prob: float = 0.
color_jitter: float = 0.4
auto_augment: Optional[str] = None
re_prob: float = 0.
re_mode: str = 'const'
re_count: int = 1
num_aug_splits: int = 0
@dataclass
class PreprocessCfg:
input_size: Tuple[int, int, int] = (3, 224, 224)
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD
interpolation: str = 'bilinear'
crop_pct: float = 0.875
aug: AugCfg = None
@dataclass
class MixupCfg:
prob: float = 1.0
switch_prob: float = 0.5
mixup_alpha: float = 1.
cutmix_alpha: float = 0.
cutmix_minmax: Optional[Tuple[float, float]] = None
mode: str = 'batch'
correct_lam: bool = True
label_smoothing: float = 0.1
num_classes: int = 0
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {}
default_cfg = default_cfg

@ -2,7 +2,7 @@ import torch
from .constants import *
from .random_erasing import RandomErasing
from. mixup import FastCollateMixup
from .mixup import FastCollateMixup
class FetcherXla:
@ -12,31 +12,55 @@ class FetcherXla:
class Fetcher:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
device=None,
dtype=None,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
def __init__(
self,
loader,
device: torch.device,
dtype=torch.float32,
normalize=True,
normalize_shape=(1, 3, 1, 1),
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
num_aug_splits=0,
use_mp_loader=False,
):
self.loader = loader
self.device = torch.device(device)
self.dtype = dtype or torch.float32
self.mean = torch.tensor([x * 255 for x in mean], dtype=self.dtype, device=self.device).view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std], dtype=self.dtype, device=self.device).view(1, 3, 1, 1)
self.dtype = dtype
if normalize:
self.mean = torch.tensor(
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.:
# NOTE RandomErasing shouldn't be used here w/ XLA devices
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device=device)
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
else:
self.random_erasing = None
self.use_mp_loader = use_mp_loader
if use_mp_loader:
# FIXME testing for TPU use
import torch_xla.distributed.parallel_loader as pl
self._loader = pl.MpDeviceLoader(loader, device)
else:
self._loader = loader
print('re', self.random_erasing, self.mean, self.std)
def __iter__(self):
for sample, target in self.loader:
sample = sample.to(device=self.device, dtype=self.dtype).sub_(self.mean).div_(self.std)
target = target.to(device=self.device)
for sample, target in self._loader:
if not self.use_mp_loader:
sample = sample.to(device=self.device)
target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype)
if self.mean is not None:
sample.sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
sample = self.random_erasing(sample)
yield sample, target

@ -6,74 +6,52 @@ https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#d
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Tuple, Optional, Union, Callable
import torch.utils.data
from timm.bits import DeviceEnv
from .fetcher import Fetcher
from .prefetcher_cuda import PrefetcherCuda
from .collate import fast_collate
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .config import PreprocessCfg, AugCfg, MixupCfg
from .distributed_sampler import OrderedDistributedSampler
from .fetcher import Fetcher
from .mixup import FastCollateMixup
from .prefetcher_cuda import PrefetcherCuda
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
dev_env=None,
no_aug=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_split=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
crop_pct=None,
collate_fn=None,
pin_memory=False,
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
def create_loader_v2(
dataset: torch.utils.data.Dataset,
batch_size: int,
is_training: bool = False,
dev_env: Optional[DeviceEnv] = None,
normalize=True,
pp_cfg: PreprocessCfg = PreprocessCfg(),
mix_cfg: MixupCfg = None,
num_workers: int = 1,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
use_multi_epochs_loader: bool = False,
persistent_workers: bool = True,
):
re_num_splits = 0
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = num_aug_splits or 2
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_fetcher=True,
no_aug=no_aug,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=num_aug_splits > 0,
)
"""
Args:
dataset:
batch_size:
is_training:
dev_env:
normalize:
pp_cfg:
mix_cfg:
num_workers:
collate_fn:
pin_memory:
use_multi_epochs_loader:
persistent_workers:
Returns:
"""
if dev_env is None:
dev_env = DeviceEnv.instance()
@ -85,10 +63,24 @@ def create_loader(
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
sampler = OrderedDistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
if collate_fn is None:
collate_fn = fast_collate
if mix_cfg is not None and mix_cfg.prob > 0:
collate_fn = FastCollateMixup(
mixup_alpha=mix_cfg.mixup_alpha,
cutmix_alpha=mix_cfg.cutmix_alpha,
cutmix_minmax=mix_cfg.cutmix_minmax,
prob=mix_cfg.prob,
switch_prob=mix_cfg.switch_prob,
mode=mix_cfg.mode,
correct_lam=mix_cfg.correct_lam,
label_smoothing=mix_cfg.label_smoothing,
num_classes=mix_cfg.num_classes,
)
else:
collate_fn = fast_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
@ -110,13 +102,18 @@ def create_loader(
loader = loader_class(dataset, **loader_args)
fetcher_kwargs = dict(
mean=mean,
std=std,
re_prob=re_prob if is_training and not no_aug else 0.,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
normalize=normalize,
mean=pp_cfg.mean,
std=pp_cfg.std,
)
if normalize and is_training and pp_cfg.aug is not None:
fetcher_kwargs.update(dict(
re_prob=pp_cfg.aug.re_prob,
re_mode=pp_cfg.aug.re_mode,
re_count=pp_cfg.aug.re_count,
num_aug_splits=pp_cfg.aug.num_aug_splits,
))
if dev_env.type_cuda:
loader = PrefetcherCuda(loader, **fetcher_kwargs)
else:

@ -102,7 +102,7 @@ class Mixup:
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
mode='batch', correct_lam=True, label_smoothing=0., num_classes=0):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
@ -113,6 +113,8 @@ class Mixup:
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
if label_smoothing > 0.:
assert num_classes > 0
self.num_classes = num_classes
self.mode = mode
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
@ -218,17 +220,30 @@ class Mixup:
return x, target
def blend(a, b, lam, is_tensor=False, round_output=True):
if is_tensor:
blend = a.to(dtype=torch.float32) * lam + b.to(dtype=torch.float32) * (1 - lam)
if round_output:
torch.round(blend, out=blend)
else:
blend = a.astype(np.float32) * lam + b.astype(np.float32) * (1 - lam)
if round_output:
np.rint(blend, out=blend)
return blend
class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
A Mixup impl that's performed while collating the batches.
"""
def _mix_elem_collate(self, output, batch, half=False):
def _mix_elem_collate(self, output, batch, half=False, is_tensor=False):
batch_size = len(batch)
num_elem = batch_size // 2 if half else batch_size
assert len(output) == num_elem
lam_batch, use_cutmix = self._params_per_elem(num_elem)
round_output = output.dtype == torch.uint8
for i in range(num_elem):
j = batch_size - i - 1
lam = lam_batch[i]
@ -236,22 +251,23 @@ class FastCollateMixup(Mixup):
if lam != 1.:
if use_cutmix[i]:
if not half:
mixed = mixed.copy()
mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output)
mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8))
output[i].copy_(mixed)
if half:
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_pair_collate(self, output, batch):
def _mix_pair_collate(self, output, batch, is_tensor=False):
batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
round_output = output.dtype == torch.uint8
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
@ -262,24 +278,30 @@ class FastCollateMixup(Mixup):
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
patch_i = mixed_i[:, yl:yh, xl:xh]
patch_i = patch_i.clone() if is_tensor else patch_i.copy() # don't want to modify while iterating
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
mixed_j[:, yl:yh, xl:xh] = patch_i
lam_batch[i] = lam
else:
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
mixed_temp = blend(mixed_i, mixed_j, lam, is_tensor, round_output)
mixed_j = blend(mixed_j, mixed_i, lam, is_tensor, round_output)
mixed_i = mixed_temp
np.rint(mixed_j, out=mixed_j)
np.rint(mixed_i, out=mixed_i)
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
if is_tensor:
mixed_i = mixed_i.to(dtype=output.dtype)
mixed_j = mixed_j.to(dtype=output.dtype)
else:
mixed_i = torch.from_numpy(mixed_i.astype(np.uint8))
mixed_j = torch.from_numpy(mixed_j.astype(np.uint8))
output[i].copy_(mixed_i)
output[j].copy_(mixed_j)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_batch_collate(self, output, batch):
def _mix_batch_collate(self, output, batch, is_tensor=False):
batch_size = len(batch)
lam, use_cutmix = self._params_per_batch()
round_output = output.dtype == torch.uint8
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
@ -288,12 +310,12 @@ class FastCollateMixup(Mixup):
mixed = batch[i][0]
if lam != 1.:
if use_cutmix:
mixed = mixed.copy() # don't want to modify the original while iterating
mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output)
mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8))
output[i].copy_(mixed)
return lam
def __call__(self, batch, _=None):
@ -302,13 +324,15 @@ class FastCollateMixup(Mixup):
half = 'half' in self.mode
if half:
batch_size //= 2
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
is_tensor = isinstance(batch[0][0], torch.Tensor)
output_dtype = batch[0][0].dtype if is_tensor else torch.uint8 # always uint8 if numpy src
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=output_dtype)
if self.mode == 'elem' or self.mode == 'half':
lam = self._mix_elem_collate(output, batch, half=half)
lam = self._mix_elem_collate(output, batch, half=half, is_tensor=is_tensor)
elif self.mode == 'pair':
lam = self._mix_pair_collate(output, batch)
lam = self._mix_pair_collate(output, batch, is_tensor=is_tensor)
else:
lam = self._mix_batch_collate(output, batch)
lam = self._mix_batch_collate(output, batch, is_tensor=is_tensor)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
target = target[:batch_size]

@ -7,25 +7,34 @@ from .random_erasing import RandomErasing
class PrefetcherCuda:
def __init__(self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
def __init__(
self,
loader,
device: torch.device = torch.device('cuda'),
dtype=torch.float32,
normalize=True,
normalize_shape=(1, 3, 1, 1),
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_aug_splits=0,
re_prob=0.,
re_mode='const',
re_count=1
):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
self.device = device
self.dtype = dtype
if normalize:
self.mean = torch.tensor(
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits, device=device)
else:
self.random_erasing = None
@ -35,12 +44,11 @@ class PrefetcherCuda:
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
next_input = next_input.to(device=self.device, non_blocking=True)
next_input = next_input.to(dtype=self.dtype)
if self.mean is not None:
next_input.sub_(self.mean).div_(self.std)
next_target = next_target.to(device=self.device, non_blocking=True)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
@ -76,4 +84,4 @@ class PrefetcherCuda:
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x
self.loader.collate_fn.mixup_enabled = x

@ -38,21 +38,20 @@ class RandomErasing:
'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-channel random (normal) color
'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
"""
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
mode='const', count=1, num_splits=0):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count
self.max_count = max_count or min_count
self.count = count
self.num_splits = num_splits
mode = mode.lower()
self.rand_color = False
@ -63,14 +62,13 @@ class RandomErasing:
self.per_pixel = True # per pixel random normal
else:
assert not mode or mode == 'const'
self.device = device
def _erase(self, img, chan, img_h, img_w, dtype):
device = img.device
if random.random() > self.probability:
return
area = img_h * img_w
count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
count = random.randint(1, self.count) if self.count > 1 else self.count
for _ in range(count):
for attempt in range(10):
target_area = random.uniform(self.min_area, self.max_area) * area / count
@ -81,17 +79,76 @@ class RandomErasing:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=device)
break
def __call__(self, input):
if len(input.size()) == 3:
self._erase(input, *input.size(), input.dtype)
def __call__(self, x):
if len(x.size()) == 3:
self._erase(x, *x.shape, x.dtype)
else:
batch_size, chan, img_h, img_w = input.size()
batch_size, chan, img_h, img_w = x.shape
# skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype)
return input
self._erase(x[i], chan, img_h, img_w, x.dtype)
return x
class RandomErasingMasked:
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
This variant of RandomErasing is intended to be applied to either a batch
or single image tensor after it has been normalized by dataset mean and std.
Args:
probability: Probability that the Random Erasing operation will be performed for each box (count)
min_area: Minimum percentage of erased area wrt input image area.
max_area: Maximum percentage of erased area wrt input image area.
min_aspect: Minimum aspect ratio of erased area.
count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is between 0 and this value.
"""
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', count=1, num_splits=0):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.mode = mode # FIXME currently ignored, add back options besides normal mean=0, std=1 noise?
self.count = count
self.num_splits = num_splits
@torch.no_grad()
def __call__(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
batch_size, _, img_h, img_w = x.shape
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
# NOTE simplified from v1 with with one count value and same prob applied for all
enable = (torch.empty((batch_size, self.count), device=device).uniform_() < self.probability).float()
enable = enable / enable.sum(dim=1, keepdim=True).clamp(min=1)
target_area = torch.empty(
(batch_size, self.count), device=device).uniform_(self.min_area, self.max_area) * enable
aspect_ratio = torch.empty((batch_size, self.count), device=device).uniform_(*self.log_aspect_ratio).exp()
h_coord = torch.arange(0, img_h, device=device).unsqueeze(-1).expand(-1, self.count).float()
w_coord = torch.arange(0, img_w, device=device).unsqueeze(-1).expand(-1, self.count).float()
h_mid = torch.rand((batch_size, self.count), device=device) * img_h
w_mid = torch.rand((batch_size, self.count), device=device) * img_w
noise = torch.empty_like(x[0]).normal_()
for i in range(batch_start, batch_size):
h_half = (img_h / 2) * torch.sqrt(target_area[i] * aspect_ratio[i]) # 1/2 box h
h_mask = (h_coord > (h_mid[i] - h_half)) & (h_coord < (h_mid[i] + h_half))
w_half = (img_w / 2) * torch.sqrt(target_area[i] / aspect_ratio[i]) # 1/2 box w
w_mask = (w_coord > (w_mid[i] - w_half)) & (w_coord < (w_mid[i] + w_half))
#mask = (h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1)
#x[i].copy_(torch.where(mask, noise, x[i]))
mask = ~(h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1)
x[i] = x[i].where(mask, noise)
#x[i].masked_scatter_(mask, noise)
return x

@ -1,5 +1,7 @@
import torch
import torchvision.transforms.functional as F
from torchvision.transforms import InterpolationMode
from PIL import Image
import warnings
import math
@ -30,29 +32,40 @@ class ToTensor:
return torch.from_numpy(np_img).to(dtype=self.dtype)
_pil_interpolation_to_str = {
Image.NEAREST: 'PIL.Image.NEAREST',
Image.BILINEAR: 'PIL.Image.BILINEAR',
Image.BICUBIC: 'PIL.Image.BICUBIC',
Image.LANCZOS: 'PIL.Image.LANCZOS',
Image.HAMMING: 'PIL.Image.HAMMING',
Image.BOX: 'PIL.Image.BOX',
}
class ToTensorNormalize:
def __init__(self, mean, std, dtype=torch.float32, device=torch.device('cpu')):
self.dtype = dtype
mean = torch.as_tensor(mean, dtype=dtype, device=device)
std = torch.as_tensor(std, dtype=dtype, device=device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
self.mean = mean
self.std = std
def _pil_interp(method):
if method == 'bicubic':
return Image.BICUBIC
elif method == 'lanczos':
return Image.LANCZOS
elif method == 'hamming':
return Image.HAMMING
else:
# default bilinear, do we want to allow nearest?
return Image.BILINEAR
def __call__(self, pil_img):
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(pil_img, mode_to_nptype.get(pil_img.mode, np.uint8))
)
if pil_img.mode == '1':
img = 255 * img
img = img.view(pil_img.size[1], pil_img.size[0], len(pil_img.getbands()))
img = img.permute((2, 0, 1))
if isinstance(img, torch.ByteTensor):
img = img.to(self.dtype)
img.sub_(self.mean * 255.).div_(self.std * 255.)
else:
img = img.to(self.dtype)
img.sub_(self.mean).div_(self.std)
return img
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
class RandomResizedCropAndInterpolation:
@ -82,7 +95,7 @@ class RandomResizedCropAndInterpolation:
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = _pil_interp(interpolation)
self.interpolation = InterpolationMode(interpolation)
self.scale = scale
self.ratio = ratio
@ -146,9 +159,9 @@ class RandomResizedCropAndInterpolation:
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
interpolate_str = ' '.join([x.value for x in self.interpolation])
else:
interpolate_str = _pil_interpolation_to_str[self.interpolation]
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))

@ -4,59 +4,50 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from typing import Union, Tuple
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import _pil_interp, RandomResizedCropAndInterpolation, ToNumpy, ToTensor
from timm.data.config import PreprocessCfg, AugCfg
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.random_erasing import RandomErasing
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensorNormalize
def transforms_noaug_train(
img_size=224,
img_size: Union[int, Tuple[int]] = 224,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
normalize=False,
):
if interpolation == 'random':
# random interpolation not supported with no-aug
interpolation = 'bilinear'
tfl = [
transforms.Resize(img_size, _pil_interp(interpolation)),
transforms.Resize(img_size, transforms.InterpolationMode(interpolation)),
transforms.CenterCrop(img_size)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
if normalize:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))
]
else:
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl += [ToNumpy()]
return transforms.Compose(tfl)
def transforms_imagenet_train(
img_size=224,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
img_size: Union[int, Tuple[int]] = 224,
interpolation='random',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
aug_cfg=AugCfg(),
normalize=False,
separate=False,
):
"""
@ -66,18 +57,24 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
scale_range = tuple(aug_cfg.scale_range or (0.08, 1.0)) # default imagenet scale range
ratio_range = tuple(aug_cfg.ratio_range or (3. / 4., 4. / 3.)) # default imagenet ratio range
# 'primary' train transforms include random resize + crop w/ optional horizontal and vertical flipping aug.
# This is the core of standard ImageNet ResNet and Inception pre-processing
primary_tfl = [
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
RandomResizedCropAndInterpolation(img_size, scale=scale_range, ratio=ratio_range, interpolation=interpolation)]
if aug_cfg.hflip_prob > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=aug_cfg.hflip_prob)]
if aug_cfg.vflip_prob > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=aug_cfg.vflip_prob)]
# 'secondary' transform stage includes either color jitter (could add lighting too) or auto-augmentations
# such as AutoAugment, RandAugment, AugMix, etc
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if aug_cfg.auto_augment:
aa = aug_cfg.auto_augment
assert isinstance(aa, str)
if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size)
else:
@ -87,58 +84,63 @@ def transforms_imagenet_train(
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = _pil_interp(interpolation)
if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['interpolation'] = interpolation
if aa.startswith('rand'):
secondary_tfl += [rand_augment_transform(aa, aa_params)]
elif aa.startswith('augmix'):
aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
secondary_tfl += [augment_and_mix_transform(aa, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
secondary_tfl += [auto_augment_transform(aa, aa_params)]
elif aug_cfg.color_jitter is not None:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
cj = aug_cfg.color_jitter
if isinstance(cj, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
assert len(cj) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
cj = (float(cj),) * 3
secondary_tfl += [transforms.ColorJitter(*cj)]
# 'final' transform stage includes normalization, followed by optional random erasing and tensor conversion
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
else:
if normalize:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
ToTensorNormalize(mean=mean, std=std)
]
if re_prob > 0.:
final_tfl.append(
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
if aug_cfg.re_prob > 0.:
final_tfl.append(RandomErasing(
aug_cfg.re_prob,
mode=aug_cfg.re_mode,
count=aug_cfg.re_count,
num_splits=aug_cfg.num_aug_splits))
else:
# when normalize disabled, (pre)fetcher and collate will handle tensor conversion and normalize
final_tfl += [ToNumpy()]
if separate:
# return each transform stage separately
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
def transforms_imagenet_eval(
img_size=224,
img_size: Union[int, Tuple[int]] = 224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
std=IMAGENET_DEFAULT_STD,
normalize=False,
):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# FIXME handle case where img is square and we want non aspect preserving resize
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
@ -147,27 +149,87 @@ def transforms_imagenet_eval(
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, _pil_interp(interpolation)),
transforms.Resize(scale_size, transforms.InterpolationMode(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
if normalize:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
ToTensorNormalize(mean=mean, std=std)
]
else:
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl += [ToNumpy()]
return transforms.Compose(tfl)
def create_transform_v2(
cfg=PreprocessCfg(),
is_training=False,
normalize=False,
separate=False,
tf_preprocessing=False,
):
"""
Args:
cfg: Pre-processing configuration
is_training (bool): Create transform for training pre-processing
tf_preprocessing (bool): Use Tensorflow pre-processing (for validation)
normalize (bool): Enable normalization in transforms (otherwise handled by fetcher/pre-fetcher)
separate (bool): Return transforms separated into stages (for train)
Returns:
"""
input_size = cfg.input_size
if isinstance(input_size, (tuple, list)):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing:
assert not normalize, "Expecting normalization to be handled in (pre)fetcher w/ TF preprocessing"
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=cfg.interpolation)
else:
if is_training and cfg.aug is None:
assert not separate, "Cannot perform split augmentation with no_aug"
transform = transforms_noaug_train(
img_size,
interpolation=cfg.interpolation,
normalize=normalize,
mean=cfg.mean,
std=cfg.std)
elif is_training:
transform = transforms_imagenet_train(
img_size,
interpolation=cfg.interpolation,
mean=cfg.mean,
std=cfg.std,
aug_cfg=cfg.aug,
normalize=normalize,
separate=separate)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=cfg.interpolation,
crop_pct=cfg.crop_pct,
mean=cfg.mean,
std=cfg.std,
normalize=normalize,
)
return transform
def create_transform(
input_size,
is_training=False,
use_fetcher=False,
use_prefetcher=False,
no_aug=False,
scale=None,
ratio=None,
@ -191,7 +253,8 @@ def create_transform(
else:
img_size = input_size
if tf_preprocessing and use_fetcher:
normalize_in_transform = not use_prefetcher
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
@ -202,35 +265,41 @@ def create_transform(
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_fetcher,
mean=mean,
std=std)
std=std,
normalize=normalize_in_transform,
)
elif is_training:
transform = transforms_imagenet_train(
img_size,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
aug_cfg = AugCfg(
scale_range=scale,
ratio_range=ratio,
hflip_prob=hflip,
vflip_prob=vflip,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_fetcher,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=separate)
num_aug_splits=re_num_splits,
)
transform = transforms_imagenet_train(
img_size,
interpolation=interpolation,
mean=mean,
std=std,
aug_cfg=aug_cfg,
normalize=normalize_in_transform,
separate=separate
)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
assert not separate, "Separate transforms not supported for validation pre-processing"
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_fetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
crop_pct=crop_pct,
normalize=normalize_in_transform,
)
return transform

@ -24,13 +24,20 @@ _logger = logging.getLogger(__name__)
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = 'state_dict'
state_dict_key = ''
if isinstance(checkpoint, dict):
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
elif use_ema and 'model_ema' in checkpoint:
state_dict_key = 'model_ema'
elif 'state_dict' in checkpoint:
state_dict_key = 'state_dict'
elif 'model' in checkpoint:
state_dict_key = 'model'
if state_dict_key:
state_dict = checkpoint[state_dict_key]
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
for k, v in state_dict.items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v

@ -30,7 +30,8 @@ import torchvision.utils
from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, Monitor, Tracker,\
TrainState, TrainServices, TrainCfg, CheckpointManager, AccuracyTopK, AvgTensor, distribute_bn
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\
PreprocessCfg, AugCfg, MixupCfg, AugMixDataset
from timm.models import create_model, safe_model_name, convert_splitbn_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import optimizer_kwargs
@ -283,10 +284,11 @@ def main():
else:
_logger.info('Training with a single process on 1 device.')
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA)
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'
train_state = setup_train_task(args, dev_env, mixup_active)
train_cfg = train_state.train_cfg
@ -421,11 +423,9 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
# setup augmentation batch splits for contrastive loss or split bn
assert args.aug_splits == 0 or args.aug_splits > 1, 'A split of 1 makes no sense'
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert args.aug_splits > 1 or args.resplit
assert args.aug_splits > 1
model = convert_splitbn_model(model, max(args.aug_splits, 2))
train_state = setup_model_and_optimizer(
@ -481,7 +481,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
return train_state
def setup_data(args, default_cfg, dev_env, mixup_active):
def setup_data(args, default_cfg, dev_env: DeviceEnv, mixup_active: bool):
data_config = resolve_data_config(vars(args), default_cfg=default_cfg, verbose=dev_env.primary)
# create the train and eval datasets
@ -489,18 +489,18 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
args.dataset,
root=args.data_dir, split=args.train_split, is_training=True,
batch_size=args.batch_size, repeats=args.epoch_repeats)
dataset_eval = create_dataset(
args.dataset, root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
args.dataset,
root=args.data_dir, split=args.val_split, is_training=False, batch_size=args.batch_size)
# setup mixup / cutmix
collate_fn = None
mixup_cfg = None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
mixup_cfg = MixupCfg(
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
label_smoothing=args.smoothing, num_classes=args.num_classes)
assert not args.aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
# wrap dataset in AugMix helper
if args.aug_splits > 1:
@ -510,46 +510,72 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
train_interpolation = args.train_interpolation
if args.no_aug or not train_interpolation:
train_interpolation = data_config['interpolation']
loader_train = create_loader(
dataset_train,
if args.no_aug:
train_aug_cfg = None
else:
train_aug_cfg = AugCfg(
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
ratio_range=args.ratio,
scale_range=args.scale,
hflip_prob=args.hflip,
vflip_prob=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_splits=args.aug_splits,
)
train_pp_cfg = PreprocessCfg(
input_size=data_config['input_size'],
batch_size=args.batch_size,
is_training=True,
no_aug=args.no_aug,
re_prob=args.reprob,
re_mode=args.remode,
re_count=args.recount,
re_split=args.resplit,
scale=args.scale,
ratio=args.ratio,
hflip=args.hflip,
vflip=args.vflip,
color_jitter=args.color_jitter,
auto_augment=args.aa,
num_aug_splits=args.aug_splits,
interpolation=train_interpolation,
crop_pct=data_config['crop_pct'],
mean=data_config['mean'],
std=data_config['std'],
aug=train_aug_cfg,
)
# if using PyTorch XLA and RandomErasing is enabled, we must normalize and do RE in transforms on CPU
normalize_in_transform = dev_env.type_xla and args.reprob > 0
dataset_train.transform = create_transform_v2(
cfg=train_pp_cfg, is_training=True, normalize=normalize_in_transform)
loader_train = create_loader_v2(
dataset_train,
batch_size=args.batch_size,
is_training=True,
normalize=not normalize_in_transform,
pp_cfg=train_pp_cfg,
mix_cfg=mixup_cfg,
num_workers=args.workers,
collate_fn=collate_fn,
pin_memory=args.pin_mem,
use_multi_epochs_loader=args.use_multi_epochs_loader
)
eval_pp_cfg = PreprocessCfg(
input_size=data_config['input_size'],
interpolation=data_config['interpolation'],
crop_pct=data_config['crop_pct'],
mean=data_config['mean'],
std=data_config['std'],
)
dataset_eval.transform = create_transform_v2(
cfg=eval_pp_cfg, is_training=False, normalize=normalize_in_transform)
eval_workers = args.workers
if 'tfds' in args.dataset:
# FIXME reduce validation issues when using TFDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
loader_eval = create_loader_v2(
dataset_eval,
input_size=data_config['input_size'],
batch_size=args.validation_batch_size_multiplier * args.batch_size,
is_training=False,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
normalize=not normalize_in_transform,
pp_cfg=eval_pp_cfg,
num_workers=eval_workers,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)
return data_config, loader_eval, loader_train
@ -700,8 +726,12 @@ def evaluate(
loss = loss_fn(output, target)
# FIXME, explictly marking step for XLA use since I'm not using the parallel xm loader
# need to investigate whether parallel loader wrapper is helpful on tpu-vm or only usefor for 2-vm setup.
dev_env.mark_step()
# need to investigate whether parallel loader wrapper is helpful on tpu-vm or only use for 2-vm setup.
if dev_env.type_xla:
dev_env.mark_step()
elif dev_env.type_cuda:
dev_env.synchronize()
tracker.mark_iter_step_end()
losses_m.update(loss, output.size(0))
accuracy_m.update(output, target)

@ -20,7 +20,8 @@ from collections import OrderedDict
from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config, RealLabelsImagenet, \
PreprocessCfg
from timm.utils import natural_key, setup_default_logging
@ -141,18 +142,22 @@ def validate(args):
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
eval_pp_cfg = PreprocessCfg(
input_size=data_config['input_size'],
batch_size=args.batch_size,
interpolation=data_config['interpolation'],
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
mean=data_config['mean'],
std=data_config['std'],
)
dataset.transform = create_transform_v2(cfg=eval_pp_cfg, is_training=False)
loader = create_loader_v2(
dataset,
batch_size=args.batch_size,
pp_cfg=eval_pp_cfg,
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
pin_memory=args.pin_mem)
logger = Monitor(logger=_logger)
tracker = Tracker()
@ -175,16 +180,17 @@ def validate(args):
loss = criterion(output, target)
if dev_env.type_cuda:
torch.cuda.synchronize()
dev_env.synchronize()
tracker.mark_iter_step_end()
losses.update(loss.detach(), sample.size(0))
if dev_env.type_xla:
dev_env.mark_step()
if real_labels is not None:
real_labels.add_result(output)
accuracy.update(output.detach(), target)
if dev_env.type_xla:
dev_env.mark_step()
losses.update(loss.detach(), sample.size(0))
accuracy.update(output.detach(), target)
tracker.mark_iter()
if step_idx % args.log_freq == 0:
@ -212,7 +218,7 @@ def validate(args):
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
cropt_pct=crop_pct,
cropt_pct=eval_pp_cfg.crop_pct,
interpolation=data_config['interpolation'])
logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results)

Loading…
Cancel
Save