From 766b4d32627fc4d1d9d188de81736504215127a0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 28 Jun 2021 15:56:24 -0700 Subject: [PATCH] Fix features for resnetv2_50t --- timm/models/resnetv2.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 4fd3b823..2ff4da8c 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -291,6 +291,10 @@ class ResNetStage(nn.Module): return x +def is_stem_deep(stem_type): + return any([s in stem_type for s in ('deep', 'tiered')]) + + def create_resnetv2_stem( in_chs, out_chs=64, stem_type='', preact=True, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): @@ -298,7 +302,7 @@ def create_resnetv2_stem( assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered') # NOTE conv padding mode can be changed by overriding the conv_layer def - if any([s in stem_type for s in ('deep', 'tiered')]): + if is_stem_deep(stem_type): # A 3 deep 3x3 conv stack as in ResNet V1D models if 'tiered' in stem_type: stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py @@ -350,7 +354,7 @@ class ResNetV2(nn.Module): stem_chs = make_div(stem_chs * wf) self.stem = create_resnetv2_stem( in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) - stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm' + stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) prev_chs = stem_chs