diff --git a/data/loader.py b/data/loader.py index 71ff3c9f..97c57af0 100644 --- a/data/loader.py +++ b/data/loader.py @@ -34,6 +34,7 @@ class PrefetchLoader: def __iter__(self): stream = torch.cuda.Stream() first = True + curr_input = None for next_input, next_target in self.loader: with torch.cuda.stream(stream): @@ -44,15 +45,15 @@ class PrefetchLoader: next_input = self.random_erasing(next_input) if not first: - yield input, target + yield curr_input, target else: first = False torch.cuda.current_stream().wait_stream(stream) - input = next_input + curr_input = next_input target = next_target - yield input, target + yield curr_input, target def __len__(self): return len(self.loader)