From b3cb5f32752c22312e0abcbfb599eabc2e14a4bf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 16 Feb 2020 20:08:17 -0800 Subject: [PATCH] Working on CutMix impl as per #8, integrating with Mixup, currently experimenting... --- timm/data/__init__.py | 2 +- timm/data/dataset.py | 2 +- timm/data/mixup.py | 184 +++++++++++++++++++++++++++++++++++++++--- train.py | 11 ++- 4 files changed, 183 insertions(+), 16 deletions(-) diff --git a/timm/data/__init__.py b/timm/data/__init__.py index ee2240b4..8e261617 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -4,6 +4,6 @@ from .dataset import Dataset, DatasetTar, AugMixDataset from .transforms import * from .loader import create_loader from .transforms_factory import create_transform -from .mixup import mixup_batch, FastCollateMixup +from .mixup import mix_batch, FastCollateMixup, FastCollateMixupBatchwise, FastCollateMixupElementwise from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform diff --git a/timm/data/dataset.py b/timm/data/dataset.py index 2ce79e7e..4c6bef96 100644 --- a/timm/data/dataset.py +++ b/timm/data/dataset.py @@ -89,7 +89,7 @@ class Dataset(data.Dataset): return img, target def __len__(self): - return len(self.imgs) + return len(self.samples) def filenames(self, indices=[], basename=False): if indices: diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 4678472d..1b298af0 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -1,5 +1,30 @@ +""" Mixup and Cutmix + +Papers: +mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) + +CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) + +Code Reference: +CutMix: https://github.com/clovaai/CutMix-PyTorch + +Hacked together by Ross Wightman +""" + import numpy as np import torch +import math +from enum import IntEnum + + +class MixupMode(IntEnum): + MIXUP = 0 + CUTMIX = 1 + RANDOM = 2 + + @classmethod + def from_str(cls, value): + return cls[value.upper()] def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): @@ -12,7 +37,7 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): on_value = 1. - smoothing + off_value y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) - return lam*y1 + (1. - lam)*y2 + return y1 * lam + y2 * (1. - lam) def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): @@ -24,28 +49,167 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab return input, target +def rand_bbox(size, ratio): + H, W = size[-2:] + ratio = max(min(ratio, 0.8), 0.2) + cut_h, cut_w = int(H * ratio), int(W * ratio) + cy, cx = np.random.randint(H), np.random.randint(W) + yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H) + xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W) + return yl, yh, xl, xh + + +def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): + lam = 1. + if not disable: + lam = np.random.beta(alpha, alpha) + if lam != 1: + ratio = math.sqrt(1. - lam) + yl, yh, xl, xh = rand_bbox(input.size(), ratio) + input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] + target = mixup_target(target, num_classes, lam, smoothing) + return input, target + + +def _resolve_mode(mode): + mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode + if mode == MixupMode.RANDOM: + mode = MixupMode(np.random.rand() > 0.5) + return mode # will be one of cutmix or mixup + + +def mix_batch( + input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): + mode = _resolve_mode(mode) + if mode == MixupMode.CUTMIX: + return mixup_batch(input, target, alpha, num_classes, smoothing, disable) + else: + return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) + + class FastCollateMixup: + """Fast Collate Mixup that applies different params to each element + flipped pair - def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): + NOTE once experiments are done, one of the three variants will remain with this class name + """ + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): self.mixup_alpha = mixup_alpha self.label_smoothing = label_smoothing self.num_classes = num_classes + self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode self.mixup_enabled = True + self.correct_lam = False # correct lambda based on clipped area for cutmix + + def _do_mix(self, tensor, batch): + batch_size = len(batch) + lam_out = torch.ones(batch_size) + for i in range(batch_size//2): + j = batch_size - i - 1 + lam = 1. + if self.mixup_enabled: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + + if _resolve_mode(self.mode) == MixupMode.CUTMIX: + mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32) + ratio = math.sqrt(1. - lam) + if lam != 1: + yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) + mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + mixed_j[:, yl:yh, xl:xh] = batch[i][0][:, yl:yh, xl:xh].astype(np.float32) + if self.correct_lam: + lam_corrected = (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + lam_out[i] -= lam_corrected + lam_out[j] -= lam_corrected + else: + lam_out[i] = lam + lam_out[j] = lam + else: + mixed_i = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + mixed_j = batch[j][0].astype(np.float32) * lam + batch[i][0].astype(np.float32) * (1 - lam) + lam_out[i] = lam + lam_out[j] = lam + np.round(mixed_i, out=mixed_i) + np.round(mixed_j, out=mixed_j) + tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8)) + tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + return lam_out def __call__(self, batch): batch_size = len(batch) + assert batch_size % 2 == 0, 'Batch size should be even when using this' + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + lam = self._do_mix(tensor, batch) + target = torch.tensor([b[1] for b in batch], dtype=torch.int64) + target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu') + + return tensor, target + + +class FastCollateMixupElementwise(FastCollateMixup): + """Fast Collate Mixup that applies different params to each batch element + + NOTE this is for experimentation, may remove at some point + """ + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): + super(FastCollateMixupElementwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode) + + def _do_mix(self, tensor, batch): + batch_size = len(batch) + lam_out = torch.ones(batch_size) + for i in range(batch_size): + lam = 1. + if self.mixup_enabled: + lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + + if _resolve_mode(self.mode) == MixupMode.CUTMIX: + mixed = batch[i][0].astype(np.float32) + ratio = math.sqrt(1. - lam) + if lam != 1: + yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) + mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) + if self.correct_lam: + lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + else: + lam_out[i] = lam + else: + mixed = batch[i][0].astype(np.float32) * lam + \ + batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + lam_out[i] = lam + np.round(mixed, out=mixed) + tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) + return lam_out + + +class FastCollateMixupBatchwise(FastCollateMixup): + """Fast Collate Mixup that applies same params to whole batch + + NOTE this is for experimentation, may remove at some point + """ + + def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): + super(FastCollateMixupBatchwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode) + + def _do_mix(self, tensor, batch): + batch_size = len(batch) + lam_out = torch.ones(batch_size) lam = 1. + cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) + if cutmix and self.correct_lam: + ratio = math.sqrt(1. - lam) + yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio) + lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) - target = torch.tensor([b[1] for b in batch], dtype=torch.int64) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') - - tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) for i in range(batch_size): - mixed = batch[i][0].astype(np.float32) * lam + \ - batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + if cutmix: + mixed = batch[i][0].astype(np.float32) + if lam != 1: + mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) + lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + else: + mixed = batch[i][0].astype(np.float32) * lam + \ + batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) - - return tensor, target + return lam diff --git a/train.py b/train.py index 32beb418..f8fdf698 100755 --- a/train.py +++ b/train.py @@ -28,7 +28,8 @@ except ImportError: from torch.nn.parallel import DistributedDataParallel as DDP has_apex = False -from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset +from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mix_batch, AugMixDataset,\ + FastCollateMixupElementwise, FastCollateMixupBatchwise from timm.models import create_model, resume_checkpoint, convert_splitbn_model from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy @@ -134,6 +135,8 @@ parser.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') +parser.add_argument('--mixup-mode', type=str, default='mixup', + help='Mixup mode ("mixup", "cutmix", "random", default: "mixup")') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, @@ -352,7 +355,7 @@ def main(): collate_fn = None if args.prefetcher and args.mixup > 0: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) - collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) + collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes, args.mixup_mode) if num_aug_splits > 1: dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) @@ -504,10 +507,10 @@ def train_epoch( if not args.prefetcher: input, target = input.cuda(), target.cuda() if args.mixup > 0.: - input, target = mixup_batch( + input, target = mix_batch( input, target, alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, - disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch) + disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch, mode=args.mixup_mode) output = model(input)