diff --git a/timm/models/davit.py b/timm/models/davit.py index e551cc61..0ccd2ae0 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,25 +12,21 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # All rights reserved. # This source code is licensed under the MIT license -# FIXME remove unused imports - import itertools -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union -from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F from torch import Tensor -import torch.utils.checkpoint as checkpoint - -from .features import FeatureInfo -from .fx_features import register_notrace_function, register_notrace_module -from .helpers import build_model_with_cfg, pretrained_cfg_for_features -from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp -from .pretrained import generate_default_cfgs -from .registry import register_model + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, to_2tuple, trunc_normal_, ClassifierHead, Mlp +from ._builder import build_model_with_cfg +from ._features import FeatureInfo +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model __all__ = ['DaViT'] @@ -391,7 +387,7 @@ class DaViTStage(nn.Module): mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, ffn=ffn, cpe_act=cpe_act, window_size=window_size, @@ -403,7 +399,7 @@ class DaViTStage(nn.Module): mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], - norm_layer=nn.LayerNorm, + norm_layer=norm_layer, ffn=ffn, cpe_act=cpe_act )) @@ -476,7 +472,8 @@ class DaViT(nn.Module): self.drop_rate=drop_rate self.grad_checkpointing = False self.feature_info = [] - + + self.patch_embed = None stages = [] for stage_id in range(self.num_stages): @@ -499,6 +496,10 @@ class DaViT(nn.Module): cpe_act = cpe_act ) + if stage_id == 0: + self.patch_embed = stage.patch_embed + stage.patch_embed = nn.Identity() + stages.append(stage) self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] @@ -533,6 +534,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) def forward_features(self, x): + x = self.patch_embed(x) x = self.stages(x) # take final feature and norm x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) @@ -562,8 +564,10 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): + k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k) k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) + k = k.replace('stages.0.patch_embed', 'patch_embed') k = k.replace('head.', 'head.fc.') k = k.replace('cpe.0', 'cpe1') k = k.replace('cpe.1', 'cpe2') @@ -596,12 +600,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc', + 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', **kwargs } +# TODO contact authors to get larger pretrained models default_cfgs = generate_default_cfgs({ # official microsoft weights from https://github.com/dingmyu/davit 'davit_tiny.msft_in1k': _cfg( @@ -635,8 +640,6 @@ def davit_base(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) - -# TODO contact authors to get larger pretrained models @register_model def davit_large(pretrained=False, **kwargs): model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),