diff --git a/timm/data/mixup.py b/timm/data/mixup.py index 38477548..7e382c52 100644 --- a/timm/data/mixup.py +++ b/timm/data/mixup.py @@ -214,7 +214,7 @@ class Mixup: lam = self._mix_pair(x) else: lam = self._mix_batch(x) - target = mixup_target(target, self.num_classes, lam, self.label_smoothing) + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, x.device) return x, target