|
|
|
@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def add_ml_decoder_head(model):
|
|
|
|
|
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50
|
|
|
|
|
if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50
|
|
|
|
|
model.global_pool = nn.Identity()
|
|
|
|
|
del model.fc
|
|
|
|
|
num_classes = model.num_classes
|
|
|
|
|
num_features = model.num_features
|
|
|
|
|
model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
|
|
|
|
elif hasattr(model, 'head'): # tresnet
|
|
|
|
|
elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet
|
|
|
|
|
model.global_pool = nn.Identity()
|
|
|
|
|
del model.classifier
|
|
|
|
|
num_classes = model.num_classes
|
|
|
|
|
num_features = model.num_features
|
|
|
|
|
model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
|
|
|
|
elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head')
|
|
|
|
|
del model.head
|
|
|
|
|
num_classes = model.num_classes
|
|
|
|
|
num_features = model.num_features
|
|
|
|
|
model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features)
|
|
|
|
|
else:
|
|
|
|
|
print("model is not suited for ml-decoder")
|
|
|
|
|
print("Model code-writing is not aligned currently with ml-decoder")
|
|
|
|
|
exit(-1)
|
|
|
|
|
|
|
|
|
|
if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout
|
|
|
|
|
model.drop_rate = 0
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|