From cd23f553974f0cda6b07761ba2d4a88e82236966 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 11 Aug 2020 12:17:43 -0700 Subject: [PATCH] Fix mixed prec issues with new mixup code --- timm/data/mixup.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/timm/data/mixup.py b/timm/data/mixup.py index de19c616..a018ea07 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -72,7 +72,7 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) if correct_lam or ratio_minmax is not None: bbox_area = (yu - yl) * (xu - xl) - lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1]) + lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1]) return (yl, yu, xl, xu), lam @@ -84,7 +84,7 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa yl, yh, xl, xh = rand_bbox(input.size(), lam) input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] if correct_lam: - lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) + lam = 1. - (yh - yl) * (xh - xl) / float(input.shape[-2] * input.shape[-1]) target = mixup_target(target, num_classes, lam, smoothing) return input, target @@ -139,7 +139,7 @@ class FastCollateMixup: def _mix_elem(self, output, batch): batch_size = len(batch) - lam_out = np.ones(batch_size) + lam_out = np.ones(batch_size, dtype=np.float32) use_cutmix = np.zeros(batch_size).astype(np.bool) if self.mixup_enabled: if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: @@ -155,22 +155,23 @@ class FastCollateMixup: lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) else: assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out) + lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix.astype(np.float32), lam_out) for i in range(batch_size): j = batch_size - i - 1 lam = lam_out[i] - mixed = batch[i][0].astype(np.float32) + mixed = batch[i][0] if lam != 1.: if use_cutmix[i]: + mixed = mixed.copy() (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) - mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] lam_out[i] = lam else: - mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) lam_out[i] = lam - np.round(mixed, out=mixed) + np.round(mixed, out=mixed) output[i] += torch.from_numpy(mixed.astype(np.uint8)) return torch.tensor(lam_out).unsqueeze(1) @@ -190,7 +191,7 @@ class FastCollateMixup: lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) else: assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." - lam = lam_mix + lam = float(lam_mix) if use_cutmix: (yl, yh, xl, xh), lam = cutmix_bbox_and_lam( @@ -198,13 +199,14 @@ class FastCollateMixup: for i in range(batch_size): j = batch_size - i - 1 - mixed = batch[i][0].astype(np.float32) + mixed = batch[i][0] if lam != 1.: if use_cutmix: - mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) + mixed = mixed.copy() + mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh] else: - mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) - np.round(mixed, out=mixed) + mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam) + np.round(mixed, out=mixed) output[i] += torch.from_numpy(mixed.astype(np.uint8)) return lam