""" Mixup Paper: `mixup: Beyond Empirical Risk Minimization` - https://arxiv.org/abs/1710.09412 Hacked together by / Copyright 2020 Ross Wightman """ import numpy as np import torch def one_hot(x, num_classes, on_value=1., off_value=0., device='cuda'): x = x.long().view(-1, 1) return torch.full((x.size()[0], num_classes), off_value, device=device).scatter_(1, x, on_value) def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'): off_value = smoothing / num_classes on_value = 1. - smoothing + off_value y1 = one_hot(target, num_classes, on_value=on_value, off_value=off_value, device=device) y2 = one_hot(target.flip(0), num_classes, on_value=on_value, off_value=off_value, device=device) return lam*y1 + (1. - lam)*y2 def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disable=False): lam = 1. if not disable: lam = np.random.beta(alpha, alpha) input = input.mul(lam).add_(1 - lam, input.flip(0)) target = mixup_target(target, num_classes, lam, smoothing) return input, target class FastCollateMixup: def __init__(self, mixup_alpha=1., label_smoothing=0.1, num_classes=1000): self.mixup_alpha = mixup_alpha self.label_smoothing = label_smoothing self.num_classes = num_classes self.mixup_enabled = True def __call__(self, batch): batch_size = len(batch) lam = 1. if self.mixup_enabled: lam = np.random.beta(self.mixup_alpha, self.mixup_alpha) target = torch.tensor([b[1] for b in batch], dtype=torch.int64) target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu') tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) for i in range(batch_size): mixed = batch[i][0].astype(np.float32) * lam + \ batch[batch_size - i - 1][0].astype(np.float32) * (1 - lam) np.round(mixed, out=mixed) tensor[i] += torch.from_numpy(mixed.astype(np.uint8)) return tensor, target