import torch.cuda from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .mixup import FastCollateMixup from .random_erasing import RandomErasing class PrefetcherCuda: def __init__(self, loader, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, fp16=False, re_prob=0., re_mode='const', re_count=1, re_num_splits=0): self.loader = loader self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) self.fp16 = fp16 if fp16: self.mean = self.mean.half() self.std = self.std.half() if re_prob > 0.: self.random_erasing = RandomErasing( probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits) else: self.random_erasing = None def __iter__(self): stream = torch.cuda.Stream() first = True for next_input, next_target in self.loader: with torch.cuda.stream(stream): next_input = next_input.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True) if self.fp16: next_input = next_input.half().sub_(self.mean).div_(self.std) else: next_input = next_input.float().sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) if not first: yield input, target else: first = False torch.cuda.current_stream().wait_stream(stream) input = next_input target = next_target yield input, target def __len__(self): return len(self.loader) @property def sampler(self): return self.loader.sampler @property def dataset(self): return self.loader.dataset @property def mixup_enabled(self): if isinstance(self.loader.collate_fn, FastCollateMixup): return self.loader.collate_fn.mixup_enabled else: return False @mixup_enabled.setter def mixup_enabled(self, x): if isinstance(self.loader.collate_fn, FastCollateMixup): self.loader.collate_fn.mixup_enabled = x