Fix pruned adapt for EfficientNet models that are now using BatchNormAct layers

pull/1190/head
Ross Wightman 3 years ago
parent 024fc4d9ab
commit dc51334cdc

@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear, BatchNormAct2d
from .registry import get_pretrained_cfg from .registry import get_pretrained_cfg
@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string):
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride) groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv) set_layer(new_module, n, new_conv)
if isinstance(old_module, nn.BatchNorm2d): elif isinstance(old_module, BatchNormAct2d):
new_bn = BatchNormAct2d(
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
new_bn.drop = old_module.drop
new_bn.act = old_module.act
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d( new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True) affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn) set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear): elif isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1] num_features = state_dict[n + '.weight'][1]
new_fc = Linear( new_fc = Linear(

Loading…
Cancel
Save