diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 6b8c0473..3d50b704 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -600,7 +600,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) block_args=decode_arch_def(arch_def), stem_size=32, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -636,7 +636,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) block_args=decode_arch_def(arch_def), stem_size=32, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -665,7 +665,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar block_args=decode_arch_def(arch_def), stem_size=8, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -694,7 +694,7 @@ def _gen_mobilenet_v2( stem_size=32, fix_stem=fix_stem_head, round_chs_fn=round_chs_fn, - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'relu6'), **kwargs ) @@ -725,7 +725,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): stem_size=16, num_features=1984, # paper suggests this, but is not 100% clear round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -760,7 +760,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): block_args=decode_arch_def(arch_def), stem_size=32, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -807,7 +807,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre stem_size=32, round_chs_fn=round_chs_fn, act_layer=resolve_act_layer(kwargs, 'swish'), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs, ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -836,7 +836,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 num_features=round_chs_fn(1280), stem_size=32, round_chs_fn=round_chs_fn, - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'relu'), **kwargs, ) @@ -867,7 +867,7 @@ def _gen_efficientnet_condconv( num_features=round_chs_fn(1280), stem_size=32, round_chs_fn=round_chs_fn, - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'swish'), **kwargs, ) @@ -909,7 +909,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0 fix_stem=True, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), act_layer=resolve_act_layer(kwargs, 'relu6'), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs, ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -937,7 +937,7 @@ def _gen_efficientnetv2_base( num_features=round_chs_fn(1280), stem_size=32, round_chs_fn=round_chs_fn, - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'silu'), **kwargs, ) @@ -976,7 +976,7 @@ def _gen_efficientnetv2_s( num_features=round_chs_fn(num_features), stem_size=24, round_chs_fn=round_chs_fn, - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'silu'), **kwargs, ) @@ -1006,7 +1006,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, num_features=1280, stem_size=24, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'silu'), **kwargs, ) @@ -1036,7 +1036,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, num_features=1280, stem_size=32, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'silu'), **kwargs, ) @@ -1066,7 +1066,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0 num_features=1280, stem_size=32, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), act_layer=resolve_act_layer(kwargs, 'silu'), **kwargs, ) @@ -1100,7 +1100,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): num_features=1536, stem_size=16, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs) @@ -1133,7 +1133,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai num_features=1536, stem_size=24, round_chs_fn=partial(round_channels, multiplier=channel_multiplier), - norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), + norm_layer=kwargs.pop('norm_layer', None) or partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)), **kwargs ) model = _create_effnet(variant, pretrained, **model_kwargs)