From a7f570c9b72369fd75e15733d5645d09039d5f9e Mon Sep 17 00:00:00 2001 From: "AFLALO, Jonathan Isaac" Date: Tue, 5 May 2020 14:32:10 +0300 Subject: [PATCH 1/2] added MultiEpochsDataLoader --- timm/data/loader.py | 40 +++++++++++++++++++++++++++++++++++++++- train.py | 3 +++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/timm/data/loader.py b/timm/data/loader.py index f3faf7b9..4dc0d697 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -140,6 +140,7 @@ def create_loader( pin_memory=False, fp16=False, tf_preprocessing=False, + use_multi_epochs_loader=False ): re_num_splits = 0 if re_split: @@ -175,7 +176,12 @@ def create_loader( if collate_fn is None: collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate - loader = torch.utils.data.DataLoader( + loader_class = torch.utils.data.DataLoader + + if use_multi_epochs_loader: + loader_class = MultiEpochsDataLoader + + loader = loader_class( dataset, batch_size=batch_size, shuffle=sampler is None and is_training, @@ -198,3 +204,35 @@ def create_loader( ) return loader + + +class MultiEpochsDataLoader(torch.utils.data.DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._DataLoader__initialized = False + self.batch_sampler = _RepeatSampler(self.batch_sampler) + self._DataLoader__initialized = True + self.iterator = super().__iter__() + + def __len__(self): + return len(self.batch_sampler.sampler) + + def __iter__(self): + for i in range(len(self)): + yield next(self.iterator) + + +class _RepeatSampler(object): + """ Sampler that repeats forever. + + Args: + sampler (Sampler) + """ + + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + while True: + yield from iter(self.sampler) diff --git a/train.py b/train.py index e88640a7..899c6984 100755 --- a/train.py +++ b/train.py @@ -198,6 +198,8 @@ parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_MET parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) +parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, + help='use the multi-epochs-loader to save time at the beginning of every epoch') def _parse_args(): @@ -391,6 +393,7 @@ def main(): distributed=args.distributed, collate_fn=collate_fn, pin_memory=args.pin_mem, + use_multi_epochs_loader=args.use_multi_epochs_loader ) eval_dir = os.path.join(args.data, 'val') From f0eb021620d2f6025226f0198e26dfef59a1ddaf Mon Sep 17 00:00:00 2001 From: Vyacheslav Shults Date: Tue, 5 May 2020 21:09:35 +0300 Subject: [PATCH 2/2] Replace all None by nn.Identity() in HRNet modules --- timm/models/hrnet.py | 54 ++++++++++++++++++++------------------------ 1 file changed, 25 insertions(+), 29 deletions(-) diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 16df5bc1..06327c65 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -13,20 +13,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import logging -import functools -import torch import torch.nn as nn -import torch._utils import torch.nn.functional as F -from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE -from .registry import register_model +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import SelectAdaptivePool2d -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .registry import register_model +from .resnet import BasicBlock, Bottleneck # leveraging ResNet blocks w/ additional features like SE _BN_MOMENTUM = 0.1 logger = logging.getLogger(__name__) @@ -101,7 +97,7 @@ cfg_cls = dict( ), ), - hrnet_w18_small_v2 = dict( + hrnet_w18_small_v2=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -137,7 +133,7 @@ cfg_cls = dict( ), ), - hrnet_w18 = dict( + hrnet_w18=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -173,7 +169,7 @@ cfg_cls = dict( ), ), - hrnet_w30 = dict( + hrnet_w30=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -209,7 +205,7 @@ cfg_cls = dict( ), ), - hrnet_w32 = dict( + hrnet_w32=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -245,7 +241,7 @@ cfg_cls = dict( ), ), - hrnet_w40 = dict( + hrnet_w40=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -281,7 +277,7 @@ cfg_cls = dict( ), ), - hrnet_w44 = dict( + hrnet_w44=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -317,7 +313,7 @@ cfg_cls = dict( ), ), - hrnet_w48 = dict( + hrnet_w48=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -353,7 +349,7 @@ cfg_cls = dict( ), ), - hrnet_w64 = dict( + hrnet_w64=dict( STEM_WIDTH=64, STAGE1=dict( NUM_MODULES=1, @@ -456,7 +452,7 @@ class HighResolutionModule(nn.Module): def _make_fuse_layers(self): if self.num_branches == 1: - return None + return nn.Identity() num_branches = self.num_branches num_inchannels = self.num_inchannels @@ -470,7 +466,7 @@ class HighResolutionModule(nn.Module): nn.BatchNorm2d(num_inchannels[i], momentum=_BN_MOMENTUM), nn.Upsample(scale_factor=2 ** (j - i), mode='nearest'))) elif j == i: - fuse_layer.append(None) + fuse_layer.append(nn.Identity()) else: conv3x3s = [] for k in range(i - j): @@ -619,7 +615,7 @@ class HighResolutionNet(nn.Module): nn.BatchNorm2d(num_channels_cur_layer[i], momentum=_BN_MOMENTUM), nn.ReLU(inplace=True))) else: - transition_layers.append(None) + transition_layers.append(nn.Identity()) else: conv3x3s = [] for j in range(i + 1 - num_branches_pre): @@ -686,8 +682,11 @@ class HighResolutionNet(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.classifier = nn.Linear( - self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None + num_features = self.num_features * self.global_pool.feat_mult() + if num_classes: + self.classifier = nn.Linear(num_features, num_classes) + else: + self.classifier = nn.Identity() def forward_features(self, x): x = self.conv1(x) @@ -699,24 +698,21 @@ class HighResolutionNet(nn.Module): x = self.layer1(x) x_list = [] - for i in range(self.stage2_cfg['NUM_BRANCHES']): - if self.transition1[i] is not None: - x_list.append(self.transition1[i](x)) - else: - x_list.append(x) + for i in range(len(self.transition1)): + x_list.append(self.transition1[i](x)) y_list = self.stage2(x_list) x_list = [] - for i in range(self.stage3_cfg['NUM_BRANCHES']): - if self.transition2[i] is not None: + for i in range(len(self.transition2)): + if not isinstance(self.transition2[i], nn.Identity): x_list.append(self.transition2[i](y_list[-1])) else: x_list.append(y_list[i]) y_list = self.stage3(x_list) x_list = [] - for i in range(self.stage4_cfg['NUM_BRANCHES']): - if self.transition3[i] is not None: + for i in range(len(self.transition3)): + if not isinstance(self.transition3[i], nn.Identity): x_list.append(self.transition3[i](y_list[-1])) else: x_list.append(y_list[i])