You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
155 lines
4.6 KiB
155 lines
4.6 KiB
6 years ago
|
import torch.utils.data
|
||
5 years ago
|
from .transforms import *
|
||
|
from .distributed_sampler import OrderedDistributedSampler
|
||
|
from .mixup import FastCollateMixup
|
||
6 years ago
|
|
||
|
|
||
|
def fast_collate(batch):
|
||
|
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||
|
batch_size = len(targets)
|
||
|
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||
|
for i in range(batch_size):
|
||
|
tensor[i] += torch.from_numpy(batch[i][0])
|
||
|
|
||
|
return tensor, targets
|
||
|
|
||
|
|
||
|
class PrefetchLoader:
|
||
|
|
||
|
def __init__(self,
|
||
|
loader,
|
||
6 years ago
|
rand_erase_prob=0.,
|
||
6 years ago
|
rand_erase_mode='const',
|
||
6 years ago
|
mean=IMAGENET_DEFAULT_MEAN,
|
||
|
std=IMAGENET_DEFAULT_STD):
|
||
6 years ago
|
self.loader = loader
|
||
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||
6 years ago
|
if rand_erase_prob > 0.:
|
||
6 years ago
|
self.random_erasing = RandomErasing(
|
||
6 years ago
|
probability=rand_erase_prob, mode=rand_erase_mode)
|
||
6 years ago
|
else:
|
||
|
self.random_erasing = None
|
||
|
|
||
|
def __iter__(self):
|
||
6 years ago
|
stream = torch.cuda.Stream()
|
||
6 years ago
|
first = True
|
||
|
|
||
|
for next_input, next_target in self.loader:
|
||
6 years ago
|
with torch.cuda.stream(stream):
|
||
6 years ago
|
next_input = next_input.cuda(non_blocking=True)
|
||
|
next_target = next_target.cuda(non_blocking=True)
|
||
6 years ago
|
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||
6 years ago
|
if self.random_erasing is not None:
|
||
|
next_input = self.random_erasing(next_input)
|
||
|
|
||
|
if not first:
|
||
|
yield input, target
|
||
|
else:
|
||
|
first = False
|
||
|
|
||
6 years ago
|
torch.cuda.current_stream().wait_stream(stream)
|
||
6 years ago
|
input = next_input
|
||
|
target = next_target
|
||
|
|
||
|
yield input, target
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.loader)
|
||
6 years ago
|
|
||
6 years ago
|
@property
|
||
|
def sampler(self):
|
||
|
return self.loader.sampler
|
||
|
|
||
5 years ago
|
@property
|
||
|
def dataset(self):
|
||
|
return self.loader.dataset
|
||
|
|
||
6 years ago
|
@property
|
||
|
def mixup_enabled(self):
|
||
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||
|
return self.loader.collate_fn.mixup_enabled
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
@mixup_enabled.setter
|
||
|
def mixup_enabled(self, x):
|
||
|
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||
|
self.loader.collate_fn.mixup_enabled = x
|
||
|
|
||
6 years ago
|
|
||
|
def create_loader(
|
||
|
dataset,
|
||
6 years ago
|
input_size,
|
||
6 years ago
|
batch_size,
|
||
|
is_training=False,
|
||
|
use_prefetcher=True,
|
||
6 years ago
|
rand_erase_prob=0.,
|
||
6 years ago
|
rand_erase_mode='const',
|
||
6 years ago
|
interpolation='bilinear',
|
||
6 years ago
|
mean=IMAGENET_DEFAULT_MEAN,
|
||
|
std=IMAGENET_DEFAULT_STD,
|
||
|
num_workers=1,
|
||
6 years ago
|
distributed=False,
|
||
6 years ago
|
crop_pct=None,
|
||
6 years ago
|
collate_fn=None,
|
||
6 years ago
|
tf_preprocessing=False,
|
||
6 years ago
|
):
|
||
6 years ago
|
if isinstance(input_size, tuple):
|
||
|
img_size = input_size[-2:]
|
||
|
else:
|
||
|
img_size = input_size
|
||
6 years ago
|
|
||
6 years ago
|
if tf_preprocessing and use_prefetcher:
|
||
5 years ago
|
from timm.data.tf_preprocessing import TfPreprocessTransform
|
||
6 years ago
|
transform = TfPreprocessTransform(is_training=is_training, size=img_size)
|
||
6 years ago
|
else:
|
||
6 years ago
|
if is_training:
|
||
|
transform = transforms_imagenet_train(
|
||
|
img_size,
|
||
|
interpolation=interpolation,
|
||
|
use_prefetcher=use_prefetcher,
|
||
|
mean=mean,
|
||
|
std=std)
|
||
|
else:
|
||
|
transform = transforms_imagenet_eval(
|
||
|
img_size,
|
||
|
interpolation=interpolation,
|
||
|
use_prefetcher=use_prefetcher,
|
||
|
mean=mean,
|
||
|
std=std,
|
||
|
crop_pct=crop_pct)
|
||
6 years ago
|
|
||
|
dataset.transform = transform
|
||
|
|
||
6 years ago
|
sampler = None
|
||
|
if distributed:
|
||
6 years ago
|
if is_training:
|
||
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
||
|
else:
|
||
|
# This will add extra duplicate entries to result in equal num
|
||
|
# of samples per-process, will slightly alter validation results
|
||
|
sampler = OrderedDistributedSampler(dataset)
|
||
6 years ago
|
|
||
6 years ago
|
if collate_fn is None:
|
||
|
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
|
||
|
|
||
6 years ago
|
loader = torch.utils.data.DataLoader(
|
||
6 years ago
|
dataset,
|
||
|
batch_size=batch_size,
|
||
6 years ago
|
shuffle=sampler is None and is_training,
|
||
6 years ago
|
num_workers=num_workers,
|
||
6 years ago
|
sampler=sampler,
|
||
6 years ago
|
collate_fn=collate_fn,
|
||
6 years ago
|
drop_last=is_training,
|
||
6 years ago
|
)
|
||
|
if use_prefetcher:
|
||
|
loader = PrefetchLoader(
|
||
|
loader,
|
||
6 years ago
|
rand_erase_prob=rand_erase_prob if is_training else 0.,
|
||
6 years ago
|
rand_erase_mode=rand_erase_mode,
|
||
6 years ago
|
mean=mean,
|
||
|
std=std)
|
||
|
|
||
|
return loader
|