Davit std (#5)

* Update davit.py

* Update test_models.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* starting point

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update test_models.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Davit revised (#4)

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

clean up

* Update test_models.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update davit.py

* Update test_models.py

* Update davit.py
pull/1583/head
Fredo Guan 2 years ago committed by GitHub
parent edea013dd1
commit c43340ddd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,7 +40,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*']
'swin*giant*', 'davit*giant', 'davit*huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
@ -271,7 +271,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_gi*',
]

@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
# FIXME remove unused imports
import itertools
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union
from collections import OrderedDict
import torch
@ -32,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -50,25 +53,21 @@ class ConvPosEnc(nn.Module):
self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x : Tensor, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
def forward(self, x : Tensor):
B, C, H, W = x.shape
feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat)
#feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(x)
if self.normtype == 'batch':
feat = self.norm(feat).flatten(2).transpose(1, 2)
elif self.normtype == 'layer':
feat = self.norm(feat.flatten(2).transpose(1, 2))
else:
feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat)
x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W)
return x
# reason: dim in control sequence
# FIXME reimplement to allow tracing
@register_notrace_module
class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
@ -76,13 +75,15 @@ class PatchEmbed(nn.Module):
def __init__(
self,
patch_size=16,
patch_size=4,
in_chans=3,
embed_dim=96,
overlapped=False):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
if patch_size[0] == 4:
self.proj = nn.Conv2d(
@ -104,30 +105,19 @@ class PatchEmbed(nn.Module):
self.norm = nn.LayerNorm(in_chans)
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
dim = x.dim()
if dim == 3:
B, HW, C = x.shape
x = self.norm(x)
x = x.reshape(B,
H,
W,
C).permute(0, 3, 1, 2).contiguous()
def forward(self, x : Tensor):
B, C, H, W = x.shape
if W % self.patch_size[1] != 0:
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
if H % self.patch_size[0] != 0:
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
if self.norm.normalized_shape[0] == self.in_chans:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1]))
x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0]))
x = self.proj(x)
newsize = (x.size(2), x.size(3))
x = x.flatten(2).transpose(1, 2)
if dim == 4:
x = self.norm(x)
return x, newsize
if self.norm.normalized_shape[0] == self.embed_dim:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class ChannelAttention(nn.Module):
@ -162,12 +152,12 @@ class ChannelBlock(nn.Module):
ffn=True, cpe_act=False):
super().__init__()
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
@ -178,17 +168,23 @@ class ChannelBlock(nn.Module):
act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe[0](x, size)
def forward(self, x : Tensor):
B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path(cur)
x = self.cpe[1](x, size)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
x = x.transpose(1, 2).view(B, C, H, W)
return x
def window_partition(x : Tensor, window_size: int):
"""
@ -283,9 +279,8 @@ class SpatialBlock(nn.Module):
self.num_heads = num_heads
self.window_size = window_size
self.mlp_ratio = mlp_ratio
self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),
ConvPosEnc(dim=dim, k=3, act=cpe_act)])
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
@ -294,6 +289,7 @@ class SpatialBlock(nn.Module):
qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
@ -304,12 +300,11 @@ class SpatialBlock(nn.Module):
act_layer=act_layer)
def forward(self, x : Tensor, size: Tuple[int, int]):
def forward(self, x : Tensor):
B, C, H, W = x.shape
H, W = size
B, L, C = x.shape
shortcut = self.cpe[0](x, size)
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
@ -338,11 +333,92 @@ class SpatialBlock(nn.Module):
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = self.cpe[1](x, size)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2)
if self.ffn:
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x, size
x = x.transpose(1, 2).view(B, C, H, W)
return x
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth = 1,
patch_size = 4,
overlapped_patch = False,
attention_types = ('spatial', 'channel'),
num_heads = 3,
window_size = 7,
mlp_ratio = 4,
qkv_bias = True,
drop_path_rates = (0, 0),
norm_layer = nn.LayerNorm,
ffn = True,
cpe_act = False
):
super().__init__()
self.grad_checkpointing = False
# patch embedding layer at the beginning of each stage
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chs,
embed_dim=out_chs,
overlapped=overlapped_patch
)
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attention_id, attention_type in enumerate(attention_types):
if attention_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
elif attention_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x : Tensor):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class DaViT(nn.Module):
@ -392,7 +468,7 @@ class DaViT(nn.Module):
self.embed_dims = embed_dims
self.num_heads = num_heads
self.num_stages = len(self.embed_dims)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))]
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))]
assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1)
self.num_classes = num_classes
@ -401,52 +477,38 @@ class DaViT(nn.Module):
self.grad_checkpointing = False
self.feature_info = []
self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2,
in_chans=in_chans if i == 0 else self.embed_dims[i - 1],
embed_dim=self.embed_dims[i],
overlapped=overlapped_patch)
for i in range(self.num_stages)])
self.stages = nn.ModuleList()
for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = nn.ModuleList([
nn.ModuleList([
ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act
) if attention_type == 'channel' else
SpatialBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id],
norm_layer=nn.LayerNorm,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
) if attention_type == 'spatial' else None
for attention_id, attention_type in enumerate(attention_types)]
) for layer_id, item in enumerate(stage_param)
])
stages = []
self.stages.add_module(f'stage_{stage_id}', stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
for stage_id in range(self.num_stages):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = depths[stage_id],
patch_size = patch_size if stage_id == 0 else 2,
overlapped_patch = overlapped_patch,
attention_types = attention_types,
num_heads = num_heads[stage_id],
window_size = window_size,
mlp_ratio = mlp_ratio,
qkv_bias = qkv_bias,
drop_path_rates = stage_drop_rates,
norm_layer = nn.LayerNorm,
ffn = ffn,
cpe_act = cpe_act
)
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = nn.Sequential(*stages)
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
@ -470,45 +532,12 @@ class DaViT(nn.Module):
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_network(self, x):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]
for patch_layer, stage in zip(self.patch_embeds, self.stages):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for _, block in enumerate(stage):
for _, layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else:
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x):
x, sizes = self.forward_network(x)
x = self.stages(x)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#H, W = sizes[-1]
#x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x
def forward_head(self, x, pre_logits: bool = False):
@ -522,17 +551,6 @@ class DaViT(nn.Module):
def forward(self, x):
return self.forward_classifier(x)
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict:
@ -541,35 +559,33 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
k = k.replace('main_blocks.', 'stages.stage_')
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('head.', 'head.fc.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
return out_dict
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls,
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model
@ -580,7 +596,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc',
'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc',
**kwargs
}
@ -594,6 +610,9 @@ default_cfgs = generate_default_cfgs({
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
})
@ -616,7 +635,7 @@ def davit_base(pretrained=False, **kwargs):
num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
''' models without weights
# TODO contact authors to get larger pretrained models
@register_model
def davit_large(pretrained=False, **kwargs):
@ -635,4 +654,3 @@ def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
'''
Loading…
Cancel
Save