Davit std (#6)

Separate patch_embed module
pull/1583/head
Fredo Guan 2 years ago committed by GitHub
parent 546590c5f5
commit 10b3f696b4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,25 +12,21 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
# FIXME remove unused imports
import itertools
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torch.utils.checkpoint as checkpoint
from .features import FeatureInfo
from .fx_features import register_notrace_function, register_notrace_module
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .pretrained import generate_default_cfgs
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, ClassifierHead, Mlp
from ._builder import build_model_with_cfg
from ._features import FeatureInfo
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['DaViT']
@ -391,7 +387,7 @@ class DaViTStage(nn.Module):
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
norm_layer=norm_layer,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
@ -403,7 +399,7 @@ class DaViTStage(nn.Module):
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
norm_layer=norm_layer,
ffn=ffn,
cpe_act=cpe_act
))
@ -476,7 +472,8 @@ class DaViT(nn.Module):
self.drop_rate=drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.patch_embed = None
stages = []
for stage_id in range(self.num_stages):
@ -499,6 +496,10 @@ class DaViT(nn.Module):
cpe_act = cpe_act
)
if stage_id == 0:
self.patch_embed = stage.patch_embed
stage.patch_embed = nn.Identity()
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
@ -533,6 +534,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stages(x)
# take final feature and norm
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
@ -562,8 +564,10 @@ def checkpoint_filter_fn(state_dict, model):
import re
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('stages.0.patch_embed', 'patch_embed')
k = k.replace('head.', 'head.fc.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
@ -596,12 +600,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}
# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
@ -635,8 +640,6 @@ def davit_base(pretrained=False, **kwargs):
num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
# TODO contact authors to get larger pretrained models
@register_model
def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536),

Loading…
Cancel
Save