diff --git a/timm/models/layers/ml_decoder.py b/timm/models/layers/ml_decoder.py index 87815aaa..3f828c6d 100644 --- a/timm/models/layers/ml_decoder.py +++ b/timm/models/layers/ml_decoder.py @@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn def add_ml_decoder_head(model): - if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 model.global_pool = nn.Identity() del model.fc num_classes = model.num_classes num_features = model.num_features model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) - elif hasattr(model, 'head'): # tresnet + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet + model.global_pool = nn.Identity() + del model.classifier + num_classes = model.num_classes + num_features = model.num_features + model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') del model.head num_classes = model.num_classes num_features = model.num_features model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) else: - print("model is not suited for ml-decoder") + print("Model code-writing is not aligned currently with ml-decoder") exit(-1) - + if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout + model.drop_rate = 0 return model