From b41cffaa93e8205bd8bd309f82c33c07c420eefd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 22 Jun 2021 23:16:05 -0700 Subject: [PATCH] Fix a few issues loading pretrained vit/bit npz weights w/ num_classes=0 __init__ arg. Missed a few other small classifier handling detail on Mlp, GhostNet, Levit. Should fix #713 --- tests/test_models.py | 17 +++++++++++++++++ timm/models/ghostnet.py | 2 +- timm/models/levit.py | 2 +- timm/models/mlp_mixer.py | 2 +- timm/models/resnetv2.py | 3 ++- timm/models/visformer.py | 1 - timm/models/vision_transformer.py | 2 +- 7 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 0a770784..5c8b02db 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -147,6 +147,15 @@ def test_model_default_cfgs(model_name, batch_size): # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + if 'pruned' not in model_name: # FIXME better pruned model handling + # test classifier + global pool deletion via __init__ + model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval() + outputs = model.forward(input_tensor) + assert len(outputs.shape) == 4 + if not isinstance(model, timm.models.MobileNetV3) and not isinstance(model, timm.models.GhostNet): + # FIXME mobilenetv3/ghostnet forward_features vs removed pooling differ + assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] + # check classifier name matches default_cfg classifier = cfg['classifier'] if not isinstance(classifier, (tuple, list)): @@ -193,6 +202,13 @@ def test_model_default_cfgs_non_std(model_name, batch_size): assert len(outputs.shape) == 2 assert outputs.shape[1] == model.num_features + model = create_model(model_name, pretrained=False, num_classes=0).eval() + outputs = model.forward(input_tensor) + if isinstance(outputs, tuple): + outputs = outputs[0] + assert len(outputs.shape) == 2 + assert outputs.shape[1] == model.num_features + # check classifier name matches default_cfg classifier = cfg['classifier'] if not isinstance(classifier, (tuple, list)): @@ -217,6 +233,7 @@ if 'GITHUB_ACTIONS' not in os.environ: """Create that pretrained weights load, verify support for in_chans != 3 while doing so.""" in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5) + create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=0) @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS)) diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index a73047c5..3b6f90a4 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -182,7 +182,7 @@ class GhostNet(nn.Module): self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled - self.classifier = Linear(out_chs, num_classes) + self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity() def get_classifier(self): return self.classifier diff --git a/timm/models/levit.py b/timm/models/levit.py index fa35f41f..9987e4ba 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -542,7 +542,7 @@ def checkpoint_filter_fn(state_dict, model): state_dict = state_dict['model'] D = model.state_dict() for k in state_dict.keys(): - if D[k].ndim == 4 and state_dict[k].ndim == 2: + if k in D and D[k].ndim == 4 and state_dict[k].ndim == 2: state_dict[k] = state_dict[k][:, :, None, None] return state_dict diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 7a87eb36..c51e61e3 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -266,7 +266,7 @@ class MlpMixer(nn.Module): act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) self.norm = norm_layer(embed_dim) - self.head = nn.Linear(embed_dim, self.num_classes) # zero init + self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.init_weights(nlhb=nlhb) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index a3c89532..8110fcca 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -424,7 +424,8 @@ def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/ model.stem.conv.weight.copy_(stem_conv_w) model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) - if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + if isinstance(model.head.fc, nn.Conv2d) and \ + model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) for i, (sname, stage) in enumerate(model.stages.named_children()): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 16631027..7740f381 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -237,7 +237,6 @@ class Visformer(nn.Module): self.num_features = embed_dim if self.vit_stem else embed_dim * 2 self.norm = norm_layer(self.num_features) self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) - self.head = nn.Linear(self.num_features, num_classes) # weights init if self.pos_embed: diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 89fba7de..0a960987 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -448,7 +448,7 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) for i, block in enumerate(model.blocks.children()):