|
|
|
@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint
|
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
|
|
|
from .fx_features import FeatureGraphNet
|
|
|
|
|
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
|
|
|
|
from .layers import Conv2dSame, Linear
|
|
|
|
|
from .layers import Conv2dSame, Linear, BatchNormAct2d
|
|
|
|
|
from .registry import get_pretrained_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string):
|
|
|
|
|
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
|
|
|
|
groups=g, stride=old_module.stride)
|
|
|
|
|
set_layer(new_module, n, new_conv)
|
|
|
|
|
if isinstance(old_module, nn.BatchNorm2d):
|
|
|
|
|
elif isinstance(old_module, BatchNormAct2d):
|
|
|
|
|
new_bn = BatchNormAct2d(
|
|
|
|
|
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
|
|
|
|
affine=old_module.affine, track_running_stats=True)
|
|
|
|
|
new_bn.drop = old_module.drop
|
|
|
|
|
new_bn.act = old_module.act
|
|
|
|
|
set_layer(new_module, n, new_bn)
|
|
|
|
|
elif isinstance(old_module, nn.BatchNorm2d):
|
|
|
|
|
new_bn = nn.BatchNorm2d(
|
|
|
|
|
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
|
|
|
|
affine=old_module.affine, track_running_stats=True)
|
|
|
|
|
set_layer(new_module, n, new_bn)
|
|
|
|
|
if isinstance(old_module, nn.Linear):
|
|
|
|
|
elif isinstance(old_module, nn.Linear):
|
|
|
|
|
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
|
|
|
|
num_features = state_dict[n + '.weight'][1]
|
|
|
|
|
new_fc = Linear(
|
|
|
|
|