diff --git a/timm/data/loader.py b/timm/data/loader.py index 6be8a3c2..8b950ce6 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -140,6 +140,7 @@ def create_loader( pin_memory=False, fp16=False, tf_preprocessing=False, + use_multi_epochs_loader=False ): re_num_splits = 0 if re_split: @@ -175,7 +176,12 @@ def create_loader( if collate_fn is None: collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate - loader = torch.utils.data.DataLoader( + loader_class = torch.utils.data.DataLoader + + if use_multi_epochs_loader: + loader_class = MultiEpochsDataLoader + + loader = loader_class( dataset, batch_size=batch_size, shuffle=sampler is None and is_training, @@ -198,3 +204,35 @@ def create_loader( ) return loader + + +class MultiEpochsDataLoader(torch.utils.data.DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = _RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ Sampler that repeats forever. + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 4d8e75fd..f5f3c242 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -675,7 +675,10 @@ class HighResolutionNet(nn.Module): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) num_features = self.num_features * self.global_pool.feat_mult() - self.classifier = nn.Linear(num_features, num_classes) if num_classes else nn.Identity() + if num_classes: + self.classifier = nn.Linear(num_features, num_classes) + else: + self.classifier = nn.Identity() def forward_features(self, x): x = self.conv1(x) diff --git a/train.py b/train.py index e88640a7..899c6984 100755 --- a/train.py +++ b/train.py @@ -198,6 +198,8 @@ parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_MET parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, + help='use the multi-epochs-loader to save time at the beginning of every epoch') def _parse_args(): @@ -391,6 +393,7 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, + use_multi_epochs_loader=args.use_multi_epochs_loader ) eval_dir = os.path.join(args.data, 'val')