From 2b49ab7a366ca36da608988a08cdf953539dfe18 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 3 Apr 2021 11:18:12 -0700 Subject: [PATCH] Fix ResNetV2 pretrained classifier issue. Fixes #540 --- tests/test_models.py | 2 +- timm/models/resnetv2.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 639a0534..4fbdc85b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -132,7 +132,7 @@ if 'GITHUB_ACTIONS' not in os.environ: def test_model_load_pretrained(model_name, batch_size): """Create that pretrained weights load, verify support for in_chans != 3 while doing so.""" in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change - create_model(model_name, pretrained=True, in_chans=in_chans) + create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5) @pytest.mark.timeout(120) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS)) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 80e0943d..0ca6fba9 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -365,6 +365,7 @@ class ResNetV2(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) @@ -393,8 +394,9 @@ class ResNetV2(nn.Module): self.stem.conv.weight.copy_(stem_conv_w) self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) - self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) - self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) + if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) + self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) for i, (sname, stage) in enumerate(self.stages.named_children()): for j, (bname, block) in enumerate(stage.blocks.named_children()): convname = 'standardized_conv2d'