Fix features for resnetv2_50t

pull/729/head
Ross Wightman 3 years ago
parent e8045e712f
commit 766b4d3262

@ -291,6 +291,10 @@ class ResNetStage(nn.Module):
return x
def is_stem_deep(stem_type):
return any([s in stem_type for s in ('deep', 'tiered')])
def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
@ -298,7 +302,7 @@ def create_resnetv2_stem(
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
# NOTE conv padding mode can be changed by overriding the conv_layer def
if any([s in stem_type for s in ('deep', 'tiered')]):
if is_stem_deep(stem_type):
# A 3 deep 3x3 conv stack as in ResNet V1D models
if 'tiered' in stem_type:
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
@ -350,7 +354,7 @@ class ResNetV2(nn.Module):
stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
prev_chs = stem_chs

Loading…
Cancel
Save