From d7930c70bd443bd5e7fa43f8d22ca31b30bb8e9e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 17:00:15 -0800 Subject: [PATCH] update --- timm/models/__init__.py | 1 + timm/models/davit.py | 34 ++++++++++++++++++++++++++++++++-- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 301186dd..b1f82789 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -8,6 +8,7 @@ from .convmixer import * from .convnext import * from .crossvit import * from .cspnet import * +from .davit import * from .deit import * from .densenet import * from .dla import * diff --git a/timm/models/davit.py b/timm/models/davit.py index 95dff9b3..af70c4d4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -22,12 +22,17 @@ from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F -import timm -from timm.models.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp +from .helpers import build_model_with_cfg +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from collections import OrderedDict import torch.utils.checkpoint as checkpoint +from .pretrained import generate_default_cfgs +from .registry import register_model +__all__ = ['DaViT'] + @@ -553,6 +558,31 @@ class DaViT(nn.Module): return model + + def _cfg(url='', **kwargs): # not sure how this should be set up + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', + **kwargs + } + + + + default_cfgs = generate_default_cfgs({ + + 'davit_tiny.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_tiny_ed28dd55.pth.tar"), + 'davit_small.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_small_d1ecf281.pth.tar"), + 'davit_base.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/untagged-b2178bcf50f43d660d99/davit_base_67d9ac26.pth.tar"), +}) + + + @register_model def davit_tiny(pretrained=False, **kwargs): model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768),