From 7987f0c83d3f296e35eadc5a39c14ae002d0e2af Mon Sep 17 00:00:00 2001 From: szingaro Date: Mon, 15 Feb 2021 19:55:48 +0100 Subject: [PATCH] up --- timm/models/helpers.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501da..ca47dfa4 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -200,12 +200,17 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non _logger.warning( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - classifier_name = cfg['classifier'] + classifiers = cfg['classifier'] label_offset = cfg.get('label_offset', 0) if num_classes != cfg['num_classes']: # completely discard fully connected if model num_classes doesn't match pretrained weights - del state_dict[classifier_name + '.weight'] - del state_dict[classifier_name + '.bias'] + if isinstance(classifiers, str): + classifiers = (classifiers,) + for classifier_name in classifiers: + classifier_weight = classifier_name + '.weight' + classifier_bias = classifier_name + '.bias' + del state_dict[classifier_weight] + del state_dict[classifier_bias] strict = False elif label_offset > 0: # special case for pretrained weights with an extra background class in pretrained weights