support CNNs

pull/1012/head
talrid 3 years ago
parent d6701d8a81
commit c11f4c3218

@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn
def add_ml_decoder_head(model): def add_ml_decoder_head(model):
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
model.global_pool = nn.Identity() model.global_pool = nn.Identity()
del model.fc del model.fc
num_classes = model.num_classes num_classes = model.num_classes
num_features = model.num_features num_features = model.num_features
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif hasattr(model, 'head'): # tresnet elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
model.global_pool = nn.Identity()
del model.classifier
num_classes = model.num_classes
num_features = model.num_features
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
del model.head del model.head
num_classes = model.num_classes num_classes = model.num_classes
num_features = model.num_features num_features = model.num_features
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
else: else:
print("model is not suited for ml-decoder") print("Model code-writing is not aligned currently with ml-decoder")
exit(-1) exit(-1)
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
model.drop_rate = 0
return model return model

Loading…
Cancel
Save