You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.5 KiB
79 lines
2.5 KiB
import torch.cuda
|
|
|
|
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
from .mixup import FastCollateMixup
|
|
from .random_erasing import RandomErasing
|
|
|
|
|
|
class PrefetcherCuda:
|
|
|
|
def __init__(self,
|
|
loader,
|
|
mean=IMAGENET_DEFAULT_MEAN,
|
|
std=IMAGENET_DEFAULT_STD,
|
|
fp16=False,
|
|
re_prob=0.,
|
|
re_mode='const',
|
|
re_count=1,
|
|
re_num_splits=0):
|
|
self.loader = loader
|
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
|
self.fp16 = fp16
|
|
if fp16:
|
|
self.mean = self.mean.half()
|
|
self.std = self.std.half()
|
|
if re_prob > 0.:
|
|
self.random_erasing = RandomErasing(
|
|
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
|
|
else:
|
|
self.random_erasing = None
|
|
|
|
def __iter__(self):
|
|
stream = torch.cuda.Stream()
|
|
first = True
|
|
|
|
for next_input, next_target in self.loader:
|
|
with torch.cuda.stream(stream):
|
|
next_input = next_input.cuda(non_blocking=True)
|
|
next_target = next_target.cuda(non_blocking=True)
|
|
if self.fp16:
|
|
next_input = next_input.half().sub_(self.mean).div_(self.std)
|
|
else:
|
|
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
|
if self.random_erasing is not None:
|
|
next_input = self.random_erasing(next_input)
|
|
|
|
if not first:
|
|
yield input, target
|
|
else:
|
|
first = False
|
|
|
|
torch.cuda.current_stream().wait_stream(stream)
|
|
input = next_input
|
|
target = next_target
|
|
|
|
yield input, target
|
|
|
|
def __len__(self):
|
|
return len(self.loader)
|
|
|
|
@property
|
|
def sampler(self):
|
|
return self.loader.sampler
|
|
|
|
@property
|
|
def dataset(self):
|
|
return self.loader.dataset
|
|
|
|
@property
|
|
def mixup_enabled(self):
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
return self.loader.collate_fn.mixup_enabled
|
|
else:
|
|
return False
|
|
|
|
@mixup_enabled.setter
|
|
def mixup_enabled(self, x):
|
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
|
self.loader.collate_fn.mixup_enabled = x |