pull/434/head
szingaro 4 years ago
parent 9dec7feef9
commit 7987f0c83d

@ -200,12 +200,17 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
_logger.warning( _logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifier_name = cfg['classifier'] classifiers = cfg['classifier']
label_offset = cfg.get('label_offset', 0) label_offset = cfg.get('label_offset', 0)
if num_classes != cfg['num_classes']: if num_classes != cfg['num_classes']:
# completely discard fully connected if model num_classes doesn't match pretrained weights # completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight'] if isinstance(classifiers, str):
del state_dict[classifier_name + '.bias'] classifiers = (classifiers,)
for classifier_name in classifiers:
classifier_weight = classifier_name + '.weight'
classifier_bias = classifier_name + '.bias'
del state_dict[classifier_weight]
del state_dict[classifier_bias]
strict = False strict = False
elif label_offset > 0: elif label_offset > 0:
# special case for pretrained weights with an extra background class in pretrained weights # special case for pretrained weights with an extra background class in pretrained weights

Loading…
Cancel
Save