|
|
@ -291,6 +291,10 @@ class ResNetStage(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_stem_deep(stem_type):
|
|
|
|
|
|
|
|
return any([s in stem_type for s in ('deep', 'tiered')])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_resnetv2_stem(
|
|
|
|
def create_resnetv2_stem(
|
|
|
|
in_chs, out_chs=64, stem_type='', preact=True,
|
|
|
|
in_chs, out_chs=64, stem_type='', preact=True,
|
|
|
|
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
|
|
|
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
|
|
|
@ -298,7 +302,7 @@ def create_resnetv2_stem(
|
|
|
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
|
|
|
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
|
|
|
# NOTE conv padding mode can be changed by overriding the conv_layer def
|
|
|
|
if any([s in stem_type for s in ('deep', 'tiered')]):
|
|
|
|
if is_stem_deep(stem_type):
|
|
|
|
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
|
|
|
# A 3 deep 3x3 conv stack as in ResNet V1D models
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
if 'tiered' in stem_type:
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
|
|
|
|
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
|
|
|
@ -350,7 +354,7 @@ class ResNetV2(nn.Module):
|
|
|
|
stem_chs = make_div(stem_chs * wf)
|
|
|
|
stem_chs = make_div(stem_chs * wf)
|
|
|
|
self.stem = create_resnetv2_stem(
|
|
|
|
self.stem = create_resnetv2_stem(
|
|
|
|
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
|
|
|
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
|
|
|
|
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
|
|
|
|
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
|
|
|
|
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
|
|
|
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
|
|
|
|
|
|
|
|
|
|
|
|
prev_chs = stem_chs
|
|
|
|
prev_chs = stem_chs
|
|
|
|