From c43340ddd4e45e728781f8c06d9ebd589894e62c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sun, 11 Dec 2022 03:03:22 -0800 Subject: [PATCH] Davit std (#5) * Update davit.py * Update test_models.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * starting point * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update test_models.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Davit revised (#4) * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py clean up * Update test_models.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update davit.py * Update test_models.py * Update davit.py --- tests/test_models.py | 4 +- timm/models/davit.py | 340 +++++++++++++++++++++++-------------------- 2 files changed, 181 insertions(+), 163 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 97872fde..008d87b7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -40,7 +40,7 @@ if 'GITHUB_ACTIONS' in os.environ: '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', - 'swin*giant*'] + 'swin*giant*', 'davit*giant', 'davit*huge'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] else: EXCLUDE_FILTERS = [] @@ -271,7 +271,7 @@ if 'GITHUB_ACTIONS' not in os.environ: EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point 'vit_large_*', 'vit_huge_*', 'vit_gi*', ] diff --git a/timm/models/davit.py b/timm/models/davit.py index eda928e4..e551cc61 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -12,8 +12,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # All rights reserved. # This source code is licensed under the MIT license +# FIXME remove unused imports + import itertools -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List +from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, overload, Tuple, TypeVar, Union from collections import OrderedDict import torch @@ -32,6 +34,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] + class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): @@ -50,25 +53,21 @@ class ConvPosEnc(nn.Module): self.norm = nn.LayerNorm(dim) self.activation = nn.GELU() if act else nn.Identity() - def forward(self, x : Tensor, size: Tuple[int, int]): - B, N, C = x.shape - H, W = size + def forward(self, x : Tensor): + B, C, H, W = x.shape - feat = x.transpose(1, 2).view(B, C, H, W) - feat = self.proj(feat) + #feat = x.transpose(1, 2).view(B, C, H, W) + feat = self.proj(x) if self.normtype == 'batch': feat = self.norm(feat).flatten(2).transpose(1, 2) elif self.normtype == 'layer': feat = self.norm(feat.flatten(2).transpose(1, 2)) else: feat = feat.flatten(2).transpose(1, 2) - x = x + self.activation(feat) + x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W) return x - -# reason: dim in control sequence -# FIXME reimplement to allow tracing -@register_notrace_module + class PatchEmbed(nn.Module): """ Size-agnostic implementation of 2D image to patch embedding, allowing input size to be adjusted during model forward operation @@ -76,13 +75,15 @@ class PatchEmbed(nn.Module): def __init__( self, - patch_size=16, + patch_size=4, in_chans=3, embed_dim=96, overlapped=False): super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim if patch_size[0] == 4: self.proj = nn.Conv2d( @@ -104,31 +105,20 @@ class PatchEmbed(nn.Module): self.norm = nn.LayerNorm(in_chans) - def forward(self, x : Tensor, size: Tuple[int, int]): - H, W = size - dim = x.dim() - if dim == 3: - B, HW, C = x.shape - x = self.norm(x) - x = x.reshape(B, - H, - W, - C).permute(0, 3, 1, 2).contiguous() - + def forward(self, x : Tensor): B, C, H, W = x.shape - if W % self.patch_size[1] != 0: - x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) - if H % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if self.norm.normalized_shape[0] == self.in_chans: + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1])) + x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0])) x = self.proj(x) - newsize = (x.size(2), x.size(3)) - x = x.flatten(2).transpose(1, 2) - if dim == 4: - x = self.norm(x) - return x, newsize - + if self.norm.normalized_shape[0] == self.embed_dim: + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + class ChannelAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False): @@ -153,7 +143,7 @@ class ChannelAttention(nn.Module): x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x - + class ChannelBlock(nn.Module): @@ -162,13 +152,13 @@ class ChannelBlock(nn.Module): ffn=True, cpe_act=False): super().__init__() - self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), - ConvPosEnc(dim=dim, k=3, act=cpe_act)]) + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.ffn = ffn self.norm1 = norm_layer(dim) self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + if self.ffn: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -178,17 +168,23 @@ class ChannelBlock(nn.Module): act_layer=act_layer) - def forward(self, x : Tensor, size: Tuple[int, int]): - x = self.cpe[0](x, size) + def forward(self, x : Tensor): + + B, C, H, W = x.shape + + x = self.cpe1(x).flatten(2).transpose(1, 2) + cur = self.norm1(x) cur = self.attn(cur) x = x + self.drop_path(cur) - x = self.cpe[1](x, size) + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) - return x, size - + + x = x.transpose(1, 2).view(B, C, H, W) + + return x def window_partition(x : Tensor, window_size: int): """ @@ -283,9 +279,8 @@ class SpatialBlock(nn.Module): self.num_heads = num_heads self.window_size = window_size self.mlp_ratio = mlp_ratio - self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), - ConvPosEnc(dim=dim, k=3, act=cpe_act)]) - + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, @@ -294,7 +289,8 @@ class SpatialBlock(nn.Module): qkv_bias=qkv_bias) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() - + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + if self.ffn: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -304,12 +300,11 @@ class SpatialBlock(nn.Module): act_layer=act_layer) - def forward(self, x : Tensor, size: Tuple[int, int]): + def forward(self, x : Tensor): + B, C, H, W = x.shape - H, W = size - B, L, C = x.shape - shortcut = self.cpe[0](x, size) + shortcut = self.cpe1(x).flatten(2).transpose(1, 2) x = self.norm1(shortcut) x = x.view(B, H, W, C) @@ -338,11 +333,92 @@ class SpatialBlock(nn.Module): x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) - x = self.cpe[1](x, size) + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) - return x, size + + x = x.transpose(1, 2).view(B, C, H, W) + + return x + +class DaViTStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth = 1, + patch_size = 4, + overlapped_patch = False, + attention_types = ('spatial', 'channel'), + num_heads = 3, + window_size = 7, + mlp_ratio = 4, + qkv_bias = True, + drop_path_rates = (0, 0), + norm_layer = nn.LayerNorm, + ffn = True, + cpe_act = False + ): + super().__init__() + + self.grad_checkpointing = False + + # patch embedding layer at the beginning of each stage + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chs, + embed_dim=out_chs, + overlapped=overlapped_patch + ) + ''' + repeating alternating attention blocks in each stage + default: (spatial -> channel) x depth + + potential opportunity to integrate with a more general version of ByobNet/ByoaNet + since the logic is similar + ''' + stage_blocks = [] + for block_idx in range(depth): + + dual_attention_block = [] + + for attention_id, attention_type in enumerate(attention_types): + if attention_type == 'spatial': + dual_attention_block.append(SpatialBlock( + dim=out_chs, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act, + window_size=window_size, + )) + elif attention_type == 'channel': + dual_attention_block.append(ChannelBlock( + dim=out_chs, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act + )) + + stage_blocks.append(nn.Sequential(*dual_attention_block)) + + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x : Tensor): + x = self.patch_embed(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x class DaViT(nn.Module): @@ -392,7 +468,7 @@ class DaViT(nn.Module): self.embed_dims = embed_dims self.num_heads = num_heads self.num_stages = len(self.embed_dims) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))] + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))] assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1) self.num_classes = num_classes @@ -401,51 +477,37 @@ class DaViT(nn.Module): self.grad_checkpointing = False self.feature_info = [] - self.patch_embeds = nn.ModuleList([ - PatchEmbed(patch_size=patch_size if i == 0 else 2, - in_chans=in_chans if i == 0 else self.embed_dims[i - 1], - embed_dim=self.embed_dims[i], - overlapped=overlapped_patch) - for i in range(self.num_stages)]) - - self.stages = nn.ModuleList() - for stage_id, stage_param in enumerate(self.architecture): - layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) - - stage = nn.ModuleList([ - nn.ModuleList([ - ChannelBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], - norm_layer=nn.LayerNorm, - ffn=ffn, - cpe_act=cpe_act - ) if attention_type == 'channel' else - SpatialBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], - norm_layer=nn.LayerNorm, - ffn=ffn, - cpe_act=cpe_act, - window_size=window_size, - ) if attention_type == 'spatial' else None - for attention_id, attention_type in enumerate(attention_types)] - ) for layer_id, item in enumerate(stage_param) - ]) + stages = [] + + for stage_id in range(self.num_stages): + stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])] + + stage = DaViTStage( + in_chans if stage_id == 0 else embed_dims[stage_id - 1], + embed_dims[stage_id], + depth = depths[stage_id], + patch_size = patch_size if stage_id == 0 else 2, + overlapped_patch = overlapped_patch, + attention_types = attention_types, + num_heads = num_heads[stage_id], + window_size = window_size, + mlp_ratio = mlp_ratio, + qkv_bias = qkv_bias, + drop_path_rates = stage_drop_rates, + norm_layer = nn.LayerNorm, + ffn = ffn, + cpe_act = cpe_act + ) - self.stages.add_module(f'stage_{stage_id}', stage) - self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')] - + stages.append(stage) + self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] + + + self.stages = nn.Sequential(*stages) + self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) - def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -469,46 +531,13 @@ class DaViT(nn.Module): if global_pool is None: global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - - - def forward_network(self, x): - size: Tuple[int, int] = (x.size(2), x.size(3)) - features = [x] - sizes = [size] - - for patch_layer, stage in zip(self.patch_embeds, self.stages): - features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) - for _, block in enumerate(stage): - for _, layer in enumerate(block): - if self.grad_checkpointing and not torch.jit.is_scripting(): - features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) - else: - features[-1], sizes[-1] = layer(features[-1], sizes[-1]) - - # don't append outputs of last stage, since they are already there - if(len(features) < self.num_stages): - features.append(features[-1]) - sizes.append(sizes[-1]) - - - # non-normalized pyramid features + corresponding sizes - return features, sizes - - def forward_pyramid_features(self, x) -> List[Tensor]: - x, sizes = self.forward_network(x) - outs = [] - for i, out in enumerate(x): - H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - - return outs def forward_features(self, x): - x, sizes = self.forward_network(x) + x = self.stages(x) # take final feature and norm - x = self.norms(x[-1]) - H, W = sizes[-1] - x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() + x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + #H, W = sizes[-1] + #x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() return x def forward_head(self, x, pre_logits: bool = False): @@ -521,17 +550,6 @@ class DaViT(nn.Module): def forward(self, x): return self.forward_classifier(x) - - -class DaViTFeatures(DaViT): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) - - def forward(self, x) -> List[Tensor]: - return self.forward_pyramid_features(x) - def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ @@ -541,38 +559,36 @@ def checkpoint_filter_fn(state_dict, model): if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] + import re out_dict = {} for k, v in state_dict.items(): - k = k.replace('main_blocks.', 'stages.stage_') + k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k) + k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) k = k.replace('head.', 'head.fc.') + k = k.replace('cpe.0', 'cpe1') + k = k.replace('cpe.1', 'cpe2') out_dict[k] = v return out_dict - - + def _create_davit(variant, pretrained=False, **kwargs): - model_cls = DaViT - features_only = False - kwargs_filter = None + + + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) - if kwargs.pop('features_only', False): - model_cls = DaViTFeatures - kwargs_filter = ('num_classes', 'global_pool') - features_only = True + model = build_model_with_cfg( - model_cls, + DaViT, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) - if features_only: - model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) - model.default_cfg = model.pretrained_cfg # backwards compat - return model - + return model + + def _cfg(url='', **kwargs): # not sure how this should be set up return { @@ -580,7 +596,7 @@ def _cfg(url='', **kwargs): # not sure how this should be set up 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bilinear', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', + 'first_conv': 'stages.0.patch_embed.proj', 'classifier': 'head.fc', **kwargs } @@ -594,6 +610,9 @@ default_cfgs = generate_default_cfgs({ url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), 'davit_base.msft_in1k': _cfg( url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), + 'davit_large': _cfg(), + 'davit_huge': _cfg(), + 'davit_giant': _cfg(), }) @@ -616,7 +635,7 @@ def davit_base(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) -''' models without weights + # TODO contact authors to get larger pretrained models @register_model def davit_large(pretrained=False, **kwargs): @@ -635,4 +654,3 @@ def davit_giant(pretrained=False, **kwargs): model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) -''' \ No newline at end of file