|
|
@ -12,25 +12,21 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
|
|
|
|
# All rights reserved.
|
|
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
# This source code is licensed under the MIT license
|
|
|
|
|
|
|
|
|
|
|
|
# FIXME remove unused imports
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import itertools
|
|
|
|
import itertools
|
|
|
|
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
from torch import Tensor
|
|
|
|
from torch import Tensor
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .features import FeatureInfo
|
|
|
|
|
|
|
|
from .fx_features import register_notrace_function, register_notrace_module
|
|
|
|
|
|
|
|
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
|
|
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
from timm.layers import DropPath, to_2tuple, trunc_normal_, ClassifierHead, Mlp
|
|
|
|
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
|
|
|
|
from ._features import FeatureInfo
|
|
|
|
|
|
|
|
from ._features_fx import register_notrace_function
|
|
|
|
|
|
|
|
from ._manipulate import checkpoint_seq
|
|
|
|
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
|
|
|
|
from ._registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
@ -391,7 +387,7 @@ class DaViTStage(nn.Module):
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
ffn=ffn,
|
|
|
|
ffn=ffn,
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
cpe_act=cpe_act,
|
|
|
|
window_size=window_size,
|
|
|
|
window_size=window_size,
|
|
|
@ -403,7 +399,7 @@ class DaViTStage(nn.Module):
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
mlp_ratio=mlp_ratio,
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
qkv_bias=qkv_bias,
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
|
|
|
|
norm_layer=nn.LayerNorm,
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
ffn=ffn,
|
|
|
|
ffn=ffn,
|
|
|
|
cpe_act=cpe_act
|
|
|
|
cpe_act=cpe_act
|
|
|
|
))
|
|
|
|
))
|
|
|
@ -476,7 +472,8 @@ class DaViT(nn.Module):
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.feature_info = []
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed = None
|
|
|
|
stages = []
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
@ -499,6 +496,10 @@ class DaViT(nn.Module):
|
|
|
|
cpe_act = cpe_act
|
|
|
|
cpe_act = cpe_act
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if stage_id == 0:
|
|
|
|
|
|
|
|
self.patch_embed = stage.patch_embed
|
|
|
|
|
|
|
|
stage.patch_embed = nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
stages.append(stage)
|
|
|
|
stages.append(stage)
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
@ -533,6 +534,7 @@ class DaViT(nn.Module):
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
x = self.stages(x)
|
|
|
|
x = self.stages(x)
|
|
|
|
# take final feature and norm
|
|
|
|
# take final feature and norm
|
|
|
|
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
@ -562,8 +564,10 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
|
|
|
|
k = k.replace('stages.0.patch_embed', 'patch_embed')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
k = k.replace('cpe.1', 'cpe2')
|
|
|
|
k = k.replace('cpe.1', 'cpe2')
|
|
|
@ -596,12 +600,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
|
|
|
|
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
# official microsoft weights from https://github.com/dingmyu/davit
|
|
|
|
# official microsoft weights from https://github.com/dingmyu/davit
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
@ -635,8 +640,6 @@ def davit_base(pretrained=False, **kwargs):
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
num_heads=(4, 8, 16, 32), **kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# TODO contact authors to get larger pretrained models
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
def davit_large(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),
|
|
|
|