From b4b8d1ec1860124970a4ec480684b96f66951da6 Mon Sep 17 00:00:00 2001 From: KAI ZHAO Date: Tue, 14 Dec 2021 17:22:54 +0800 Subject: [PATCH 1/6] fix hard-coded strides --- timm/models/visformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 6e832cd0..37284c9d 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -167,14 +167,14 @@ class Visformer(nn.Module): self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size = [x // 16 for x in img_size] + img_size = [x // patch_size for x in img_size] else: if self.init_channels is None: self.stem = None self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size = [x // 8 for x in img_size] + img_size = [x // (patch_size // 2) for x in img_size] else: self.stem = nn.Sequential( nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), @@ -185,7 +185,7 @@ class Visformer(nn.Module): self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size = [x // 4 for x in img_size] + img_size = [x // (patch_size // 4) for x in img_size] if self.pos_embed: if self.vit_stem: @@ -207,7 +207,7 @@ class Visformer(nn.Module): self.patch_embed2 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size = [x // 2 for x in img_size] + img_size = [x // (patch_size // 8) for x in img_size] if self.pos_embed: self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) self.stage2 = nn.ModuleList([ @@ -224,7 +224,7 @@ class Visformer(nn.Module): self.patch_embed3 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) - img_size = [x // 2 for x in img_size] + img_size = [x // (patch_size // 8) for x in img_size] if self.pos_embed: self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) self.stage3 = nn.ModuleList([ From 31bcd36e4658c7a418797058ce5fcc57da0d87fc Mon Sep 17 00:00:00 2001 From: Rahul Somani Date: Tue, 14 Dec 2021 19:34:04 +0530 Subject: [PATCH 2/6] add tinynet models --- timm/models/efficientnet.py | 94 +++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3d50b704..ec7e17c7 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -23,6 +23,10 @@ An implementation of EfficienNet that covers variety of related models with effi * Single-Path NAS Pixel1 - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 +* TinyNet + - Model Rubik's Cube: Twisting Resolution, Depth and Width for TinyNets - https://arxiv.org/abs/2010.14819 + - Definitions & weights borrowed from https://github.com/huawei-noah/CV-Backbones/tree/master/tinynet_pytorch + * And likely more... The majority of the above models (EfficientNet*, MixNet, MnasNet) and original weights were made available @@ -407,6 +411,22 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'), 'tf_mixnet_l': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), + + "tinynet_a": _cfg( + input_size=(3, 192, 192), # int(224 * 0.86) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth'), + "tinynet_b": _cfg( + input_size=(3, 188, 188), # int(224 * 0.84) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth'), + "tinynet_c": _cfg( + input_size=(3, 184, 184), # int(224 * 0.825) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth'), + "tinynet_d": _cfg( + input_size=(3, 152, 152), # int(224 * 0.68) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth'), + "tinynet_e": _cfg( + input_size=(3, 106, 106), # int(224 * 0.475) + url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth'), } @@ -1140,6 +1160,50 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai return model +def _gen_tinynet( + variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs +): + """Creates a TinyNet model. + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + model_kwargs = dict( + block_args = decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features = max(1280, round_channels(1280, model_width, 8, None)), + stem_size = 32, + fix_stem = True, + round_chs_fn=partial(round_channels, multiplier=model_width), + act_layer = resolve_act_layer(kwargs, 'swish'), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + **kwargs, + ) + + features_only = False + model_cls = EfficientNet + kwargs_filter = None + + if kwargs.pop('features_only', False): + features_only = True + # kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') + kwargs_filter = ('num_classes', 'num_features', 'conv_head', 'global_pool') + model_cls = EfficientNetFeatures + + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) + + return model + + @register_model def mnasnet_050(pretrained=False, **kwargs): """ MNASNet B1, depth multiplier of 0.5. """ @@ -2209,3 +2273,33 @@ def tf_mixnet_l(pretrained=False, **kwargs): model = _gen_mixnet_m( 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) return model + + +@register_model +def tinynet_a(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_a', 1.0, 1.2, **kwargs) + return model + + +@register_model +def tinynet_b(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained, **kwargs) + return model + + +@register_model +def tinynet_c(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained, **kwargs) + return model + + +@register_model +def tinynet_d(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained, **kwargs) + return model + + +@register_model +def tinynet_e(pretrained=False, **kwargs): + model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained, **kwargs) + return model From 450ac6a0f5bfd2db4d30355c99a23b91d9ab63e4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 21 Dec 2021 23:51:54 -0800 Subject: [PATCH 3/6] Post merge tinynet fixes for pool_size, feature extraction --- timm/models/efficientnet.py | 51 ++++++++++++------------------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index ec7e17c7..b1c570b2 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -413,19 +413,19 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'), "tinynet_a": _cfg( - input_size=(3, 192, 192), # int(224 * 0.86) + input_size=(3, 192, 192), pool_size=(6, 6), # int(224 * 0.86) url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_a.pth'), "tinynet_b": _cfg( - input_size=(3, 188, 188), # int(224 * 0.84) + input_size=(3, 188, 188), pool_size=(6, 6), # int(224 * 0.84) url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_b.pth'), "tinynet_c": _cfg( - input_size=(3, 184, 184), # int(224 * 0.825) + input_size=(3, 184, 184), pool_size=(6, 6), # int(224 * 0.825) url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_c.pth'), "tinynet_d": _cfg( - input_size=(3, 152, 152), # int(224 * 0.68) + input_size=(3, 152, 152), pool_size=(5, 5), # int(224 * 0.68) url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_d.pth'), "tinynet_e": _cfg( - input_size=(3, 106, 106), # int(224 * 0.475) + input_size=(3, 106, 106), pool_size=(4, 4), # int(224 * 0.475) url='https://github.com/huawei-noah/CV-Backbones/releases/download/v1.2.0/tinynet_e.pth'), } @@ -1172,35 +1172,16 @@ def _gen_tinynet( ['ir_r1_k3_s1_e6_c320_se0.25'], ] model_kwargs = dict( - block_args = decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), - num_features = max(1280, round_channels(1280, model_width, 8, None)), - stem_size = 32, - fix_stem = True, + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=max(1280, round_channels(1280, model_width, 8, None)), + stem_size=32, + fix_stem=True, round_chs_fn=partial(round_channels, multiplier=model_width), - act_layer = resolve_act_layer(kwargs, 'swish'), + act_layer=resolve_act_layer(kwargs, 'swish'), norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs, ) - - features_only = False - model_cls = EfficientNet - kwargs_filter = None - - if kwargs.pop('features_only', False): - features_only = True - # kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') - kwargs_filter = ('num_classes', 'num_features', 'conv_head', 'global_pool') - model_cls = EfficientNetFeatures - - model = build_model_with_cfg( - model_cls, variant, pretrained, - default_cfg=default_cfgs[variant], - pretrained_strict=not features_only, - kwargs_filter=kwargs_filter, - **model_kwargs) - if features_only: - model.default_cfg = default_cfg_for_features(model.default_cfg) - + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -2277,29 +2258,29 @@ def tf_mixnet_l(pretrained=False, **kwargs): @register_model def tinynet_a(pretrained=False, **kwargs): - model = _gen_tinynet('tinynet_a', 1.0, 1.2, **kwargs) + model = _gen_tinynet('tinynet_a', 1.0, 1.2, pretrained=pretrained, **kwargs) return model @register_model def tinynet_b(pretrained=False, **kwargs): - model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained, **kwargs) + model = _gen_tinynet('tinynet_b', 0.75, 1.1, pretrained=pretrained, **kwargs) return model @register_model def tinynet_c(pretrained=False, **kwargs): - model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained, **kwargs) + model = _gen_tinynet('tinynet_c', 0.54, 0.85, pretrained=pretrained, **kwargs) return model @register_model def tinynet_d(pretrained=False, **kwargs): - model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained, **kwargs) + model = _gen_tinynet('tinynet_d', 0.54, 0.695, pretrained=pretrained, **kwargs) return model @register_model def tinynet_e(pretrained=False, **kwargs): - model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained, **kwargs) + model = _gen_tinynet('tinynet_e', 0.51, 0.6, pretrained=pretrained, **kwargs) return model From ca4fb7d1540acd47d7bef6f83f379e5724c7bb9a Mon Sep 17 00:00:00 2001 From: kozistr Date: Sat, 1 Jan 2022 15:31:38 +0900 Subject: [PATCH 4/6] feature: boost speed of pytest --- .github/workflows/tests.yml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9e0a4aac..908b1ec8 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,7 +30,7 @@ jobs: - name: Install testing dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-timeout expecttest + pip install pytest pytest-timeout expecttest pytest-xdist - name: Install torch on mac if: startsWith(matrix.os, 'macOS') run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} @@ -48,4 +48,5 @@ jobs: env: LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 run: | - pytest -vv --durations=0 ./tests + export PYTHONDONTWRITEBYTECODE=1 + pytest -vv --durations=0 -n auto ./tests From df5afabc92eae6edc7575730f1063875ae52373f Mon Sep 17 00:00:00 2001 From: kozistr Date: Sun, 2 Jan 2022 14:39:17 +0900 Subject: [PATCH 5/6] update: rollback pytest-xdist --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 908b1ec8..bcb42fee 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,7 +30,7 @@ jobs: - name: Install testing dependencies run: | python -m pip install --upgrade pip - pip install pytest pytest-timeout expecttest pytest-xdist + pip install pytest pytest-timeout expecttest - name: Install torch on mac if: startsWith(matrix.os, 'macOS') run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }} @@ -49,4 +49,4 @@ jobs: LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 run: | export PYTHONDONTWRITEBYTECODE=1 - pytest -vv --durations=0 -n auto ./tests + pytest -vv --durations=0 ./tests From a0b26574976c5ec1b6c498a9336e46707e0d1571 Mon Sep 17 00:00:00 2001 From: Hyeongchan Kim Date: Mon, 3 Jan 2022 07:01:06 +0900 Subject: [PATCH 6/6] Use `torch.repeat_interleave()` to generate repeated indices faster (#1058) * update: use numpy to generate repeated indices faster * update: use torch.repeat_interleave() instead of np.repeat() * refactor: remove unused import, numpy * refactor: torch.range to torch.arange * update: tensor to list before appending the extra samples * update: concatenate the paddings with torch.cat --- timm/data/distributed_sampler.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/timm/data/distributed_sampler.py b/timm/data/distributed_sampler.py index fa403d0a..16090189 100644 --- a/timm/data/distributed_sampler.py +++ b/timm/data/distributed_sampler.py @@ -103,15 +103,16 @@ class RepeatAugSampler(Sampler): g = torch.Generator() g.manual_seed(self.epoch) if self.shuffle: - indices = torch.randperm(len(self.dataset), generator=g).tolist() + indices = torch.randperm(len(self.dataset), generator=g) else: - indices = list(range(len(self.dataset))) + indices = torch.arange(start=0, end=len(self.dataset)) # produce repeats e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2....] - indices = [x for x in indices for _ in range(self.num_repeats)] + indices = torch.repeat_interleave(indices, repeats=self.num_repeats, dim=0) # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) - indices += indices[:padding_size] + if padding_size > 0: + indices = torch.cat([indices, indices[:padding_size]], dim=0) assert len(indices) == self.total_size # subsample per rank @@ -125,4 +126,4 @@ class RepeatAugSampler(Sampler): return self.num_selected_samples def set_epoch(self, epoch): - self.epoch = epoch \ No newline at end of file + self.epoch = epoch