Fix mixed prec issues with new mixup code

pull/218/head
Ross Wightman 5 years ago
parent f471c17c9d
commit cd23f55397

@ -72,7 +72,7 @@ def cutmix_bbox_and_lam(img_shape, lam, ratio_minmax=None, correct_lam=True, cou
yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count) yl, yu, xl, xu = rand_bbox(img_shape, lam, count=count)
if correct_lam or ratio_minmax is not None: if correct_lam or ratio_minmax is not None:
bbox_area = (yu - yl) * (xu - xl) bbox_area = (yu - yl) * (xu - xl)
lam = 1. - bbox_area / (img_shape[-2] * img_shape[-1]) lam = 1. - bbox_area / float(img_shape[-2] * img_shape[-1])
return (yl, yu, xl, xu), lam return (yl, yu, xl, xu), lam
@ -84,7 +84,7 @@ def cutmix_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disa
yl, yh, xl, xh = rand_bbox(input.size(), lam) yl, yh, xl, xh = rand_bbox(input.size(), lam)
input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh] input[:, :, yl:yh, xl:xh] = input.flip(0)[:, :, yl:yh, xl:xh]
if correct_lam: if correct_lam:
lam = 1 - (yh - yl) * (xh - xl) / (input.shape[-2] * input.shape[-1]) lam = 1. - (yh - yl) * (xh - xl) / float(input.shape[-2] * input.shape[-1])
target = mixup_target(target, num_classes, lam, smoothing) target = mixup_target(target, num_classes, lam, smoothing)
return input, target return input, target
@ -139,7 +139,7 @@ class FastCollateMixup:
def _mix_elem(self, output, batch): def _mix_elem(self, output, batch):
batch_size = len(batch) batch_size = len(batch)
lam_out = np.ones(batch_size) lam_out = np.ones(batch_size, dtype=np.float32)
use_cutmix = np.zeros(batch_size).astype(np.bool) use_cutmix = np.zeros(batch_size).astype(np.bool)
if self.mixup_enabled: if self.mixup_enabled:
if self.mixup_alpha > 0. and self.cutmix_alpha > 0.: if self.mixup_alpha > 0. and self.cutmix_alpha > 0.:
@ -155,22 +155,23 @@ class FastCollateMixup:
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size) lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha, size=batch_size)
else: else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix, lam_out) lam_out = np.where(np.random.rand(batch_size) < self.prob, lam_mix.astype(np.float32), lam_out)
for i in range(batch_size): for i in range(batch_size):
j = batch_size - i - 1 j = batch_size - i - 1
lam = lam_out[i] lam = lam_out[i]
mixed = batch[i][0].astype(np.float32) mixed = batch[i][0]
if lam != 1.: if lam != 1.:
if use_cutmix[i]: if use_cutmix[i]:
mixed = mixed.copy()
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam( (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam) output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_out[i] = lam lam_out[i] = lam
else: else:
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
lam_out[i] = lam lam_out[i] = lam
np.round(mixed, out=mixed) np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8)) output[i] += torch.from_numpy(mixed.astype(np.uint8))
return torch.tensor(lam_out).unsqueeze(1) return torch.tensor(lam_out).unsqueeze(1)
@ -190,7 +191,7 @@ class FastCollateMixup:
lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha) lam_mix = np.random.beta(self.cutmix_alpha, self.cutmix_alpha)
else: else:
assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true." assert False, "One of mixup_alpha > 0., cutmix_alpha > 0., cutmix_minmax not None should be true."
lam = lam_mix lam = float(lam_mix)
if use_cutmix: if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam( (yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
@ -198,13 +199,14 @@ class FastCollateMixup:
for i in range(batch_size): for i in range(batch_size):
j = batch_size - i - 1 j = batch_size - i - 1
mixed = batch[i][0].astype(np.float32) mixed = batch[i][0]
if lam != 1.: if lam != 1.:
if use_cutmix: if use_cutmix:
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh].astype(np.float32) mixed = mixed.copy()
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else: else:
mixed = mixed * lam + batch[j][0].astype(np.float32) * (1 - lam) mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.round(mixed, out=mixed) np.round(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8)) output[i] += torch.from_numpy(mixed.astype(np.uint8))
return lam return lam

Loading…
Cancel
Save