From 670c61b28ffd123267df4126fc700fd8f2837d22 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 17 Feb 2020 11:00:54 -0800 Subject: [PATCH] Some cutmix/mixup cleanup/fixes --- timm/data/mixup.py | 55 +++++++++++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 23 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 1b298af0..a59fa1f3 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -14,6 +14,7 @@ Hacked together by Ross Wightman import numpy as np import torch import math +import numbers from enum import IntEnum @@ -49,9 +50,17 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab return input, target +def calc_ratio(lam, minmax=None): + ratio = math.sqrt(1 - lam) + if minmax is not None: + if isinstance(minmax, numbers.Number): + minmax = (minmax, 1 - minmax) + ratio = np.clip(ratio, minmax[0], minmax[1]) + return ratio + + def rand_bbox(size, ratio): H, W = size[-2:] - ratio = max(min(ratio, 0.8), 0.2) cut_h, cut_w = int(H * ratio), int(W * ratio) cy, cx = np.random.randint(H), np.random.randint(W) yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H) @@ -59,14 +68,15 @@ def rand_bbox(size, ratio): return yl, yh, xl, xh -def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): +def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False): lam = 1. if not disable: lam = np.random.beta(alpha, alpha) if lam != 1: - ratio = math.sqrt(1. - lam) - yl, yh, xl, xh = rand_bbox(input.size(), ratio) + yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam)) input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] + if correct_lam: + lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) target = mixup_target(target, num_classes, lam, smoothing) return input, target @@ -82,9 +92,9 @@ def mix_batch( input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): mode = _resolve_mode(mode) if mode == MixupMode.CUTMIX: - return mixup_batch(input, target, alpha, num_classes, smoothing, disable) - else: return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) + else: + return mixup_batch(input, target, alpha, num_classes, smoothing, disable) class FastCollateMixup: @@ -99,6 +109,7 @@ class FastCollateMixup: self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode self.mixup_enabled = True self.correct_lam = False # correct lambda based on clipped area for cutmix + self.ratio_minmax = None # (0.2, 0.8) def _do_mix(self, tensor, batch): batch_size = len(batch) @@ -111,7 +122,7 @@ class FastCollateMixup: if _resolve_mode(self.mode) == MixupMode.CUTMIX: mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32) - ratio = math.sqrt(1. - lam) + ratio = calc_ratio(lam, self.ratio_minmax) if lam != 1: yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) @@ -132,7 +143,7 @@ class FastCollateMixup: np.round(mixed_j, out=mixed_j) tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8)) tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8)) - return lam_out + return lam_out.unsqueeze(1) def __call__(self, batch): batch_size = len(batch) @@ -140,7 +151,7 @@ class FastCollateMixup: tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) lam = self._do_mix(tensor, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) - target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu') + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') return tensor, target @@ -157,27 +168,27 @@ class FastCollateMixupElementwise(FastCollateMixup): batch_size = len(batch) lam_out = torch.ones(batch_size) for i in range(batch_size): + j = batch_size - i - 1 lam = 1. if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) if _resolve_mode(self.mode) == MixupMode.CUTMIX: mixed = batch[i][0].astype(np.float32) - ratio = math.sqrt(1. - lam) if lam != 1: + ratio = calc_ratio(lam) yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) - mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) if self.correct_lam: lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) else: lam_out[i] = lam else: - mixed = batch[i][0].astype(np.float32) * lam + \ - batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) lam_out[i] = lam np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) - return lam_out + return lam_out.unsqueeze(1) class FastCollateMixupBatchwise(FastCollateMixup): @@ -191,25 +202,23 @@ class FastCollateMixupBatchwise(FastCollateMixup): def _do_mix(self, tensor, batch): batch_size = len(batch) - lam_out = torch.ones(batch_size) lam = 1. cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) - if cutmix and self.correct_lam: - ratio = math.sqrt(1. - lam) - yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio) - lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + if cutmix: + yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam)) + if self.correct_lam: + lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) for i in range(batch_size): + j = batch_size - i - 1 if cutmix: mixed = batch[i][0].astype(np.float32) if lam != 1: - mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) - lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) else: - mixed = batch[i][0].astype(np.float32) * lam + \ - batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) + mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) return lam