@ -214,7 +214,7 @@ class Mixup:
lam = self._mix_pair(x)
else:
lam = self._mix_batch(x)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device)
return x, target