Loader tweaks

pull/1/head
Ross Wightman 5 years ago
parent 79f615639e
commit 71afec86d3

@ -1,5 +1,5 @@
import torch
import torch.utils.data as tdata
import torch.utils.data
from data.random_erasing import RandomErasingTorch
from data.transforms import *
@ -105,15 +105,16 @@ def create_loader(
# FIXME note, doing this for validation isn't technically correct
# There currently is no fixed order distributed sampler that corrects
# for padded entries
sampler = tdata.distributed.DistributedSampler(dataset)
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
loader = tdata.DataLoader(
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
drop_last=is_training,
)
if use_prefetcher:
loader = PrefetchLoader(

Loading…
Cancel
Save