|
|
@ -198,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|
|
|
|
|
|
|
|
|
|
|
classifier_name = cfg['classifier']
|
|
|
|
classifier_name = cfg['classifier']
|
|
|
|
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
|
|
|
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
|
|
|
|
|
|
|
# FIXME this special case is problematic as number of pretrained weight sources increases
|
|
|
|
# special case for imagenet trained models with extra background class in pretrained weights
|
|
|
|
# special case for imagenet trained models with extra background class in pretrained weights
|
|
|
|
classifier_weight = state_dict[classifier_name + '.weight']
|
|
|
|
classifier_weight = state_dict[classifier_name + '.weight']
|
|
|
|
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
|
|
|
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
|
|
|