From b3cb5f32752c22312e0abcbfb599eabc2e14a4bf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 16 Feb 2020 20:08:17 -0800 Subject: [PATCH 1/5] Working on CutMix impl as per #8, integrating with Mixup, currently experimenting... --- timm/data/__init__.py | 2 +- timm/data/dataset.py | 2 +- timm/data/mixup.py | 184 +++++++++++++++++++++++++++++++++++++++--- train.py | 11 ++- 4 files changed, 183 insertions(+), 16 deletions(-) diff --git a/timm/data/__init__.py b/timm/data/__init__.py index ee2240b4..8e261617 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -4,6 +4,6 @@ from .dataset import Dataset, DatasetTar, AugMixDataset from .transforms import * from .loader import create_loader from .transforms_factory import create_transform -from .mixup import mixup_batch, FastCollateMixup +from .mixup import mix_batch, FastCollateMixup, FastCollateMixupBatchwise, FastCollateMixupElementwise from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 2ce79e7e..4c6bef96 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -89,7 +89,7 @@ class Dataset(data.Dataset): return img, target def __len__(self): - return len(self.imgs) + return len(self.samples) def filenames(self, indices=[], basename=False): if indices: diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 4678472d..1b298af0 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -1,5 +1,30 @@ +""" Mixup and Cutmix + +Papers: +mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) + +CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) + +Code Reference: +CutMix: https://github.com/clovaai/CutMix-PyTorch + +Hacked together by Ross Wightman +""" + import numpy as np import torch +import math +from enum import IntEnum + + +class MixupMode(IntEnum): + MIXUP = 0 + CUTMIX = 1 + RANDOM = 2 + + @classmethod + def from_str(cls, value): + return cls[value.upper()] def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): @@ -12,7 +37,7 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): on_value = 1. - smoothing + off_value y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) - return lam*y1 + (1. - lam)*y2 + return y1 * lam + y2 * (1. - lam) def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): @@ -24,28 +49,167 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab return input, target +def rand_bbox(size, ratio): + H, W = size[-2:] + ratio = max(min(ratio, 0.8), 0.2) + cut_h, cut_w = int(H * ratio), int(W * ratio) + cy, cx = np.random.randint(H), np.random.randint(W) + yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H) + xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W) + return yl, yh, xl, xh + + +def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): + lam = 1. + if not disable: + lam = np.random.beta(alpha, alpha) + if lam != 1: + ratio = math.sqrt(1. - lam) + yl, yh, xl, xh = rand_bbox(input.size(), ratio) + input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] + target = mixup_target(target, num_classes, lam, smoothing) + return input, target + + +def _resolve_mode(mode): + mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode + if mode == MixupMode.RANDOM: + mode = MixupMode(np.random.rand() > 0.5) + return mode # will be one of cutmix or mixup + + +def mix_batch( + input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): + mode = _resolve_mode(mode) + if mode == MixupMode.CUTMIX: + return mixup_batch(input, target, alpha, num_classes, smoothing, disable) + else: + return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) + + class FastCollateMixup: + """Fast Collate Mixup that applies different params to each element + flipped pair - def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): + NOTE once experiments are done, one of the three variants will remain with this class name + """ + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): self.mixup_alpha = mixup_alpha self.label_smoothing = label_smoothing self.num_classes = num_classes + self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode self.mixup_enabled = True + self.correct_lam = False # correct lambda based on clipped area for cutmix + + def _do_mix(self, tensor, batch): + batch_size = len(batch) + lam_out = torch.ones(batch_size) + for i in range(batch_size//2): + j = batch_size - i - 1 + lam = 1. + if self.mixup_enabled: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + + if _resolve_mode(self.mode) == MixupMode.CUTMIX: + mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32) + ratio = math.sqrt(1. - lam) + if lam != 1: + yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) + mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + mixed_j[:, yl:yh, xl:xh] = batch[i][0][:, yl:yh, xl:xh].astype(np.float32) + if self.correct_lam: + lam_corrected = (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + lam_out[i] -= lam_corrected + lam_out[j] -= lam_corrected + else: + lam_out[i] = lam + lam_out[j] = lam + else: + mixed_i = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + mixed_j = batch[j][0].astype(np.float32) * lam + batch[i][0].astype(np.float32) * (1 - lam) + lam_out[i] = lam + lam_out[j] = lam + np.round(mixed_i, out=mixed_i) + np.round(mixed_j, out=mixed_j) + tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8)) + tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + return lam_out def __call__(self, batch): batch_size = len(batch) + assert batch_size % 2 == 0, 'Batch size should be even when using this' + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + lam = self._do_mix(tensor, batch) + target = torch.tensor([b[1] for b in batch], dtype=torch.int64) + target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu') + + return tensor, target + + +class FastCollateMixupElementwise(FastCollateMixup): + """Fast Collate Mixup that applies different params to each batch element + + NOTE this is for experimentation, may remove at some point + """ + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): + super(FastCollateMixupElementwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode) + + def _do_mix(self, tensor, batch): + batch_size = len(batch) + lam_out = torch.ones(batch_size) + for i in range(batch_size): + lam = 1. + if self.mixup_enabled: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + + if _resolve_mode(self.mode) == MixupMode.CUTMIX: + mixed = batch[i][0].astype(np.float32) + ratio = math.sqrt(1. - lam) + if lam != 1: + yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) + mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) + if self.correct_lam: + lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + else: + lam_out[i] = lam + else: + mixed = batch[i][0].astype(np.float32) * lam + \ + batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + lam_out[i] = lam + np.round(mixed, out=mixed) + tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) + return lam_out + + +class FastCollateMixupBatchwise(FastCollateMixup): + """Fast Collate Mixup that applies same params to whole batch + + NOTE this is for experimentation, may remove at some point + """ + + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): + super(FastCollateMixupBatchwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode) + + def _do_mix(self, tensor, batch): + batch_size = len(batch) + lam_out = torch.ones(batch_size) lam = 1. + cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + if cutmix and self.correct_lam: + ratio = math.sqrt(1. - lam) + yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio) + lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) - target = torch.tensor([b[1] for b in batch], dtype=torch.int64) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') - - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) for i in range(batch_size): - mixed = batch[i][0].astype(np.float32) * lam + \ - batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + if cutmix: + mixed = batch[i][0].astype(np.float32) + if lam != 1: + mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) + lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + else: + mixed = batch[i][0].astype(np.float32) * lam + \ + batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) - - return tensor, target + return lam diff --git a/train.py b/train.py index 32beb418..f8fdf698 100755 --- a/train.py +++ b/train.py @@ -28,7 +28,8 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mix_batch, AugMixDataset,\ + FastCollateMixupElementwise, FastCollateMixupBatchwise from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy @@ -134,6 +135,8 @@ parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') +parser.add_argument('--mixup-mode', type=str, default='mixup', + help='Mixup mode ("mixup", "cutmix", "random", default: "mixup")') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, @@ -352,7 +355,7 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) - collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) + collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes, args.mixup_mode) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) @@ -504,10 +507,10 @@ def train_epoch( if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: - input, target = mixup_batch( + input, target = mix_batch( input, target, alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, - disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) + disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch, mode=args.mixup_mode) output = model(input) From 670c61b28ffd123267df4126fc700fd8f2837d22 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 17 Feb 2020 11:00:54 -0800 Subject: [PATCH 2/5] Some cutmix/mixup cleanup/fixes --- timm/data/mixup.py | 55 +++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 1b298af0..a59fa1f3 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -14,6 +14,7 @@ Hacked together by Ross Wightman import numpy as np import torch import math +import numbers from enum import IntEnum @@ -49,9 +50,17 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab return input, target +def calc_ratio(lam, minmax=None): + ratio = math.sqrt(1 - lam) + if minmax is not None: + if isinstance(minmax, numbers.Number): + minmax = (minmax, 1 - minmax) + ratio = np.clip(ratio, minmax[0], minmax[1]) + return ratio + + def rand_bbox(size, ratio): H, W = size[-2:] - ratio = max(min(ratio, 0.8), 0.2) cut_h, cut_w = int(H * ratio), int(W * ratio) cy, cx = np.random.randint(H), np.random.randint(W) yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H) @@ -59,14 +68,15 @@ def rand_bbox(size, ratio): return yl, yh, xl, xh -def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): +def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False): lam = 1. if not disable: lam = np.random.beta(alpha, alpha) if lam != 1: - ratio = math.sqrt(1. - lam) - yl, yh, xl, xh = rand_bbox(input.size(), ratio) + yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam)) input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] + if correct_lam: + lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) target = mixup_target(target, num_classes, lam, smoothing) return input, target @@ -82,9 +92,9 @@ def mix_batch( input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): mode = _resolve_mode(mode) if mode == MixupMode.CUTMIX: - return mixup_batch(input, target, alpha, num_classes, smoothing, disable) - else: return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) + else: + return mixup_batch(input, target, alpha, num_classes, smoothing, disable) class FastCollateMixup: @@ -99,6 +109,7 @@ class FastCollateMixup: self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode self.mixup_enabled = True self.correct_lam = False # correct lambda based on clipped area for cutmix + self.ratio_minmax = None # (0.2, 0.8) def _do_mix(self, tensor, batch): batch_size = len(batch) @@ -111,7 +122,7 @@ class FastCollateMixup: if _resolve_mode(self.mode) == MixupMode.CUTMIX: mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32) - ratio = math.sqrt(1. - lam) + ratio = calc_ratio(lam, self.ratio_minmax) if lam != 1: yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) @@ -132,7 +143,7 @@ class FastCollateMixup: np.round(mixed_j, out=mixed_j) tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8)) tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8)) - return lam_out + return lam_out.unsqueeze(1) def __call__(self, batch): batch_size = len(batch) @@ -140,7 +151,7 @@ class FastCollateMixup: tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) lam = self._do_mix(tensor, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) - target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu') + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') return tensor, target @@ -157,27 +168,27 @@ class FastCollateMixupElementwise(FastCollateMixup): batch_size = len(batch) lam_out = torch.ones(batch_size) for i in range(batch_size): + j = batch_size - i - 1 lam = 1. if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) if _resolve_mode(self.mode) == MixupMode.CUTMIX: mixed = batch[i][0].astype(np.float32) - ratio = math.sqrt(1. - lam) if lam != 1: + ratio = calc_ratio(lam) yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) - mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) if self.correct_lam: lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) else: lam_out[i] = lam else: - mixed = batch[i][0].astype(np.float32) * lam + \ - batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) lam_out[i] = lam np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) - return lam_out + return lam_out.unsqueeze(1) class FastCollateMixupBatchwise(FastCollateMixup): @@ -191,25 +202,23 @@ class FastCollateMixupBatchwise(FastCollateMixup): def _do_mix(self, tensor, batch): batch_size = len(batch) - lam_out = torch.ones(batch_size) lam = 1. cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) - if cutmix and self.correct_lam: - ratio = math.sqrt(1. - lam) - yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio) - lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + if cutmix: + yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam)) + if self.correct_lam: + lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) for i in range(batch_size): + j = batch_size - i - 1 if cutmix: mixed = batch[i][0].astype(np.float32) if lam != 1: - mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) - lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) else: - mixed = batch[i][0].astype(np.float32) * lam + \ - batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) return lam From f471c17c9d12dea9385edc395816fb5d57cf3412 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Aug 2020 00:10:33 -0700 Subject: [PATCH 3/5] More cutmix/mixup overhaul, ready to kick-off some trials. --- timm/data/mixup.py | 240 ++++++++++++++++++++++++++------------------- train.py | 24 +++-- 2 files changed, 159 insertions(+), 105 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index cf6df6f6..de19c616 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -15,17 +15,6 @@ import numpy as np import torch import math import numbers -from enum import IntEnum - - -class MixupMode(IntEnum): - MIXUP = 0 - CUTMIX = 1 - RANDOM = 2 - - @classmethod - def from_str(cls, value): - return cls[value.upper()] def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): @@ -50,30 +39,49 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab return input, target -def calc_ratio(lam, minmax=None): +def rand_bbox(size, lam, border=0., count=None): ratio = math.sqrt(1 - lam) - if minmax is not None: - if isinstance(minmax, numbers.Number): - minmax = (minmax, 1 - minmax) - ratio = np.clip(ratio, minmax[0], minmax[1]) - return ratio - - -def rand_bbox(size, ratio): - H, W = size[-2:] - cut_h, cut_w = int(H * ratio), int(W * ratio) - cy, cx = np.random.randint(H), np.random.randint(W) - yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H) - xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W) + img_h, img_w = size[-2:] + cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) + margin_y, margin_x = int(border * cut_h), int(border * cut_w) + cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) + cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) + yl = np.clip(cy - cut_h // 2, 0, img_h) + yh = np.clip(cy + cut_h // 2, 0, img_h) + xl = np.clip(cx - cut_w // 2, 0, img_w) + xh = np.clip(cx + cut_w // 2, 0, img_w) return yl, yh, xl, xh +def rand_bbox_minmax(size, minmax, count=None): + assert len(minmax) == 2 + img_h, img_w = size[-2:] + cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) + cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) + yl = np.random.randint(0, img_h - cut_h, size=count) + xl = np.random.randint(0, img_w - cut_w, size=count) + yu = yl + cut_h + xu = xl + cut_w + return yl, yu, xl, xu + + +def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): + if ratio_minmax is not None: + yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) + else: + yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) + if correct_lam or ratio_minmax is not None: + bbox_area = (yu - yl) * (xu - xl) + lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1]) + return (yl, yu, xl, xu), lam + + def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False): lam = 1. if not disable: lam = np.random.beta(alpha, alpha) if lam != 1: - yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam)) + yl, yh, xl, xh = rand_bbox(input.size(), lam) input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] if correct_lam: lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) @@ -81,101 +89,135 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa return input, target -def _resolve_mode(mode): - mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode - if mode == MixupMode.RANDOM: - mode = MixupMode(np.random.rand() > 0.7) - return mode # will be one of cutmix or mixup - - def mix_batch( - input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): - mode = _resolve_mode(mode) - if mode == MixupMode.CUTMIX: - return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) + input, target, mixup_alpha=0.2, cutmix_alpha=0., prob=1.0, switch_prob=.5, + num_classes=1000, smoothing=0.1, disable=False): + # FIXME test this version + if np.random.rand() > prob: + return input, target + use_cutmix = cutmix_alpha > 0. and np.random.rand() <= switch_prob + if use_cutmix: + return cutmix_batch(input, target, cutmix_alpha, num_classes, smoothing, disable) else: - return mixup_batch(input, target, alpha, num_classes, smoothing, disable) + return mixup_batch(input, target, mixup_alpha, num_classes, smoothing, disable) class FastCollateMixup: - """Fast Collate Mixup that applies different params to each element + flipped pair + """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch NOTE once experiments are done, one of the three variants will remain with this class name + """ - def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): + def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, + elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000): + """ + + Args: + mixup_alpha (float): mixup alpha value, mixup is active if > 0. + cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. + cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None + prob (float): probability of applying mixup or cutmix per batch or element + switch_prob (float): probability of using cutmix instead of mixup when both active + elementwise (bool): apply mixup/cutmix params per batch element instead of per batch + label_smoothing (float): + num_classes (int): + """ self.mixup_alpha = mixup_alpha + self.cutmix_alpha = cutmix_alpha + self.cutmix_minmax = cutmix_minmax + if self.cutmix_minmax is not None: + assert len(self.cutmix_minmax) == 2 + # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe + self.cutmix_alpha = 1.0 + self.prob = prob + self.switch_prob = switch_prob self.label_smoothing = label_smoothing self.num_classes = num_classes - self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode - self.mixup_enabled = True - self.correct_lam = True # correct lambda based on clipped area for cutmix - self.ratio_minmax = None # (0.2, 0.8) + self.elementwise = elementwise + self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix + self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) - def _do_mix(self, tensor, batch): + def _mix_elem(self, output, batch): batch_size = len(batch) - lam_out = torch.ones(batch_size) + lam_out = np.ones(batch_size) + use_cutmix = np.zeros(batch_size).astype(np.bool) + if self.mixup_enabled: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand(batch_size) < self.switch_prob + lam_mix = np.where( + use_cutmix, + np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size), + np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) + elif self.cutmix_alpha > 0.: + use_cutmix = np.ones(batch_size).astype(np.bool) + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) + else: + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out) + for i in range(batch_size): j = batch_size - i - 1 - lam = 1. - if self.mixup_enabled: - lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) - - if _resolve_mode(self.mode) == MixupMode.CUTMIX: - mixed = batch[i][0].astype(np.float32) - if lam != 1: - ratio = calc_ratio(lam) - yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) + lam = lam_out[i] + mixed = batch[i][0].astype(np.float32) + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) - if self.correct_lam: - lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) - else: - lam_out[i] = lam + lam_out[i] = lam + else: + mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) + lam_out[i] = lam + np.round(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + return torch.tensor(lam_out).unsqueeze(1) + + def _mix_batch(self, output, batch): + batch_size = len(batch) + lam = 1. + use_cutmix = False + if self.mixup_enabled and np.random.rand() < self.prob: + if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: + use_cutmix = np.random.rand() < self.switch_prob + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ + np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.mixup_alpha > 0.: + lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha) + elif self.cutmix_alpha > 0.: + use_cutmix = True + lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) else: - mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - lam_out[i] = lam - np.round(mixed, out=mixed) - tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) - return lam_out.unsqueeze(1) + assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." + lam = lam_mix + + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + + for i in range(batch_size): + j = batch_size - i - 1 + mixed = batch[i][0].astype(np.float32) + if lam != 1.: + if use_cutmix: + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + else: + mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.round(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + return lam def __call__(self, batch): batch_size = len(batch) assert batch_size % 2 == 0, 'Batch size should be even when using this' - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - lam = self._do_mix(tensor, batch) + output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + if self.elementwise: + lam = self._mix_elem(output, batch) + else: + lam = self._mix_batch(output, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') - return tensor, target - - -class FastCollateMixupBatchwise(FastCollateMixup): - """Fast Collate Mixup that applies same params to whole batch - - NOTE this is for experimentation, may remove at some point - """ - - def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): - super(FastCollateMixupBatchwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode) + return output, target - def _do_mix(self, tensor, batch): - batch_size = len(batch) - lam = 1. - cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX - if self.mixup_enabled: - lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) - if cutmix: - yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam)) - if self.correct_lam: - lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) - - for i in range(batch_size): - j = batch_size - i - 1 - if cutmix: - mixed = batch[i][0].astype(np.float32) - if lam != 1: - mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) - else: - mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.round(mixed, out=mixed) - tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) - return lam diff --git a/train.py b/train.py index af1b81dd..3b038d53 100755 --- a/train.py +++ b/train.py @@ -157,8 +157,16 @@ parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') -parser.add_argument('--mixup-mode', type=str, default='mixup', - help='Mixup mode. One of "mixup", "cutmix", "random" (default: "mixup")') +parser.add_argument('--cutmix', type=float, default=0.0, + help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') +parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') +parser.add_argument('--mixup-prob', type=float, default=1.0, + help='Probability of performing mixup or cutmix when either/both is enabled') +parser.add_argument('--mixup-switch-prob', type=float, default=0.5, + help='Probability of switching to cutmix when both mixup and cutmix enabled') +parser.add_argument('--mixup-elem', action='store_true', default=False, + help='Apply mixup/cutmix params uniquely per batch element instead of per batch.') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='Turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, @@ -390,9 +398,12 @@ def main(): dataset_train = Dataset(train_dir) collate_fn = None - if args.prefetcher and args.mixup > 0: + if args.prefetcher and (args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None): assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) - collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes, args.mixup_mode) + collate_fn = FastCollateMixup( + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem, + label_smoothing=args.smoothing, num_classes=args.num_classes) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) @@ -555,8 +566,9 @@ def train_epoch( if args.mixup > 0.: input, target = mix_batch( input, target, - alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, - disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch, mode=args.mixup_mode) + mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, prob=args.mixup_prob, + switch_prob=args.mixup_switch_prob, num_classes=args.num_classes, smoothing=args.smoothing, + disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) output = model(input) From cd23f553974f0cda6b07761ba2d4a88e82236966 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Aug 2020 12:17:43 -0700 Subject: [PATCH 4/5] Fix mixed prec issues with new mixup code --- timm/data/mixup.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index de19c616..a018ea07 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -72,7 +72,7 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) if correct_lam or ratio_minmax is not None: bbox_area = (yu - yl) * (xu - xl) - lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1]) + lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) return (yl, yu, xl, xu), lam @@ -84,7 +84,7 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa yl, yh, xl, xh = rand_bbox(input.size(), lam) input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] if correct_lam: - lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) + lam = 1. - (yh - yl) * (xh - xl) / float(input.shape[-2] * input.shape[-1]) target = mixup_target(target, num_classes, lam, smoothing) return input, target @@ -139,7 +139,7 @@ class FastCollateMixup: def _mix_elem(self, output, batch): batch_size = len(batch) - lam_out = np.ones(batch_size) + lam_out = np.ones(batch_size, dtype=np.float32) use_cutmix = np.zeros(batch_size).astype(np.bool) if self.mixup_enabled: if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: @@ -155,22 +155,23 @@ class FastCollateMixup: lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) else: assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out) + lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix.astype(np.float32), lam_out) for i in range(batch_size): j = batch_size - i - 1 lam = lam_out[i] - mixed = batch[i][0].astype(np.float32) + mixed = batch[i][0] if lam != 1.: if use_cutmix[i]: + mixed = mixed.copy() (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] lam_out[i] = lam else: - mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) lam_out[i] = lam - np.round(mixed, out=mixed) + np.round(mixed, out=mixed) output[i] += torch.from_numpy(mixed.astype(np.uint8)) return torch.tensor(lam_out).unsqueeze(1) @@ -190,7 +191,7 @@ class FastCollateMixup: lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) else: assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam = lam_mix + lam = float(lam_mix) if use_cutmix: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( @@ -198,13 +199,14 @@ class FastCollateMixup: for i in range(batch_size): j = batch_size - i - 1 - mixed = batch[i][0].astype(np.float32) + mixed = batch[i][0] if lam != 1.: if use_cutmix: - mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + mixed = mixed.copy() + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] else: - mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.round(mixed, out=mixed) + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.round(mixed, out=mixed) output[i] += torch.from_numpy(mixed.astype(np.uint8)) return lam From 8c9814e3f500e8b37aae86dd4db10aba2c295bd2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 12 Aug 2020 17:01:32 -0700 Subject: [PATCH 5/5] Final cleanup of mixup/cutmix. Element/batch modes working with both collate (prefetcher active) and without prefetcher. --- timm/data/__init__.py | 2 +- timm/data/mixup.py | 214 ++++++++++++++++++++++++------------------ train.py | 41 ++++---- 3 files changed, 143 insertions(+), 114 deletions(-) diff --git a/timm/data/__init__.py b/timm/data/__init__.py index e1886fcc..15617859 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -4,7 +4,7 @@ from .dataset import Dataset, DatasetTar, AugMixDataset from .transforms import * from .loader import create_loader from .transforms_factory import create_transform -from .mixup import mix_batch, FastCollateMixup +from .mixup import Mixup, FastCollateMixup from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform from .real_labels import RealLabelsImagenet diff --git a/timm/data/mixup.py b/timm/data/mixup.py index a018ea07..63861bc7 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -10,11 +10,8 @@ CutMix: https://github.com/clovaai/CutMix-PyTorch Hacked together by / Copyright 2020 Ross Wightman """ - import numpy as np import torch -import math -import numbers def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): @@ -30,20 +27,21 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): return y1 * lam + y2 * (1. - lam) -def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): - lam = 1. - if not disable: - lam = np.random.beta(alpha, alpha) - input = input.mul(lam).add_(1 - lam, input.flip(0)) - target = mixup_target(target, num_classes, lam, smoothing) - return input, target - +def rand_bbox(img_shape, lam, margin=0., count=None): + """ Standard CutMix bounding-box + Generates a random square bbox based on lambda value. This impl includes + support for enforcing a border margin as percent of bbox dimensions. -def rand_bbox(size, lam, border=0., count=None): - ratio = math.sqrt(1 - lam) - img_h, img_w = size[-2:] + Args: + img_shape (tuple): Image shape as tuple + lam (float): Cutmix lambda value + margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image) + count (int): Number of bbox to generate + """ + ratio = np.sqrt(1 - lam) + img_h, img_w = img_shape[-2:] cut_h, cut_w = int(img_h * ratio), int(img_w * ratio) - margin_y, margin_x = int(border * cut_h), int(border * cut_w) + margin_y, margin_x = int(margin * cut_h), int(margin * cut_w) cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count) cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count) yl = np.clip(cy - cut_h // 2, 0, img_h) @@ -53,9 +51,20 @@ def rand_bbox(size, lam, border=0., count=None): return yl, yh, xl, xh -def rand_bbox_minmax(size, minmax, count=None): +def rand_bbox_minmax(img_shape, minmax, count=None): + """ Min-Max CutMix bounding-box + Inspired by Darknet cutmix impl, generates a random rectangular bbox + based on min/max percent values applied to each dimension of the input image. + + Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max. + + Args: + img_shape (tuple): Image shape as tuple + minmax (tuple or list): Min and max bbox ratios (as percent of image size) + count (int): Number of bbox to generate + """ assert len(minmax) == 2 - img_h, img_w = size[-2:] + img_h, img_w = img_shape[-2:] cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count) cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count) yl = np.random.randint(0, img_h - cut_h, size=count) @@ -66,6 +75,8 @@ def rand_bbox_minmax(size, minmax, count=None): def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None): + """ Generate bbox and apply lambda correction. + """ if ratio_minmax is not None: yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count) else: @@ -76,52 +87,22 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou return (yl, yu, xl, xu), lam -def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False): - lam = 1. - if not disable: - lam = np.random.beta(alpha, alpha) - if lam != 1: - yl, yh, xl, xh = rand_bbox(input.size(), lam) - input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] - if correct_lam: - lam = 1. - (yh - yl) * (xh - xl) / float(input.shape[-2] * input.shape[-1]) - target = mixup_target(target, num_classes, lam, smoothing) - return input, target - - -def mix_batch( - input, target, mixup_alpha=0.2, cutmix_alpha=0., prob=1.0, switch_prob=.5, - num_classes=1000, smoothing=0.1, disable=False): - # FIXME test this version - if np.random.rand() > prob: - return input, target - use_cutmix = cutmix_alpha > 0. and np.random.rand() <= switch_prob - if use_cutmix: - return cutmix_batch(input, target, cutmix_alpha, num_classes, smoothing, disable) - else: - return mixup_batch(input, target, mixup_alpha, num_classes, smoothing, disable) - - -class FastCollateMixup: - """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch - - NOTE once experiments are done, one of the three variants will remain with this class name +class Mixup: + """ Mixup/Cutmix that applies different params to each element or whole batch + Args: + mixup_alpha (float): mixup alpha value, mixup is active if > 0. + cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. + cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. + prob (float): probability of applying mixup or cutmix per batch or element + switch_prob (float): probability of switching to cutmix instead of mixup when both are active + elementwise (bool): apply mixup/cutmix params per batch element instead of per batch + correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders + label_smoothing (float): apply label smoothing to the mixed target tensor + num_classes (int): number of classes for target """ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000): - """ - - Args: - mixup_alpha (float): mixup alpha value, mixup is active if > 0. - cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0. - cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None - prob (float): probability of applying mixup or cutmix per batch or element - switch_prob (float): probability of using cutmix instead of mixup when both active - elementwise (bool): apply mixup/cutmix params per batch element instead of per batch - label_smoothing (float): - num_classes (int): - """ self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.cutmix_minmax = cutmix_minmax @@ -129,7 +110,7 @@ class FastCollateMixup: assert len(self.cutmix_minmax) == 2 # force cutmix alpha == 1.0 when minmax active to keep logic simple & safe self.cutmix_alpha = 1.0 - self.prob = prob + self.mix_prob = prob self.switch_prob = switch_prob self.label_smoothing = label_smoothing self.num_classes = num_classes @@ -137,10 +118,9 @@ class FastCollateMixup: self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) - def _mix_elem(self, output, batch): - batch_size = len(batch) - lam_out = np.ones(batch_size, dtype=np.float32) - use_cutmix = np.zeros(batch_size).astype(np.bool) + def _params_per_elem(self, batch_size): + lam = np.ones(batch_size, dtype=np.float32) + use_cutmix = np.zeros(batch_size, dtype=np.bool) if self.mixup_enabled: if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: use_cutmix = np.random.rand(batch_size) < self.switch_prob @@ -151,35 +131,17 @@ class FastCollateMixup: elif self.mixup_alpha > 0.: lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size) elif self.cutmix_alpha > 0.: - use_cutmix = np.ones(batch_size).astype(np.bool) + use_cutmix = np.ones(batch_size, dtype=np.bool) lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) else: assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix.astype(np.float32), lam_out) - - for i in range(batch_size): - j = batch_size - i - 1 - lam = lam_out[i] - mixed = batch[i][0] - if lam != 1.: - if use_cutmix[i]: - mixed = mixed.copy() - (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] - lam_out[i] = lam - else: - mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - lam_out[i] = lam - np.round(mixed, out=mixed) - output[i] += torch.from_numpy(mixed.astype(np.uint8)) - return torch.tensor(lam_out).unsqueeze(1) + lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam) + return lam, use_cutmix - def _mix_batch(self, output, batch): - batch_size = len(batch) + def _params_per_batch(self): lam = 1. use_cutmix = False - if self.mixup_enabled and np.random.rand() < self.prob: + if self.mixup_enabled and np.random.rand() < self.mix_prob: if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: use_cutmix = np.random.rand() < self.switch_prob lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \ @@ -192,17 +154,84 @@ class FastCollateMixup: else: assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." lam = float(lam_mix) + return lam, use_cutmix + def _mix_elem(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + + def _mix_batch(self, x): + lam, use_cutmix = self._params_per_batch() + if lam == 1.: + return 1. if use_cutmix: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( - output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh] + else: + x_flipped = x.flip(0).mul_(1. - lam) + x.mul_(lam).add_(x_flipped) + return lam + + def __call__(self, x, target): + assert len(x) % 2 == 0, 'Batch size should be even when using this' + lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + return x, target + + +class FastCollateMixup(Mixup): + """ Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch + A Mixup impl that's performed while collating the batches. + """ + + def _mix_elem_collate(self, output, batch): + batch_size = len(batch) + lam_batch, use_cutmix = self._params_per_elem(batch_size) for i in range(batch_size): j = batch_size - i - 1 + lam = lam_batch[i] mixed = batch[i][0] if lam != 1.: - if use_cutmix: + if use_cutmix[i]: mixed = mixed.copy() + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + lam_batch[i] = lam + np.round(mixed, out=mixed) + output[i] += torch.from_numpy(mixed.astype(np.uint8)) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_batch_collate(self, output, batch): + batch_size = len(batch) + lam, use_cutmix = self._params_per_batch() + if use_cutmix: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + for i in range(batch_size): + j = batch_size - i - 1 + mixed = batch[i][0] + if lam != 1.: + if use_cutmix: + mixed = mixed.copy() # don't want to modify the original while iterating mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] else: mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) @@ -210,16 +239,15 @@ class FastCollateMixup: output[i] += torch.from_numpy(mixed.astype(np.uint8)) return lam - def __call__(self, batch): + def __call__(self, batch, _=None): batch_size = len(batch) assert batch_size % 2 == 0, 'Batch size should be even when using this' output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) if self.elementwise: - lam = self._mix_elem(output, batch) + lam = self._mix_elem_collate(output, batch) else: - lam = self._mix_batch(output, batch) + lam = self._mix_batch_collate(output, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') - return output, target diff --git a/train.py b/train.py index 3b038d53..c28bd266 100755 --- a/train.py +++ b/train.py @@ -28,7 +28,7 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mix_batch, AugMixDataset +from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy @@ -398,12 +398,18 @@ def main(): dataset_train = Dataset(train_dir) collate_fn = None - if args.prefetcher and (args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None): - assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) - collate_fn = FastCollateMixup( + mixup_fn = None + mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None + if mixup_active: + mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem, label_smoothing=args.smoothing, num_classes=args.num_classes) + if args.prefetcher: + assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) + collate_fn = FastCollateMixup(**mixup_args) + else: + mixup_fn = Mixup(**mixup_args) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) @@ -465,17 +471,14 @@ def main(): if args.jsd: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() - elif args.mixup > 0.: - # smoothing is handled with mixup label transform + elif mixup_active: + # smoothing is handled with mixup target transform train_loss_fn = SoftTargetCrossEntropy().cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() - validate_loss_fn = train_loss_fn + validate_loss_fn = nn.CrossEntropyLoss().cuda() eval_metric = args.eval_metric best_metric = None @@ -503,7 +506,7 @@ def main(): train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - use_amp=use_amp, model_ema=model_ema) + use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: @@ -543,11 +546,13 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None): + lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None): - if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: - if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: + if args.prefetcher and loader.mixup_enabled: loader.mixup_enabled = False + elif mixup_fn is not None: + mixup_fn.mixup_enabled = False batch_time_m = AverageMeter() data_time_m = AverageMeter() @@ -563,12 +568,8 @@ def train_epoch( data_time_m.update(time.time() - end) if not args.prefetcher: input, target = input.cuda(), target.cuda() - if args.mixup > 0.: - input, target = mix_batch( - input, target, - mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, prob=args.mixup_prob, - switch_prob=args.mixup_switch_prob, num_classes=args.num_classes, smoothing=args.smoothing, - disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) + if mixup_fn is not None: + input, target = mixup_fn(input, target) output = model(input)