diff --git a/timm/data/prefetcher_cuda.py b/timm/data/prefetcher_cuda.py index 0b36c027..3ae52dc2 100644 --- a/timm/data/prefetcher_cuda.py +++ b/timm/data/prefetcher_cuda.py @@ -34,7 +34,7 @@ class PrefetcherCuda: self.std = None if re_prob > 0.: self.random_erasing = RandomErasing( - probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits, device=device) + probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits) else: self.random_erasing = None