From 71afec86d372f64bbbad065826007a3014bf13d6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 13 Apr 2019 14:52:38 -0700 Subject: [PATCH] Loader tweaks --- data/loader.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/data/loader.py b/data/loader.py index 32fa9595..af900c86 100644 --- a/data/loader.py +++ b/data/loader.py @@ -1,5 +1,5 @@ import torch -import torch.utils.data as tdata +import torch.utils.data from data.random_erasing import RandomErasingTorch from data.transforms import * @@ -105,15 +105,16 @@ def create_loader( # FIXME note, doing this for validation isn't technically correct # There currently is no fixed order distributed sampler that corrects # for padded entries - sampler = tdata.distributed.DistributedSampler(dataset) + sampler = torch.utils.data.distributed.DistributedSampler(dataset) - loader = tdata.DataLoader( + loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, shuffle=sampler is None and is_training, num_workers=num_workers, sampler=sampler, - collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate, + collate_fn=fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate, + drop_last=is_training, ) if use_prefetcher: loader = PrefetchLoader(