Loader tweaks

pull/1/head
Ross Wightman 6 years ago
parent 79f615639e
commit 71afec86d3

@ -1,5 +1,5 @@
import torch import torch
import torch.utils.data as tdata import torch.utils.data
from data.random_erasing import RandomErasingTorch from data.random_erasing import RandomErasingTorch
from data.transforms import * from data.transforms import *
@ -105,15 +105,16 @@ def create_loader(
# FIXME note, doing this for validation isn't technically correct # FIXME note, doing this for validation isn't technically correct
# There currently is no fixed order distributed sampler that corrects # There currently is no fixed order distributed sampler that corrects
# for padded entries # for padded entries
sampler = tdata.distributed.DistributedSampler(dataset) sampler = torch.utils.data.distributed.DistributedSampler(dataset)
loader = tdata.DataLoader( loader = torch.utils.data.DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=sampler is None and is_training, shuffle=sampler is None and is_training,
num_workers=num_workers, num_workers=num_workers,
sampler=sampler, sampler=sampler,
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate, collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate,
drop_last=is_training,
) )
if use_prefetcher: if use_prefetcher:
loader = PrefetchLoader( loader = PrefetchLoader(

Loading…
Cancel
Save