Some cutmix/mixup cleanup/fixes

pull/218/head
Ross Wightman 4 years ago
parent b3cb5f3275
commit 670c61b28f

@ -14,6 +14,7 @@ Hacked together by Ross Wightman
import numpy as np import numpy as np
import torch import torch
import math import math
import numbers
from enum import IntEnum from enum import IntEnum
@ -49,9 +50,17 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
return input, target return input, target
def calc_ratio(lam, minmax=None):
ratio = math.sqrt(1 - lam)
if minmax is not None:
if isinstance(minmax, numbers.Number):
minmax = (minmax, 1 - minmax)
ratio = np.clip(ratio, minmax[0], minmax[1])
return ratio
def rand_bbox(size, ratio): def rand_bbox(size, ratio):
H, W = size[-2:] H, W = size[-2:]
ratio = max(min(ratio, 0.8), 0.2)
cut_h, cut_w = int(H * ratio), int(W * ratio) cut_h, cut_w = int(H * ratio), int(W * ratio)
cy, cx = np.random.randint(H), np.random.randint(W) cy, cx = np.random.randint(H), np.random.randint(W)
yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H) yl, yh = np.clip(cy - cut_h // 2, 0, H), np.clip(cy + cut_h // 2, 0, H)
@ -59,14 +68,15 @@ def rand_bbox(size, ratio):
return yl, yh, xl, xh return yl, yh, xl, xh
def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, correct_lam=False):
lam = 1. lam = 1.
if not disable: if not disable:
lam = np.random.beta(alpha, alpha) lam = np.random.beta(alpha, alpha)
if lam != 1: if lam != 1:
ratio = math.sqrt(1. - lam) yl, yh, xl, xh = rand_bbox(input.size(), calc_ratio(lam))
yl, yh, xl, xh = rand_bbox(input.size(), ratio)
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
if correct_lam:
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1])
target = mixup_target(target, num_classes, lam, smoothing) target = mixup_target(target, num_classes, lam, smoothing)
return input, target return input, target
@ -82,9 +92,9 @@ def mix_batch(
input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP): input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False, mode=MixupMode.MIXUP):
mode = _resolve_mode(mode) mode = _resolve_mode(mode)
if mode == MixupMode.CUTMIX: if mode == MixupMode.CUTMIX:
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
else:
return cutmix_batch(input, target, alpha, num_classes, smoothing, disable) return cutmix_batch(input, target, alpha, num_classes, smoothing, disable)
else:
return mixup_batch(input, target, alpha, num_classes, smoothing, disable)
class FastCollateMixup: class FastCollateMixup:
@ -99,6 +109,7 @@ class FastCollateMixup:
self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode self.mode = MixupMode.from_str(mode) if isinstance(mode, str) else mode
self.mixup_enabled = True self.mixup_enabled = True
self.correct_lam = False # correct lambda based on clipped area for cutmix self.correct_lam = False # correct lambda based on clipped area for cutmix
self.ratio_minmax = None # (0.2, 0.8)
def _do_mix(self, tensor, batch): def _do_mix(self, tensor, batch):
batch_size = len(batch) batch_size = len(batch)
@ -111,7 +122,7 @@ class FastCollateMixup:
if _resolve_mode(self.mode) == MixupMode.CUTMIX: if _resolve_mode(self.mode) == MixupMode.CUTMIX:
mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32) mixed_i, mixed_j = batch[i][0].astype(np.float32), batch[j][0].astype(np.float32)
ratio = math.sqrt(1. - lam) ratio = calc_ratio(lam, self.ratio_minmax)
if lam != 1: if lam != 1:
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) mixed_i[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
@ -132,7 +143,7 @@ class FastCollateMixup:
np.round(mixed_j, out=mixed_j) np.round(mixed_j, out=mixed_j)
tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8)) tensor[i] += torch.from_numpy(mixed_i.astype(np.uint8))
tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8)) tensor[j] += torch.from_numpy(mixed_j.astype(np.uint8))
return lam_out return lam_out.unsqueeze(1)
def __call__(self, batch): def __call__(self, batch):
batch_size = len(batch) batch_size = len(batch)
@ -140,7 +151,7 @@ class FastCollateMixup:
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
lam = self._do_mix(tensor, batch) lam = self._do_mix(tensor, batch)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam.unsqueeze(1), self.label_smoothing, device='cpu') target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
return tensor, target return tensor, target
@ -157,27 +168,27 @@ class FastCollateMixupElementwise(FastCollateMixup):
batch_size = len(batch) batch_size = len(batch)
lam_out = torch.ones(batch_size) lam_out = torch.ones(batch_size)
for i in range(batch_size): for i in range(batch_size):
j = batch_size - i - 1
lam = 1. lam = 1.
if self.mixup_enabled: if self.mixup_enabled:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
if _resolve_mode(self.mode) == MixupMode.CUTMIX: if _resolve_mode(self.mode) == MixupMode.CUTMIX:
mixed = batch[i][0].astype(np.float32) mixed = batch[i][0].astype(np.float32)
ratio = math.sqrt(1. - lam)
if lam != 1: if lam != 1:
ratio = calc_ratio(lam)
yl, yh, xl, xh = rand_bbox(tensor.size(), ratio) yl, yh, xl, xh = rand_bbox(tensor.size(), ratio)
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
if self.correct_lam: if self.correct_lam:
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
else: else:
lam_out[i] = lam lam_out[i] = lam
else: else:
mixed = batch[i][0].astype(np.float32) * lam + \ mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
lam_out[i] = lam lam_out[i] = lam
np.round(mixed, out=mixed) np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam_out return lam_out.unsqueeze(1)
class FastCollateMixupBatchwise(FastCollateMixup): class FastCollateMixupBatchwise(FastCollateMixup):
@ -191,25 +202,23 @@ class FastCollateMixupBatchwise(FastCollateMixup):
def _do_mix(self, tensor, batch): def _do_mix(self, tensor, batch):
batch_size = len(batch) batch_size = len(batch)
lam_out = torch.ones(batch_size)
lam = 1. lam = 1.
cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX cutmix = _resolve_mode(self.mode) == MixupMode.CUTMIX
if self.mixup_enabled: if self.mixup_enabled:
lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) lam = np.random.beta(self.mixup_alpha, self.mixup_alpha)
if cutmix and self.correct_lam: if cutmix:
ratio = math.sqrt(1. - lam) yl, yh, xl, xh = rand_bbox(batch[0][0].shape, calc_ratio(lam))
yl, yh, xl, xh = rand_bbox(batch[0][0].shape, ratio) if self.correct_lam:
lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1]) lam = 1 - (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
for i in range(batch_size): for i in range(batch_size):
j = batch_size - i - 1
if cutmix: if cutmix:
mixed = batch[i][0].astype(np.float32) mixed = batch[i][0].astype(np.float32)
if lam != 1: if lam != 1:
mixed[:, yl:yh, xl:xh] = batch[batch_size - i - 1][0][:, yl:yh, xl:xh].astype(np.float32) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32)
lam_out[i] -= (yh - yl) * (xh - xl) / (tensor.shape[-2] * tensor.shape[-1])
else: else:
mixed = batch[i][0].astype(np.float32) * lam + \ mixed = batch[i][0].astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed) np.round(mixed, out=mixed)
tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) tensor[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam return lam

Loading…
Cancel
Save