|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
import torch
|
|
|
|
|
import torch.utils.data as tdata
|
|
|
|
|
import torch.utils.data
|
|
|
|
|
from data.random_erasing import RandomErasingTorch
|
|
|
|
|
from data.transforms import *
|
|
|
|
|
|
|
|
|
@ -105,15 +105,16 @@ def create_loader(
|
|
|
|
|
# FIXME note, doing this for validation isn't technically correct
|
|
|
|
|
# There currently is no fixed order distributed sampler that corrects
|
|
|
|
|
# for padded entries
|
|
|
|
|
sampler = tdata.distributed.DistributedSampler(dataset)
|
|
|
|
|
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
|
|
|
|
|
|
|
|
|
|
loader = tdata.DataLoader(
|
|
|
|
|
loader = torch.utils.data.DataLoader(
|
|
|
|
|
dataset,
|
|
|
|
|
batch_size=batch_size,
|
|
|
|
|
shuffle=sampler is None and is_training,
|
|
|
|
|
num_workers=num_workers,
|
|
|
|
|
sampler=sampler,
|
|
|
|
|
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
|
|
|
|
|
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
|
|
|
|
|
drop_last=is_training,
|
|
|
|
|
)
|
|
|
|
|
if use_prefetcher:
|
|
|
|
|
loader = PrefetchLoader(
|
|
|
|
|