diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 63861bc7..38477548 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -96,13 +96,13 @@ class Mixup: cutmix_minmax (List[float]): cutmix min/max image ratio, cutmix is active and uses this vs alpha if not None. prob (float): probability of applying mixup or cutmix per batch or element switch_prob (float): probability of switching to cutmix instead of mixup when both are active - elementwise (bool): apply mixup/cutmix params per batch element instead of per batch + mode (str): how to apply mixup/cutmix params (per 'batch', 'pair' (pair of elements), 'elem' (element) correct_lam (bool): apply lambda correction when cutmix bbox clipped by image borders label_smoothing (float): apply label smoothing to the mixed target tensor num_classes (int): number of classes for target """ def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5, - elementwise=False, correct_lam=True, label_smoothing=0.1, num_classes=1000): + mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000): self.mixup_alpha = mixup_alpha self.cutmix_alpha = cutmix_alpha self.cutmix_minmax = cutmix_minmax @@ -114,7 +114,7 @@ class Mixup: self.switch_prob = switch_prob self.label_smoothing = label_smoothing self.num_classes = num_classes - self.elementwise = elementwise + self.mode = mode self.correct_lam = correct_lam # correct lambda based on clipped area for cutmix self.mixup_enabled = True # set to false to disable mixing (intended tp be set by train loop) @@ -173,6 +173,26 @@ class Mixup: x[i] = x[i] * lam + x_orig[j] * (1 - lam) return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + def _mix_pair(self, x): + batch_size = len(x) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + x_orig = x.clone() # need to keep an unmodified original for mixing source + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + if lam != 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + x[i].shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + x[i][:, yl:yh, xl:xh] = x_orig[j][:, yl:yh, xl:xh] + x[j][:, yl:yh, xl:xh] = x_orig[i][:, yl:yh, xl:xh] + lam_batch[i] = lam + else: + x[i] = x[i] * lam + x_orig[j] * (1 - lam) + x[j] = x[j] * lam + x_orig[i] * (1 - lam) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) + return torch.tensor(lam_batch, device=x.device, dtype=x.dtype).unsqueeze(1) + def _mix_batch(self, x): lam, use_cutmix = self._params_per_batch() if lam == 1.: @@ -188,7 +208,12 @@ class Mixup: def __call__(self, x, target): assert len(x) % 2 == 0, 'Batch size should be even when using this' - lam = self._mix_elem(x) if self.elementwise else self._mix_batch(x) + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) target = mixup_target(target, self.num_classes, lam, self.label_smoothing) return x, target @@ -199,25 +224,57 @@ class FastCollateMixup(Mixup): A Mixup impl that's performed while collating the batches. """ - def _mix_elem_collate(self, output, batch): + def _mix_elem_collate(self, output, batch, half=False): batch_size = len(batch) - lam_batch, use_cutmix = self._params_per_elem(batch_size) - for i in range(batch_size): + num_elem = batch_size // 2 if half else batch_size + assert len(output) == num_elem + lam_batch, use_cutmix = self._params_per_elem(num_elem) + for i in range(num_elem): j = batch_size - i - 1 lam = lam_batch[i] mixed = batch[i][0] if lam != 1.: if use_cutmix[i]: - mixed = mixed.copy() + if not half: + mixed = mixed.copy() (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] lam_batch[i] = lam else: mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - lam_batch[i] = lam - np.round(mixed, out=mixed) + np.rint(mixed, out=mixed) output[i] += torch.from_numpy(mixed.astype(np.uint8)) + if half: + lam_batch = np.concatenate((lam_batch, np.ones(num_elem))) + return torch.tensor(lam_batch).unsqueeze(1) + + def _mix_pair_collate(self, output, batch): + batch_size = len(batch) + lam_batch, use_cutmix = self._params_per_elem(batch_size // 2) + for i in range(batch_size // 2): + j = batch_size - i - 1 + lam = lam_batch[i] + mixed_i = batch[i][0] + mixed_j = batch[j][0] + assert 0 <= lam <= 1.0 + if lam < 1.: + if use_cutmix[i]: + (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( + output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) + patch_i = mixed_i[:, yl:yh, xl:xh].copy() + mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh] + mixed_j[:, yl:yh, xl:xh] = patch_i + lam_batch[i] = lam + else: + mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam) + mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam) + mixed_i = mixed_temp + np.rint(mixed_j, out=mixed_j) + np.rint(mixed_i, out=mixed_i) + output[i] += torch.from_numpy(mixed_i.astype(np.uint8)) + output[j] += torch.from_numpy(mixed_j.astype(np.uint8)) + lam_batch = np.concatenate((lam_batch, lam_batch[::-1])) return torch.tensor(lam_batch).unsqueeze(1) def _mix_batch_collate(self, output, batch): @@ -235,19 +292,25 @@ class FastCollateMixup(Mixup): mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] else: mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.round(mixed, out=mixed) + np.rint(mixed, out=mixed) output[i] += torch.from_numpy(mixed.astype(np.uint8)) return lam def __call__(self, batch, _=None): batch_size = len(batch) assert batch_size % 2 == 0, 'Batch size should be even when using this' + half = 'half' in self.mode + if half: + batch_size //= 2 output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) - if self.elementwise: - lam = self._mix_elem_collate(output, batch) + if self.mode == 'elem' or self.mode == 'half': + lam = self._mix_elem_collate(output, batch, half=half) + elif self.mode == 'pair': + lam = self._mix_pair_collate(output, batch) else: lam = self._mix_batch_collate(output, batch) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') + target = target[:batch_size] return output, target diff --git a/train.py b/train.py index 023a51ca..3a235faf 100755 --- a/train.py +++ b/train.py @@ -176,8 +176,8 @@ parser.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') parser.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') -parser.add_argument('--mixup-elem', action='store_true', default=False, - help='Apply mixup/cutmix params uniquely per batch element instead of per batch.') +parser.add_argument('--mixup-mode', type=str, default='batch', + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='Turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, @@ -444,7 +444,7 @@ def main(): if mixup_active: mixup_args = dict( mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax, - prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, elementwise=args.mixup_elem, + prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode, label_smoothing=args.smoothing, num_classes=args.num_classes) if args.prefetcher: assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)