More cutmix/mixup overhaul, ready to kick-off some trials.

pull/218/head
Ross Wightman 4 years ago
parent 92f2d0d65d
commit f471c17c9d

@ -15,17 +15,6 @@ import numpy as np
import torch import torch
import math import math
import numbers import numbers
from enum import IntEnum
class MixupMode(IntEnum):
MIXUP = 0
CUTMIX = 1
RANDOM = 2
@classmethod
def from_str(cls, value):
return cls[value.upper()]
def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'):
@ -50,30 +39,49 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
return input, target return input, target
def calc_ratio(lam, minmax=None): def rand_bbox(size, lam, border=0., count=None):
ratio = math.sqrt(1 - lam) ratio = math.sqrt(1 - lam)
if minmax is not None: img_h, img_w = size[-2:]
if isinstance(minmax, numbers.Number): cut_h, cut_w = int(img_h * ratio), int(img_w * ratio)
minmax = (minmax, 1 - minmax) margin_y, margin_x = int(border * cut_h), int(border * cut_w)
ratio = np.clip(ratio, minmax[0], minmax[1]) cy = np.random.randint(0 + margin_y, img_h - margin_y, size=count)
return ratio cx = np.random.randint(0 + margin_x, img_w - margin_x, size=count)
yl = np.clip(cy - cut_h // 2, 0, img_h)
yh = np.clip(cy + cut_h // 2, 0, img_h)
def rand_bbox(size, ratio): xl = np.clip(cx - cut_w // 2, 0, img_w)
H, W = size[-2:] xh = np.clip(cx + cut_w // 2, 0, img_w)
cut_h, cut_w = int(H * ratio), int(W * ratio)
cy, cx = np.random.randint(H), np.random.randint(W)
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
xl, xh = np.clip(cx - cut_w // 2, 0, W), np.clip(cx + cut_w // 2, 0, W)
return yl, yh, xl, xh return yl, yh, xl, xh
def rand_bbox_minmax(size, minmax, count=None):
assert len(minmax) == 2
img_h, img_w = size[-2:]
cut_h = np.random.randint(int(img_h * minmax[0]), int(img_h * minmax[1]), size=count)
cut_w = np.random.randint(int(img_w * minmax[0]), int(img_w * minmax[1]), size=count)
yl = np.random.randint(0, img_h - cut_h, size=count)
xl = np.random.randint(0, img_w - cut_w, size=count)
yu = yl + cut_h
xu = xl + cut_w
return yl, yu, xl, xu
def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, count=None):
if ratio_minmax is not None:
yl, yu, xl, xu = rand_bbox_minmax(img_shape, ratio_minmax, count=count)
else:
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False): def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False):
lam = 1. lam = 1.
if not disable: if not disable:
lam = np.random.beta(alpha, alpha) lam = np.random.beta(alpha, alpha)
if lam != 1: if lam != 1:
yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam)) yl, yh, xl, xh = rand_bbox(input.size(), lam)
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
if correct_lam: if correct_lam:
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1])
@ -81,101 +89,135 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa
return input, target return input, target
def _resolve_mode(mode):
mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
if mode == MixupMode.RANDOM:
mode = MixupMode(np.random.rand() > 0.7)
return mode # will be one of cutmix or mixup
def mix_batch( def mix_batch(
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): input, target, mixup_alpha=0.2, cutmix_alpha=0., prob=1.0, switch_prob=.5,
mode = _resolve_mode(mode) num_classes=1000, smoothing=0.1, disable=False):
if mode == MixupMode.CUTMIX: # FIXME test this version
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) if np.random.rand() > prob:
return input, target
use_cutmix = cutmix_alpha > 0. and np.random.rand() <= switch_prob
if use_cutmix:
return cutmix_batch(input, target, cutmix_alpha, num_classes, smoothing, disable)
else: else:
return mixup_batch(input, target, alpha, num_classes, smoothing, disable) return mixup_batch(input, target, mixup_alpha, num_classes, smoothing, disable)
class FastCollateMixup: class FastCollateMixup:
"""Fast Collate Mixup that applies different params to each element + flipped pair """Fast Collate Mixup/Cutmix that applies different params to each element or whole batch
NOTE once experiments are done, one of the three variants will remain with this class name NOTE once experiments are done, one of the three variants will remain with this class name
""" """
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP): def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
"""
Args:
mixup_alpha (float): mixup alpha value, mixup is active if > 0.
cutmix_alpha (float): cutmix alpha value, cutmix is active if > 0.
cutmix_minmax (float): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None
prob (float): probability of applying mixup or cutmix per batch or element
switch_prob (float): probability of using cutmix instead of mixup when both active
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
label_smoothing (float):
num_classes (int):
"""
self.mixup_alpha = mixup_alpha self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
if self.cutmix_minmax is not None:
assert len(self.cutmix_minmax) == 2
# force cutmix alpha == 1.0 when minmax active to keep logic simple & safe
self.cutmix_alpha = 1.0
self.prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing self.label_smoothing = label_smoothing
self.num_classes = num_classes self.num_classes = num_classes
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode self.elementwise = elementwise
self.mixup_enabled = True self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
self.correct_lam = True # correct lambda based on clipped area for cutmix self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
self.ratio_minmax = None # (0.2, 0.8)
def _do_mix(self, tensor, batch): def _mix_elem(self, output, batch):
batch_size = len(batch) batch_size = len(batch)
lam_out = torch.ones(batch_size) lam_out = np.ones(batch_size)
use_cutmix = np.zeros(batch_size).astype(np.bool)
if self.mixup_enabled:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand(batch_size) < self.switch_prob
lam_mix = np.where(
use_cutmix,
np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size),
np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size))
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha, size=batch_size)
elif self.cutmix_alpha > 0.:
use_cutmix = np.ones(batch_size).astype(np.bool)
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out)
for i in range(batch_size): for i in range(batch_size):
j = batch_size - i - 1 j = batch_size - i - 1
lam = 1. lam = lam_out[i]
if self.mixup_enabled: mixed = batch[i][0].astype(np.float32)
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) if lam != 1.:
if use_cutmix[i]:
if _resolve_mode(self.mode) == MixupMode.CUTMIX: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
mixed = batch[i][0].astype(np.float32) output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
if lam != 1:
ratio = calc_ratio(lam)
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
if self.correct_lam: lam_out[i] = lam
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) else:
else: mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam)
lam_out[i] = lam lam_out[i] = lam
np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return torch.tensor(lam_out).unsqueeze(1)
def _mix_batch(self, output, batch):
batch_size = len(batch)
lam = 1.
use_cutmix = False
if self.mixup_enabled and np.random.rand() < self.prob:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
use_cutmix = np.random.rand() < self.switch_prob
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) if use_cutmix else \
np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.mixup_alpha > 0.:
lam_mix = np.random.beta(self.mixup_alpha, self.mixup_alpha)
elif self.cutmix_alpha > 0.:
use_cutmix = True
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else: else:
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam_out[i] = lam lam = lam_mix
np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) if use_cutmix:
return lam_out.unsqueeze(1) (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
for i in range(batch_size):
j = batch_size - i - 1
mixed = batch[i][0].astype(np.float32)
if lam != 1.:
if use_cutmix:
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
else:
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam
def __call__(self, batch): def __call__(self, batch):
batch_size = len(batch) batch_size = len(batch)
assert batch_size % 2 == 0, 'Batch size should be even when using this' assert batch_size % 2 == 0, 'Batch size should be even when using this'
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
lam = self._do_mix(tensor, batch) if self.elementwise:
lam = self._mix_elem(output, batch)
else:
lam = self._mix_batch(output, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
return tensor, target return output, target
class FastCollateMixupBatchwise(FastCollateMixup):
"""Fast Collate Mixup that applies same params to whole batch
NOTE this is for experimentation, may remove at some point
"""
def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000, mode=MixupMode.MIXUP):
super(FastCollateMixupBatchwise, self).__init__(mixup_alpha, label_smoothing, num_classes, mode)
def _do_mix(self, tensor, batch):
batch_size = len(batch)
lam = 1.
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
if self.mixup_enabled:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
if cutmix:
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam))
if self.correct_lam:
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
for i in range(batch_size):
j = batch_size - i - 1
if cutmix:
mixed = batch[i][0].astype(np.float32)
if lam != 1:
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
else:
mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam

@ -157,8 +157,16 @@ parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split') help='Do not random erase first (clean) augmentation split')
parser.add_argument('--mixup', type=float, default=0.0, parser.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)') help='mixup alpha, mixup enabled if > 0. (default: 0.)')
parser.add_argument('--mixup-mode', type=str, default='mixup', parser.add_argument('--cutmix', type=float, default=0.0,
help='Mixup mode. One of "mixup", "cutmix", "random" (default: "mixup")') help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-elem', action='store_true', default=False,
help='Apply mixup/cutmix params uniquely per batch element instead of per batch.')
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)') help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
parser.add_argument('--smoothing', type=float, default=0.1, parser.add_argument('--smoothing', type=float, default=0.1,
@ -390,9 +398,12 @@ def main():
dataset_train = Dataset(train_dir) dataset_train = Dataset(train_dir)
collate_fn = None collate_fn = None
if args.prefetcher and args.mixup > 0: if args.prefetcher and (args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None):
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup) assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes, args.mixup_mode) collate_fn = FastCollateMixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if num_aug_splits > 1: if num_aug_splits > 1:
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits) dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
@ -555,8 +566,9 @@ def train_epoch(
if args.mixup > 0.: if args.mixup > 0.:
input, target = mix_batch( input, target = mix_batch(
input, target, input, target,
alpha=args.mixup, num_classes=args.num_classes, smoothing=args.smoothing, mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, prob=args.mixup_prob,
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch, mode=args.mixup_mode) switch_prob=args.mixup_switch_prob, num_classes=args.num_classes, smoothing=args.smoothing,
disable=args.mixup_off_epoch and epoch >= args.mixup_off_epoch)
output = model(input) output = model(input)

Loading…
Cancel
Save