|
|
|
@ -96,13 +96,13 @@ class Mixup:
|
|
|
|
|
cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None.
|
|
|
|
|
prob (float): probability of applying mixup or cutmix per batch or element
|
|
|
|
|
switch_prob (float): probability of switching to cutmix instead of mixup when both are active
|
|
|
|
|
elementwise (bool): apply mixup/cutmix params per batch element instead of per batch
|
|
|
|
|
mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element)
|
|
|
|
|
correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders
|
|
|
|
|
label_smoothing (float): apply label smoothing to the mixed target tensor
|
|
|
|
|
num_classes (int): number of classes for target
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
|
|
|
|
|
elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
|
|
|
|
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
|
|
|
|
|
self.mixup_alpha = mixup_alpha
|
|
|
|
|
self.cutmix_alpha = cutmix_alpha
|
|
|
|
|
self.cutmix_minmax = cutmix_minmax
|
|
|
|
@ -114,7 +114,7 @@ class Mixup:
|
|
|
|
|
self.switch_prob = switch_prob
|
|
|
|
|
self.label_smoothing = label_smoothing
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.elementwise = elementwise
|
|
|
|
|
self.mode = mode
|
|
|
|
|
self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix
|
|
|
|
|
self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop)
|
|
|
|
|
|
|
|
|
@ -173,6 +173,26 @@ class Mixup:
|
|
|
|
|
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
|
|
|
|
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
def _mix_pair(self, x):
|
|
|
|
|
batch_size = len(x)
|
|
|
|
|
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
|
|
|
|
x_orig = x.clone() # need to keep an unmodified original for mixing source
|
|
|
|
|
for i in range(batch_size // 2):
|
|
|
|
|
j = batch_size - i - 1
|
|
|
|
|
lam = lam_batch[i]
|
|
|
|
|
if lam != 1.:
|
|
|
|
|
if use_cutmix[i]:
|
|
|
|
|
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
|
|
|
|
x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
|
|
|
|
x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh]
|
|
|
|
|
x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh]
|
|
|
|
|
lam_batch[i] = lam
|
|
|
|
|
else:
|
|
|
|
|
x[i] = x[i] * lam + x_orig[j] * (1 - lam)
|
|
|
|
|
x[j] = x[j] * lam + x_orig[i] * (1 - lam)
|
|
|
|
|
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
|
|
|
|
return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
def _mix_batch(self, x):
|
|
|
|
|
lam, use_cutmix = self._params_per_batch()
|
|
|
|
|
if lam == 1.:
|
|
|
|
@ -188,7 +208,12 @@ class Mixup:
|
|
|
|
|
|
|
|
|
|
def __call__(self, x, target):
|
|
|
|
|
assert len(x) % 2 == 0, 'Batch size should be even when using this'
|
|
|
|
|
lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x)
|
|
|
|
|
if self.mode == 'elem':
|
|
|
|
|
lam = self._mix_elem(x)
|
|
|
|
|
elif self.mode == 'pair':
|
|
|
|
|
lam = self._mix_pair(x)
|
|
|
|
|
else:
|
|
|
|
|
lam = self._mix_batch(x)
|
|
|
|
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
|
|
|
|
|
return x, target
|
|
|
|
|
|
|
|
|
@ -199,15 +224,18 @@ class FastCollateMixup(Mixup):
|
|
|
|
|
A Mixup impl that's performed while collating the batches.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def _mix_elem_collate(self, output, batch):
|
|
|
|
|
def _mix_elem_collate(self, output, batch, half=False):
|
|
|
|
|
batch_size = len(batch)
|
|
|
|
|
lam_batch, use_cutmix = self._params_per_elem(batch_size)
|
|
|
|
|
for i in range(batch_size):
|
|
|
|
|
num_elem = batch_size // 2 if half else batch_size
|
|
|
|
|
assert len(output) == num_elem
|
|
|
|
|
lam_batch, use_cutmix = self._params_per_elem(num_elem)
|
|
|
|
|
for i in range(num_elem):
|
|
|
|
|
j = batch_size - i - 1
|
|
|
|
|
lam = lam_batch[i]
|
|
|
|
|
mixed = batch[i][0]
|
|
|
|
|
if lam != 1.:
|
|
|
|
|
if use_cutmix[i]:
|
|
|
|
|
if not half:
|
|
|
|
|
mixed = mixed.copy()
|
|
|
|
|
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
|
|
|
|
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
|
|
|
@ -215,9 +243,38 @@ class FastCollateMixup(Mixup):
|
|
|
|
|
lam_batch[i] = lam
|
|
|
|
|
else:
|
|
|
|
|
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
|
|
|
|
lam_batch[i] = lam
|
|
|
|
|
np.round(mixed, out=mixed)
|
|
|
|
|
np.rint(mixed, out=mixed)
|
|
|
|
|
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
|
|
|
|
if half:
|
|
|
|
|
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
|
|
|
|
|
return torch.tensor(lam_batch).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
def _mix_pair_collate(self, output, batch):
|
|
|
|
|
batch_size = len(batch)
|
|
|
|
|
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
|
|
|
|
|
for i in range(batch_size // 2):
|
|
|
|
|
j = batch_size - i - 1
|
|
|
|
|
lam = lam_batch[i]
|
|
|
|
|
mixed_i = batch[i][0]
|
|
|
|
|
mixed_j = batch[j][0]
|
|
|
|
|
assert 0 <= lam <= 1.0
|
|
|
|
|
if lam < 1.:
|
|
|
|
|
if use_cutmix[i]:
|
|
|
|
|
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
|
|
|
|
|
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
|
|
|
|
|
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
|
|
|
|
|
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
|
|
|
|
|
mixed_j[:, yl:yh, xl:xh] = patch_i
|
|
|
|
|
lam_batch[i] = lam
|
|
|
|
|
else:
|
|
|
|
|
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
|
|
|
|
|
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
|
|
|
|
|
mixed_i = mixed_temp
|
|
|
|
|
np.rint(mixed_j, out=mixed_j)
|
|
|
|
|
np.rint(mixed_i, out=mixed_i)
|
|
|
|
|
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
|
|
|
|
|
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
|
|
|
|
|
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
|
|
|
|
|
return torch.tensor(lam_batch).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
def _mix_batch_collate(self, output, batch):
|
|
|
|
@ -235,19 +292,25 @@ class FastCollateMixup(Mixup):
|
|
|
|
|
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
|
|
|
|
|
else:
|
|
|
|
|
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
|
|
|
|
|
np.round(mixed, out=mixed)
|
|
|
|
|
np.rint(mixed, out=mixed)
|
|
|
|
|
output[i] += torch.from_numpy(mixed.astype(np.uint8))
|
|
|
|
|
return lam
|
|
|
|
|
|
|
|
|
|
def __call__(self, batch, _=None):
|
|
|
|
|
batch_size = len(batch)
|
|
|
|
|
assert batch_size % 2 == 0, 'Batch size should be even when using this'
|
|
|
|
|
half = 'half' in self.mode
|
|
|
|
|
if half:
|
|
|
|
|
batch_size //= 2
|
|
|
|
|
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
|
|
|
|
if self.elementwise:
|
|
|
|
|
lam = self._mix_elem_collate(output, batch)
|
|
|
|
|
if self.mode == 'elem' or self.mode == 'half':
|
|
|
|
|
lam = self._mix_elem_collate(output, batch, half=half)
|
|
|
|
|
elif self.mode == 'pair':
|
|
|
|
|
lam = self._mix_pair_collate(output, batch)
|
|
|
|
|
else:
|
|
|
|
|
lam = self._mix_batch_collate(output, batch)
|
|
|
|
|
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
|
|
|
|
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
|
|
|
|
|
target = target[:batch_size]
|
|
|
|
|
return output, target
|
|
|
|
|
|
|
|
|
|