|
|
@ -22,12 +22,17 @@ from typing import Tuple
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import timm
|
|
|
|
from .helpers import build_model_with_cfg
|
|
|
|
from timm.models.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -553,6 +558,31 @@ class DaViT(nn.Module):
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
|
|
|
'url': url,
|
|
|
|
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
|
|
|
|
|
|
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
|
|
|
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
|
|
|
|
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
|
|
|
|
|
|
|
|
**kwargs
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = generate_default_cfgs({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'davit_tiny.msft_in1k': _cfg(
|
|
|
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"),
|
|
|
|
|
|
|
|
'davit_small.msft_in1k': _cfg(
|
|
|
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"),
|
|
|
|
|
|
|
|
'davit_base.msft_in1k': _cfg(
|
|
|
|
|
|
|
|
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"),
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
def davit_tiny(pretrained=False, **kwargs):
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|
model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),
|
|
|
|