Fix ResNetV2 pretrained classifier issue. Fixes #540

pull/555/head
Ross Wightman 4 years ago
parent de9dff933a
commit 2b49ab7a36

@ -132,7 +132,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
def test_model_load_pretrained(model_name, batch_size): def test_model_load_pretrained(model_name, batch_size):
"""Create that pretrained weights load, verify support for in_chans != 3 while doing so.""" """Create that pretrained weights load, verify support for in_chans != 3 while doing so."""
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans) create_model(model_name, pretrained=True, in_chans=in_chans, num_classes=5)
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS)) @pytest.mark.parametrize('model_name', list_models(pretrained=True, exclude_filters=NON_STD_FILTERS))

@ -365,6 +365,7 @@ class ResNetV2(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'): def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.head = ClassifierHead( self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
@ -393,6 +394,7 @@ class ResNetV2(nn.Module):
self.stem.conv.weight.copy_(stem_conv_w) self.stem.conv.weight.copy_(stem_conv_w)
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel']))
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(self.stages.named_children()): for i, (sname, stage) in enumerate(self.stages.named_children()):

Loading…
Cancel
Save