From dc51334cdc05757dc9004583aac8668ebd892b03 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 22 Mar 2022 20:33:01 -0700 Subject: [PATCH] Fix pruned adapt for EfficientNet models that are now using BatchNormAct layers --- timm/models/helpers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index eda09680..bbedd7a8 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf -from .layers import Conv2dSame, Linear +from .layers import Conv2dSame, Linear, BatchNormAct2d from .registry import get_pretrained_cfg @@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string): bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, groups=g, stride=old_module.stride) set_layer(new_module, n, new_conv) - if isinstance(old_module, nn.BatchNorm2d): + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): new_bn = nn.BatchNorm2d( num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, affine=old_module.affine, track_running_stats=True) set_layer(new_module, n, new_bn) - if isinstance(old_module, nn.Linear): + elif isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] new_fc = Linear(