From 41559247e9f282e0a1df9dca1f6173b7f8e86399 Mon Sep 17 00:00:00 2001 From: talrid Date: Mon, 22 Nov 2021 17:50:39 +0200 Subject: [PATCH 1/4] use_ml_decoder_head --- timm/models/factory.py | 5 ++ timm/models/layers/ml_decoder.py | 149 +++++++++++++++++++++++++++++++ train.py | 4 +- 3 files changed, 157 insertions(+), 1 deletion(-) create mode 100644 timm/models/layers/ml_decoder.py diff --git a/timm/models/factory.py b/timm/models/factory.py index d040a9ff..40453380 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -29,6 +29,7 @@ def create_model( scriptable=None, exportable=None, no_jit=None, + use_ml_decoder_head=False, **kwargs): """Create a model @@ -80,6 +81,10 @@ def create_model( with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): model = create_fn(pretrained=pretrained, **kwargs) + if use_ml_decoder_head: + from timm.models.layers.ml_decoder import add_ml_decoder_head + model = add_ml_decoder_head(model) + if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/layers/ml_decoder.py b/timm/models/layers/ml_decoder.py new file mode 100644 index 00000000..87815aaa --- /dev/null +++ b/timm/models/layers/ml_decoder.py @@ -0,0 +1,149 @@ +from typing import Optional + +import torch +from torch import nn +from torch import nn, Tensor +from torch.nn.modules.transformer import _get_activation_fn + + +def add_ml_decoder_head(model): + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 + model.global_pool = nn.Identity() + del model.fc + num_classes = model.num_classes + num_features = model.num_features + model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif hasattr(model, 'head'): # tresnet + del model.head + num_classes = model.num_classes + num_features = model.num_features + model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + else: + print("model is not suited for ml-decoder") + exit(-1) + + return model + + +class TransformerDecoderLayerOptimal(nn.Module): + def __init__(self, d_model, nhead=8, dim_feedforward=2048, dropout=0.1, activation="relu", + layer_norm_eps=1e-5) -> None: + super(TransformerDecoderLayerOptimal, self).__init__() + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout = nn.Dropout(dropout) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) + + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = torch.nn.functional.relu + super(TransformerDecoderLayerOptimal, self).__setstate__(state) + + def forward(self, tgt: Tensor, memory: Tensor, tgt_mask: Optional[Tensor] = None, + memory_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None) -> Tensor: + tgt = tgt + self.dropout1(tgt) + tgt = self.norm1(tgt) + tgt2 = self.multihead_attn(tgt, memory, memory)[0] + tgt = tgt + self.dropout2(tgt2) + tgt = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + return tgt + + +# @torch.jit.script +# class ExtrapClasses(object): +# def __init__(self, num_queries: int, group_size: int): +# self.num_queries = num_queries +# self.group_size = group_size +# +# def __call__(self, h: torch.Tensor, class_embed_w: torch.Tensor, class_embed_b: torch.Tensor, out_extrap: +# torch.Tensor): +# # h = h.unsqueeze(-1).expand(-1, -1, -1, self.group_size) +# h = h[..., None].repeat(1, 1, 1, self.group_size) # torch.Size([bs, 5, 768, groups]) +# w = class_embed_w.view((self.num_queries, h.shape[2], self.group_size)) +# out = (h * w).sum(dim=2) + class_embed_b +# out = out.view((h.shape[0], self.group_size * self.num_queries)) +# return out + +@torch.jit.script +class GroupFC(object): + def __init__(self, embed_len_decoder: int): + self.embed_len_decoder = embed_len_decoder + + def __call__(self, h: torch.Tensor, duplicate_pooling: torch.Tensor, out_extrap: torch.Tensor): + for i in range(self.embed_len_decoder): + h_i = h[:, i, :] + w_i = duplicate_pooling[i, :, :] + out_extrap[:, i, :] = torch.matmul(h_i, w_i) + + +class MLDecoder(nn.Module): + def __init__(self, num_classes, num_of_groups=-1, decoder_embedding=768, initial_num_features=2048): + super(MLDecoder, self).__init__() + embed_len_decoder = 100 if num_of_groups < 0 else num_of_groups + if embed_len_decoder > num_classes: + embed_len_decoder = num_classes + + # switching to 768 initial embeddings + decoder_embedding = 768 if decoder_embedding < 0 else decoder_embedding + self.embed_standart = nn.Linear(initial_num_features, decoder_embedding) + + # decoder + decoder_dropout = 0.1 + num_layers_decoder = 1 + dim_feedforward = 2048 + layer_decode = TransformerDecoderLayerOptimal(d_model=decoder_embedding, + dim_feedforward=dim_feedforward, dropout=decoder_dropout) + self.decoder = nn.TransformerDecoder(layer_decode, num_layers=num_layers_decoder) + + # non-learnable queries + self.query_embed = nn.Embedding(embed_len_decoder, decoder_embedding) + self.query_embed.requires_grad_(False) + + # group fully-connected + self.num_classes = num_classes + self.duplicate_factor = int(num_classes / embed_len_decoder + 0.999) + self.duplicate_pooling = torch.nn.Parameter( + torch.Tensor(embed_len_decoder, decoder_embedding, self.duplicate_factor)) + self.duplicate_pooling_bias = torch.nn.Parameter(torch.Tensor(num_classes)) + torch.nn.init.xavier_normal_(self.duplicate_pooling) + torch.nn.init.constant_(self.duplicate_pooling_bias, 0) + self.group_fc = GroupFC(embed_len_decoder) + + def forward(self, x): + if len(x.shape) == 4: # [bs,2048, 7,7] + embedding_spatial = x.flatten(2).transpose(1, 2) + else: # [bs, 197,468] + embedding_spatial = x + embedding_spatial_786 = self.embed_standart(embedding_spatial) + embedding_spatial_786 = torch.nn.functional.relu(embedding_spatial_786, inplace=True) + + bs = embedding_spatial_786.shape[0] + query_embed = self.query_embed.weight + # tgt = query_embed.unsqueeze(1).repeat(1, bs, 1) + tgt = query_embed.unsqueeze(1).expand(-1, bs, -1) # no allocation of memory with expand + h = self.decoder(tgt, embedding_spatial_786.transpose(0, 1)) # [embed_len_decoder, batch, 768] + h = h.transpose(0, 1) + + out_extrap = torch.zeros(h.shape[0], h.shape[1], self.duplicate_factor, device=h.device, dtype=h.dtype) + self.group_fc(h, self.duplicate_pooling, out_extrap) + h_out = out_extrap.flatten(1)[:, :self.num_classes] + h_out += self.duplicate_pooling_bias + logits = h_out + return logits diff --git a/train.py b/train.py index 10d839be..44a0e292 100755 --- a/train.py +++ b/train.py @@ -115,6 +115,7 @@ parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='validation batch size override (default: None)') +parser.add_argument('--use-ml-decoder-head', type=int, default=0) # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -379,7 +380,8 @@ def main(): bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint) + checkpoint_path=args.initial_checkpoint, + ml_decoder_head=args.use_ml_decoder_head) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly From b6c180ef19714ae33a186b1070e991e7403d81cd Mon Sep 17 00:00:00 2001 From: talrid Date: Sun, 28 Nov 2021 13:56:55 +0200 Subject: [PATCH 2/4] use_ml_decoder_head --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 44a0e292..42985e12 100755 --- a/train.py +++ b/train.py @@ -381,7 +381,7 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, - ml_decoder_head=args.use_ml_decoder_head) + use_ml_decoder_head=args.use_ml_decoder_head) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly From c11f4c3218958599d990fdc91c00d4560aa1c5bb Mon Sep 17 00:00:00 2001 From: talrid Date: Tue, 30 Nov 2021 08:48:08 +0200 Subject: [PATCH 3/4] support CNNs --- timm/models/layers/ml_decoder.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/ml_decoder.py b/timm/models/layers/ml_decoder.py index 87815aaa..3f828c6d 100644 --- a/timm/models/layers/ml_decoder.py +++ b/timm/models/layers/ml_decoder.py @@ -7,21 +7,28 @@ from torch.nn.modules.transformer import _get_activation_fn def add_ml_decoder_head(model): - if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # resnet50 + if hasattr(model, 'global_pool') and hasattr(model, 'fc'): # most CNN models, like Resnet50 model.global_pool = nn.Identity() del model.fc num_classes = model.num_classes num_features = model.num_features model.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features) - elif hasattr(model, 'head'): # tresnet + elif hasattr(model, 'global_pool') and hasattr(model, 'classifier'): # EfficientNet + model.global_pool = nn.Identity() + del model.classifier + num_classes = model.num_classes + num_features = model.num_features + model.classifier = MLDecoder(num_classes=num_classes, initial_num_features=num_features) + elif 'RegNet' in model._get_name() or 'TResNet' in model._get_name(): # hasattr(model, 'head') del model.head num_classes = model.num_classes num_features = model.num_features model.head = MLDecoder(num_classes=num_classes, initial_num_features=num_features) else: - print("model is not suited for ml-decoder") + print("Model code-writing is not aligned currently with ml-decoder") exit(-1) - + if hasattr(model, 'drop_rate'): # Ml-Decoder has inner dropout + model.drop_rate = 0 return model From d98aa47d12d27d941fe019fb1ab9b52cde670056 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 21 Mar 2022 12:29:02 -0700 Subject: [PATCH 4/4] Revert ml-decoder changes to model factory and train script --- timm/models/factory.py | 5 ----- train.py | 4 +--- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index 40453380..d040a9ff 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -29,7 +29,6 @@ def create_model( scriptable=None, exportable=None, no_jit=None, - use_ml_decoder_head=False, **kwargs): """Create a model @@ -81,10 +80,6 @@ def create_model( with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): model = create_fn(pretrained=pretrained, **kwargs) - if use_ml_decoder_head: - from timm.models.layers.ml_decoder import add_ml_decoder_head - model = add_ml_decoder_head(model) - if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/train.py b/train.py index 42985e12..10d839be 100755 --- a/train.py +++ b/train.py @@ -115,7 +115,6 @@ parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='input batch size for training (default: 128)') parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='validation batch size override (default: None)') -parser.add_argument('--use-ml-decoder-head', type=int, default=0) # Optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -380,8 +379,7 @@ def main(): bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint, - use_ml_decoder_head=args.use_ml_decoder_head) + checkpoint_path=args.initial_checkpoint) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly