diff --git a/timm/models/davit.py b/timm/models/davit.py index 7834546c..e0bb0aa0 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -382,8 +382,8 @@ class DaViT(nn.Module): attn_drop_rate=0., img_size=224, num_classes=1000, - global_pool='avg', - features_only = False + global_pool='avg'#, + #features_only = False ): super().__init__() @@ -399,7 +399,8 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - self._features_only = False + self._features_only = features_only + self.feature_info = [] self.patch_embeds = nn.ModuleList([ PatchEmbed(patch_size=patch_size if i == 0 else 2, @@ -438,7 +439,10 @@ class DaViT(nn.Module): for attention_id, attention_type in enumerate(attention_types)] ) for layer_id, item in enumerate(block_param) ]) - main_blocks.append(block) + + main_blocks.append((f'block.{block_id}', block)) + + self.feature_info += [dict(num_ch=self.embed_dims[block_id], reduction = 2, module=f'block.{block_id}')] self.main_blocks = nn.ModuleList(main_blocks) ''' @@ -457,7 +461,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) - self._update_forward_fn() + #self._update_forward_fn() #self.forward = self._get_forward_fn() ''' @@ -482,6 +486,7 @@ class DaViT(nn.Module): else: return self.forward_classification ''' + @torch.jit.ignore def _update_forward_fn(self): if self._features_only == True: