From 9a53c3f727c37a7e45a5394fc2db5ea851c2cdf8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 27 Jan 2023 13:54:04 -0800 Subject: [PATCH] Finalize DaViT, some formatting and modelling simplifications (separate PatchEmbed to Stem + Downsample, weights on HF hub. --- tests/test_models.py | 1 - timm/models/davit.py | 555 ++++++++++++++++++++++--------------------- 2 files changed, 279 insertions(+), 277 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index eb470d5f..fdededc7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -29,7 +29,6 @@ NON_STD_FILTERS = [ 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'eva_*', 'flexivit*' ] -#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', ' NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/davit.py b/timm/models/davit.py index f57cc5ae..8b9e67b4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -11,9 +11,10 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # Copyright (c) 2022 Mingyu Ding # All rights reserved. # This source code is licensed under the MIT license - -from collections import OrderedDict import itertools +from collections import OrderedDict +from functools import partial +from typing import Tuple import torch import torch.nn as nn @@ -21,9 +22,8 @@ import torch.nn.functional as F from torch import Tensor from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp # ClassifierHead +from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer from ._builder import build_model_with_cfg -from ._features import FeatureInfo from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq from ._pretrained import generate_default_cfgs @@ -33,89 +33,83 @@ __all__ = ['DaViT'] class ConvPosEnc(nn.Module): - def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): - + def __init__(self, dim: int, k: int = 3, act: bool = False): super(ConvPosEnc, self).__init__() - self.proj = nn.Conv2d(dim, - dim, - to_2tuple(k), - to_2tuple(1), - to_2tuple(k // 2), - groups=dim) - self.normtype = normtype - self.norm = nn.Identity() - if self.normtype == 'batch': - self.norm = nn.BatchNorm2d(dim) - elif self.normtype == 'layer': - self.norm = nn.LayerNorm(dim) - self.activation = nn.GELU() if act else nn.Identity() - - def forward(self, x : Tensor): - B, C, H, W = x.shape - #feat = x.transpose(1, 2).view(B, C, H, W) + self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim) + self.act = nn.GELU() if act else nn.Identity() + + def forward(self, x: Tensor): feat = self.proj(x) - if self.normtype == 'batch': - feat = self.norm(feat).flatten(2).transpose(1, 2) - elif self.normtype == 'layer': - feat = self.norm(feat.flatten(2).transpose(1, 2)) - else: - feat = feat.flatten(2).transpose(1, 2) - x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W) + x = x + self.act(feat) return x -class PatchEmbed(nn.Module): +class Stem(nn.Module): """ Size-agnostic implementation of 2D image to patch embedding, allowing input size to be adjusted during model forward operation """ def __init__( self, - patch_size=4, - in_chans=3, - embed_dim=96, - overlapped=False): + in_chs=3, + out_chs=96, + stride=4, + norm_layer=LayerNorm2d, + ): super().__init__() - patch_size = to_2tuple(patch_size) - self.patch_size = patch_size - self.in_chans = in_chans - self.embed_dim = embed_dim - - if patch_size[0] == 4: - self.proj = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=(7, 7), - stride=patch_size, - padding=(3, 3)) - self.norm = nn.LayerNorm(embed_dim) - if patch_size[0] == 2: - kernel = 3 if overlapped else 2 - pad = 1 if overlapped else 0 - self.proj = nn.Conv2d( - in_chans, - embed_dim, - kernel_size=to_2tuple(kernel), - stride=patch_size, - padding=to_2tuple(pad)) - self.norm = nn.LayerNorm(in_chans) - - - def forward(self, x : Tensor): + stride = to_2tuple(stride) + self.stride = stride + self.in_chs = in_chs + self.out_chs = out_chs + assert stride[0] == 4 # only setup for stride==4 + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=7, + stride=stride, + padding=3, + ) + self.norm = norm_layer(out_chs) + + def forward(self, x: Tensor): B, C, H, W = x.shape - if self.norm.normalized_shape[0] == self.in_chans: - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - - x = F.pad(x, (0, (self.patch_size[1] - W % self.patch_size[1]) % self.patch_size[1])) - x = F.pad(x, (0, 0, 0, (self.patch_size[0] - H % self.patch_size[0]) % self.patch_size[0])) + x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1])) + x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0])) + x = self.conv(x) + x = self.norm(x) + return x - x = self.proj(x) - if self.norm.normalized_shape[0] == self.embed_dim: - x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) +class Downsample(nn.Module): + def __init__( + self, + in_chs, + out_chs, + norm_layer=LayerNorm2d, + ): + super().__init__() + self.in_chs = in_chs + self.out_chs = out_chs + + self.norm = norm_layer(in_chs) + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=2, + stride=2, + padding=0, + ) + + def forward(self, x: Tensor): + B, C, H, W = x.shape + x = self.norm(x) + x = F.pad(x, (0, (2 - W % 2) % 2)) + x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2)) + x = self.conv(x) return x - + + class ChannelAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False): @@ -127,11 +121,11 @@ class ChannelAttention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) - def forward(self, x : Tensor): + def forward(self, x: Tensor): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] + q, k, v = qkv.unbind(0) k = k * self.scale attention = k.transpose(-1, -2) @ v @@ -140,50 +134,64 @@ class ChannelAttention(nn.Module): x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x - + class ChannelBlock(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn=True, cpe_act=False): + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ffn=True, + cpe_act=False, + ): super().__init__() self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.ffn = ffn self.norm1 = norm_layer(dim) self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) - + if self.ffn: self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer) - - - def forward(self, x : Tensor): + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.norm2 = None + self.mlp = None + self.drop_path2 = None + def forward(self, x: Tensor): B, C, H, W = x.shape x = self.cpe1(x).flatten(2).transpose(1, 2) - + cur = self.norm1(x) cur = self.attn(cur) - x = x + self.drop_path(cur) + x = x + self.drop_path1(cur) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)) + + if self.mlp is not None: + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).view(B, C, H, W) - x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) - if self.ffn: - x = x + self.drop_path(self.mlp(self.norm2(x))) - - x = x.transpose(1, 2).view(B, C, H, W) - return x -def window_partition(x : Tensor, window_size: int): + +def window_partition(x: Tensor, window_size: Tuple[int, int]): """ Args: x: (B, H, W, C) @@ -192,12 +200,13 @@ def window_partition(x : Tensor, window_size: int): windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) return windows -@register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows : Tensor, window_size: int, H: int, W: int): + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) @@ -207,9 +216,8 @@ def window_reverse(windows : Tensor, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ - - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -225,7 +233,6 @@ class WindowAttention(nn.Module): """ def __init__(self, dim, window_size, num_heads, qkv_bias=True): - super().__init__() self.dim = dim self.window_size = window_size @@ -238,11 +245,11 @@ class WindowAttention(nn.Module): self.softmax = nn.Softmax(dim=-1) - def forward(self, x : Tensor): + def forward(self, x: Tensor): B_, N, C = x.shape - + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] + q, k, v = qkv.unbind(0) q = q * self.scale attn = (q @ k.transpose(-2, -1)) @@ -266,108 +273,119 @@ class SpatialBlock(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, num_heads, window_size=7, - mlp_ratio=4., qkv_bias=True, drop_path=0., - act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ffn=True, cpe_act=False): + def __init__( + self, + dim, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ffn=True, + cpe_act=False, + ): super().__init__() self.dim = dim self.ffn = ffn self.num_heads = num_heads - self.window_size = window_size + self.window_size = to_2tuple(window_size) self.mlp_ratio = mlp_ratio - + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) self.norm1 = norm_layer(dim) self.attn = WindowAttention( dim, - window_size=to_2tuple(self.window_size), + self.window_size, num_heads=num_heads, - qkv_bias=qkv_bias) + qkv_bias=qkv_bias, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) - if self.ffn: self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, - act_layer=act_layer) - + act_layer=act_layer, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.norm2 = None + self.mlp = None + self.drop_path1 = None - def forward(self, x : Tensor): + def forward(self, x: Tensor): B, C, H, W = x.shape - shortcut = self.cpe1(x).flatten(2).transpose(1, 2) + x = self.norm1(shortcut) x = x.view(B, H, W, C) pad_l = pad_t = 0 - pad_r = (self.window_size - W % self.window_size) % self.window_size - pad_b = (self.window_size - H % self.window_size) % self.window_size + pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) _, Hp, Wp, _ = x.shape x_windows = window_partition(x, self.window_size) - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) # W-MSA/SW-MSA attn_windows = self.attn(x_windows) # merge windows - attn_windows = attn_windows.view(-1, - self.window_size, - self.window_size, - C) + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) x = window_reverse(attn_windows, self.window_size, Hp, Wp) - #if pad_r > 0 or pad_b > 0: + # if pad_r > 0 or pad_b > 0: x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) - x = shortcut + self.drop_path(x) + x = shortcut + self.drop_path1(x) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)) + + if self.mlp is not None: + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).view(B, C, H, W) - x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) - if self.ffn: - x = x + self.drop_path(self.mlp(self.norm2(x))) - - x = x.transpose(1, 2).view(B, C, H, W) - return x - + class DaViTStage(nn.Module): def __init__( - self, - in_chs, - out_chs, - depth = 1, - patch_size = 4, - overlapped_patch = False, - attention_types = ('spatial', 'channel'), - num_heads = 3, - window_size = 7, - mlp_ratio = 4, - qkv_bias = True, - drop_path_rates = (0, 0), - norm_layer = nn.LayerNorm, - ffn = True, - cpe_act = False + self, + in_chs, + out_chs, + depth=1, + downsample=True, + attn_types=('spatial', 'channel'), + num_heads=3, + window_size=7, + mlp_ratio=4, + qkv_bias=True, + drop_path_rates=(0, 0), + norm_layer=LayerNorm2d, + norm_layer_cl=nn.LayerNorm, + ffn=True, + cpe_act=False ): super().__init__() self.grad_checkpointing = False - - # patch embedding layer at the beginning of each stage - self.patch_embed = PatchEmbed( - patch_size=patch_size, - in_chans=in_chs, - embed_dim=out_chs, - overlapped=overlapped_patch - ) + + # downsample embedding layer at the beginning of each stage + if downsample: + self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer) + else: + self.downsample = nn.Identity() + ''' repeating alternating attention blocks in each stage default: (spatial -> channel) x depth @@ -377,44 +395,40 @@ class DaViTStage(nn.Module): ''' stage_blocks = [] for block_idx in range(depth): - dual_attention_block = [] - - for attention_id, attention_type in enumerate(attention_types): - if attention_type == 'spatial': + for attn_idx, attn_type in enumerate(attn_types): + if attn_type == 'spatial': dual_attention_block.append(SpatialBlock( dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], - norm_layer=norm_layer, + drop_path=drop_path_rates[block_idx], + norm_layer=norm_layer_cl, ffn=ffn, cpe_act=cpe_act, window_size=window_size, )) - elif attention_type == 'channel': + elif attn_type == 'channel': dual_attention_block.append(ChannelBlock( dim=out_chs, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, - drop_path=drop_path_rates[len(attention_types) * block_idx + attention_id], - norm_layer=norm_layer, + drop_path=drop_path_rates[block_idx], + norm_layer=norm_layer_cl, ffn=ffn, cpe_act=cpe_act )) - stage_blocks.append(nn.Sequential(*dual_attention_block)) - self.blocks = nn.Sequential(*stage_blocks) - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable - - def forward(self, x : Tensor): - x = self.patch_embed(x) + + def forward(self, x: Tensor): + x = self.downsample(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -431,7 +445,6 @@ class DaViT(nn.Module): in_chans (int): Number of input image channels. Default: 3 num_classes (int): Number of classes for classification head. Default: 1000 depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1) - patch_size (int | tuple(int)): Patch size. Default: 4 embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768) num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24) window_size (int): Window size. Default: 7 @@ -442,75 +455,67 @@ class DaViT(nn.Module): """ def __init__( - self, - in_chans=3, - depths=(1, 1, 3, 1), - patch_size=4, - embed_dims=(96, 192, 384, 768), - num_heads=(3, 6, 12, 24), - window_size=7, - mlp_ratio=4., - qkv_bias=True, - drop_path_rate=0.1, - norm_layer=nn.LayerNorm, - attention_types=('spatial', 'channel'), - ffn=True, - overlapped_patch=False, - cpe_act=False, - drop_rate=0., - attn_drop_rate=0., - num_classes=1000, - global_pool='avg', - head_norm_first=False, + self, + in_chans=3, + depths=(1, 1, 3, 1), + embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4, + qkv_bias=True, + norm_layer='layernorm2d', + norm_layer_cl='layernorm', + norm_eps=1e-5, + attn_types=('spatial', 'channel'), + ffn=True, + cpe_act=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_classes=1000, + global_pool='avg', + head_norm_first=False, ): super().__init__() - - architecture = [[index] * item for index, item in enumerate(depths)] - self.architecture = architecture - self.embed_dims = embed_dims - self.num_heads = num_heads - self.num_stages = len(self.embed_dims) - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, len(attention_types) * len(list(itertools.chain(*self.architecture))))] - assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1) - + num_stages = len(embed_dims) + assert num_stages == len(num_heads) == len(depths) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) self.num_classes = num_classes self.num_features = embed_dims[-1] - self.drop_rate=drop_rate + self.drop_rate = drop_rate self.grad_checkpointing = False self.feature_info = [] - - self.patch_embed = None - stages = [] - - for stage_id in range(self.num_stages): - stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])] + self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer) + in_chs = embed_dims[0] + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + stages = [] + for stage_idx in range(num_stages): + out_chs = embed_dims[stage_idx] stage = DaViTStage( - in_chans if stage_id == 0 else embed_dims[stage_id - 1], - embed_dims[stage_id], - depth = depths[stage_id], - patch_size = patch_size if stage_id == 0 else 2, - overlapped_patch = overlapped_patch, - attention_types = attention_types, - num_heads = num_heads[stage_id], - window_size = window_size, - mlp_ratio = mlp_ratio, - qkv_bias = qkv_bias, - drop_path_rates = stage_drop_rates, - norm_layer = nn.LayerNorm, - ffn = ffn, - cpe_act = cpe_act + in_chs, + out_chs, + depth=depths[stage_idx], + downsample=stage_idx > 0, + attn_types=attn_types, + num_heads=num_heads[stage_idx], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rates=dpr[stage_idx], + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act, ) - - if stage_id == 0: - self.patch_embed = stage.patch_embed - stage.patch_embed = nn.Identity() - + in_chs = out_chs stages.append(stage) - self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] - + self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')] + self.stages = nn.Sequential(*stages) - + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt # FIXME generalize this structure to ClassifierHead @@ -521,28 +526,25 @@ class DaViT(nn.Module): ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(self.drop_rate)), ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) - + self.apply(self._init_weights) - + def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable for stage in self.stages: stage.set_grad_checkpointing(enable=enable) - + @torch.jit.ignore def get_classifier(self): return self.head.fc - + def reset_classifier(self, num_classes, global_pool=None): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) @@ -550,21 +552,21 @@ class DaViT(nn.Module): self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): - x = self.patch_embed(x) + x = self.stem(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.stages, x) else: x = self.stages(x) - x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.norm_pre(x) return x - + def forward_head(self, x, pre_logits: bool = False): x = self.head.global_pool(x) - x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + x = self.head.norm(x) x = self.head.flatten(x) x = self.head.drop(x) return x if pre_logits else self.head.fc(x) - + def forward(self, x): x = self.forward_features(x) x = self.forward_head(x) @@ -573,29 +575,28 @@ class DaViT(nn.Module): def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ - if 'head' in state_dict: + if 'head.fc.weight' in state_dict: return state_dict # non-MSFT checkpoint - + if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] import re out_dict = {} for k, v in state_dict.items(): - - k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k) + k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k) k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) - k = k.replace('stages.0.patch_embed', 'patch_embed') + k = k.replace('downsample.proj', 'downsample.conv') + k = k.replace('stages.0.downsample', 'stem') k = k.replace('head.', 'head.fc.') k = k.replace('norms.', 'head.norm.') k = k.replace('cpe.0', 'cpe1') k = k.replace('cpe.1', 'cpe2') out_dict[k] = v return out_dict - -def _create_davit(variant, pretrained=False, **kwargs): +def _create_davit(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) @@ -608,69 +609,71 @@ def _create_davit(variant, pretrained=False, **kwargs): **kwargs) return model - - + def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.850, 'interpolation': 'bicubic', + 'crop_pct': 0.95, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc', + 'first_conv': 'stem.conv', 'classifier': 'head.fc', **kwargs } - # TODO contact authors to get larger pretrained models default_cfgs = generate_default_cfgs({ # official microsoft weights from https://github.com/dingmyu/davit 'davit_tiny.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"), + hf_hub_id='timm/'), 'davit_small.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), + hf_hub_id='timm/'), 'davit_base.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), + hf_hub_id='timm/'), 'davit_large': _cfg(), 'davit_huge': _cfg(), 'davit_giant': _cfg(), }) - @register_model def davit_tiny(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), - num_heads=(3, 6, 12, 24), **kwargs) + model_kwargs = dict( + depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs) return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) - + + @register_model def davit_small(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), - num_heads=(3, 6, 12, 24), **kwargs) + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs) return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) - + + @register_model def davit_base(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), - num_heads=(4, 8, 16, 32), **kwargs) + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) + @register_model def davit_large(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), - num_heads=(6, 12, 24, 48), **kwargs) + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs) return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) - + + @register_model def davit_huge(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), - num_heads=(8, 16, 32, 64), **kwargs) + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs) return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) - + + @register_model def davit_giant(pretrained=False, **kwargs): - model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), - num_heads=(12, 24, 48, 96), **kwargs) + model_kwargs = dict( + depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)