Merge pull request #218 from rwightman/cutmix

CutMix + MixUp overhaul
pull/227/head
Ross Wightman 4 years ago committed by GitHub
commit b423bc8362
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ from .dataset import Dataset, DatasetTar, AugMixDataset
from .transforms import *
from .loader import create_loader
from .transforms_factory import create_transform
from .mixup import mixup_batch, FastCollateMixup
from .mixup import Mixup, FastCollateMixup
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .real_labels import RealLabelsImagenet

@ -92,7 +92,7 @@ class Dataset(data.Dataset):
return img, target
def __len__(self):
return len(self.imgs)
return len(self.samples)
def filenames(self, indices=[], basename=False):
if indices:

@ -1,5 +1,12 @@
""" Mixup
Paper: `mixup: Beyond Empirical Risk Minimization` - https://arxiv.org/abs/1710.09412
""" Mixup and Cutmix
Papers:
mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
Code Reference:
CutMix: https://github.com/clovaai/CutMix-PyTorch
Hacked together by / Copyright 2020 Ross Wightman
"""
@ -17,40 +24,230 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
on_value = 1. - smoothing + off_value
y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device)
y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device)
return lam*y1 + (1. - lam)*y2
return y1 * lam + y2 * (1. - lam)
def rand_bbox(img_shape, lam, margin=0., count=None):
""" Standard CutMix bounding-box
Generates a random square bbox based on lambda value. This impl includes
support for enforcing a border margin as percent of bbox dimensions.
Args:
img_shape (tuple): Image shape as tuple
lam (float): Cutmix lambda value
margin (float): Percentage of bbox dimension to enforce as margin (reduce amount of box outside image)
count (int): Number of bbox to generate
"""
ratio = np.sqrt(1 - lam)
img_h, img_w = img_shape[-2:]
cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
margin_y, margin_x = int(margin * cut_h), int(margin * cut_w)
cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
xl = np.clip(cx - cut_w // 2, 0, img_w)
xh = np.clip(cx + cut_w // 2, 0, img_w)
return yl, yh, xl, xh
def rand_bbox_minmax(img_shape, minmax, count=None):
""" Min-Max CutMix bounding-box
Inspired by Darknet cutmix impl, generates a random rectangular bbox
based on min/max percent values applied to each dimension of the input image.
Typical defaults for minmax are usually in the .2-.3 for min and .8-.9 range for max.
def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False):
lam = 1.
if not disable:
lam = np.random.beta(alpha, alpha)
input = input.mul(lam).add_(1 - lam, input.flip(0))
target = mixup_target(target, num_classes, lam, smoothing)
return input, target
Args:
img_shape (tuple): Image shape as tuple
minmax (tuple or list): Min and max bbox ratios (as percent of image size)
count (int): Number of bbox to generate
"""
assert len(minmax) == 2
img_h, img_w = img_shape[-2:]
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu
class FastCollateMixup:
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
""" Generate bbox and apply lambda correction.
"""
if ratio_minmax is not None:
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000):
class Mixup:
""" Mixup/Cutmix that applies different params to each element or whole batch
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
label_smoothing (float): apply label smoothing to the mixed target tensor
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mixup_enabled = True
self.elementwise = elementwise
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
def __call__(self, batch):
batch_size = len(batch)
lam = 1.
def _params_per_elem(self, batch_size):
lam = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size, dtype=np.bool)
if self.mixup_enabled:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size, dtype=np.bool)
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = np.where(np.random.rand(batch_size) < self.mix_prob, lam_mix.astype(np.float32), lam)
return lam, use_cutmix
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
def _params_per_batch(self):
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.mix_prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = float(lam_mix)
return lam, use_cutmix
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
def _mix_elem(self, x):
batch_size = len(x)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
x_orig = x.clone() # need to keep an unmodified original for mixing source
for i in range(batch_size):
mixed = batch[i][0].astype(np.float32) * lam + \
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
j = batch_size - i - 1
lam = lam_batch[i]
if lam != 1.:
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
def _mix_batch(self, x):
lam, use_cutmix = self._params_per_batch()
if lam == 1.:
return 1.
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
x.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
x[:, :, yl:yh, xl:xh] = x.flip(0)[:, :, yl:yh, xl:xh]
else:
x_flipped = x.flip(0).mul_(1. - lam)
x.mul_(lam).add_(x_flipped)
return lam
def __call__(self, x, target):
assert len(x) % 2 == 0, 'Batch size should be even when using this'
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
return x, target
class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
A Mixup impl that's performed while collating the batches.
"""
def _mix_elem_collate(self, output, batch):
batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size)
for i in range(batch_size):
j = batch_size - i - 1
lam = lam_batch[i]
mixed = batch[i][0]
if lam != 1.:
if use_cutmix[i]:
mixed = mixed.copy()
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
lam_batch[i] = lam
np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_batch_collate(self, output, batch):
batch_size = len(batch)
lam, use_cutmix = self._params_per_batch()
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
for i in range(batch_size):
j = batch_size - i - 1
mixed = batch[i][0]
if lam != 1.:
if use_cutmix:
mixed = mixed.copy() # don't want to modify the original while iterating
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam
def __call__(self, batch, _=None):
batch_size = len(batch)
assert batch_size % 2 == 0, 'Batch size should be even when using this'
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
if self.elementwise:
lam = self._mix_elem_collate(output, batch)
else:
lam = self._mix_batch_collate(output, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
return output, target
return tensor, target

@ -28,7 +28,7 @@ except ImportError:
from torch.nn.parallel import DistributedDataParallel as DDP
has_apex = False
from timm.data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_batch, AugMixDataset
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
@ -156,7 +156,17 @@ parser.add_argument('--recount', type=int, default=1,
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0,
help='Mixup alpha, mixup enabled if > 0. (default: 0.)')
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-elem', action='store_true', default=False,
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1,
@ -388,9 +398,18 @@ def main():
dataset_train = Dataset(train_dir)
collate_fn = None
if args.prefetcher and args.mixup > 0:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_args = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if args.prefetcher:
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(**mixup_args)
else:
mixup_fn = Mixup(**mixup_args)
if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
@ -452,17 +471,14 @@ def main():
if args.jsd:
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.mixup > 0.:
# smoothing is handled with mixup label transform
elif mixup_active:
# smoothing is handled with mixup target transform
train_loss_fn = SoftTargetCrossEntropy().cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
validate_loss_fn = nn.CrossEntropyLoss().cuda()
eval_metric = args.eval_metric
best_metric = None
@ -490,7 +506,7 @@ def main():
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema)
use_amp=use_amp, model_ema=model_ema, mixup_fn=mixup_fn)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
@ -530,11 +546,13 @@ def main():
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None, mixup_fn=None):
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
loader.mixup_enabled = False
elif mixup_fn is not None:
mixup_fn.mixup_enabled = False
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
@ -550,11 +568,8 @@ def train_epoch(
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if args.mixup > 0.:
input, target = mixup_batch(
input, target,
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing,
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
if mixup_fn is not None:
input, target = mixup_fn(input, target)
output = model(input)

Loading…
Cancel
Save