diff --git a/tests/test_models.py b/tests/test_models.py index 87d75cbd..97872fde 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,9 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', + 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*' + ] NUM_NON_STD = len(NON_STD_FILTERS) diff --git a/timm/models/davit.py b/timm/models/davit.py index 444f21f3..eda928e4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -7,47 +7,34 @@ attention in each block. The attention mechanisms used are linear in complexity. DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below - - - - """ # Copyright (c) 2022 Mingyu Ding # All rights reserved. # This source code is licensed under the MIT license import itertools -from typing import Tuple +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List +from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F -from .helpers import build_model_with_cfg -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp -from collections import OrderedDict +from torch import Tensor import torch.utils.checkpoint as checkpoint + +from .features import FeatureInfo +from .fx_features import register_notrace_function, register_notrace_module +from .helpers import build_model_with_cfg, pretrained_cfg_for_features +from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from .pretrained import generate_default_cfgs from .registry import register_model - +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] - - - -class MySequential(nn.Sequential): - def forward(self, *inputs): - for module in self._modules.values(): - if type(inputs) == tuple: - inputs = module(*inputs) - else: - inputs = module(inputs) - return inputs - - class ConvPosEnc(nn.Module): - def __init__(self, dim, k=3, act=False, normtype=False): + def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): + super(ConvPosEnc, self).__init__() self.proj = nn.Conv2d(dim, dim, @@ -56,16 +43,16 @@ class ConvPosEnc(nn.Module): to_2tuple(k // 2), groups=dim) self.normtype = normtype + self.norm = nn.Identity() if self.normtype == 'batch': self.norm = nn.BatchNorm2d(dim) elif self.normtype == 'layer': self.norm = nn.LayerNorm(dim) self.activation = nn.GELU() if act else nn.Identity() - def forward(self, x, size: Tuple[int, int]): + def forward(self, x : Tensor, size: Tuple[int, int]): B, N, C = x.shape H, W = size - assert N == H * W feat = x.transpose(1, 2).view(B, C, H, W) feat = self.proj(feat) @@ -77,8 +64,11 @@ class ConvPosEnc(nn.Module): feat = feat.flatten(2).transpose(1, 2) x = x + self.activation(feat) return x + - +# reason: dim in control sequence +# FIXME reimplement to allow tracing +@register_notrace_module class PatchEmbed(nn.Module): """ Size-agnostic implementation of 2D image to patch embedding, allowing input size to be adjusted during model forward operation @@ -113,9 +103,10 @@ class PatchEmbed(nn.Module): padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) - def forward(self, x, size): + + def forward(self, x : Tensor, size: Tuple[int, int]): H, W = size - dim = len(x.shape) + dim = x.dim() if dim == 3: B, HW, C = x.shape x = self.norm(x) @@ -149,7 +140,7 @@ class ChannelAttention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.proj = nn.Linear(dim, dim) - def forward(self, x): + def forward(self, x : Tensor): B, N, C = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -186,7 +177,8 @@ class ChannelBlock(nn.Module): hidden_features=mlp_hidden_dim, act_layer=act_layer) - def forward(self, x, size): + + def forward(self, x : Tensor, size: Tuple[int, int]): x = self.cpe[0](x, size) cur = self.norm1(x) cur = self.attn(cur) @@ -198,7 +190,7 @@ class ChannelBlock(nn.Module): return x, size -def window_partition(x, window_size: int): +def window_partition(x : Tensor, window_size: int): """ Args: x: (B, H, W, C) @@ -211,8 +203,8 @@ def window_partition(x, window_size: int): windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) return windows - -def window_reverse(windows, window_size: int, H: int, W: int): +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows : Tensor, window_size: int, H: int, W: int): """ Args: windows: (num_windows*B, window_size, window_size, C) @@ -222,6 +214,7 @@ def window_reverse(windows, window_size: int, H: int, W: int): Returns: x: (B, H, W, C) """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) @@ -252,7 +245,7 @@ class WindowAttention(nn.Module): self.softmax = nn.Softmax(dim=-1) - def forward(self, x): + def forward(self, x : Tensor): B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) @@ -310,10 +303,11 @@ class SpatialBlock(nn.Module): hidden_features=mlp_hidden_dim, act_layer=act_layer) - def forward(self, x, size): + + def forward(self, x : Tensor, size: Tuple[int, int]): + H, W = size B, L, C = x.shape - assert L == H * W, "input feature has wrong size" shortcut = self.cpe[0](x, size) x = self.norm1(shortcut) @@ -338,8 +332,8 @@ class SpatialBlock(nn.Module): C) x = window_reverse(attn_windows, self.window_size, Hp, Wp) - if pad_r > 0 or pad_b > 0: - x = x[:, :H, :W, :].contiguous() + #if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) @@ -352,12 +346,17 @@ class SpatialBlock(nn.Module): class DaViT(nn.Module): - r""" Dual Attention Transformer + r""" DaViT + A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645 + Supports arbitrary input sizes and pyramid feature extraction + Args: - patch_size (int | tuple(int)): Patch size. Default: 4 in_chans (int): Number of input image channels. Default: 3 - embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256) - num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16) + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1) + patch_size (int | tuple(int)): Patch size. Default: 4 + embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768) + num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24) window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True @@ -383,7 +382,6 @@ class DaViT(nn.Module): cpe_act=False, drop_rate=0., attn_drop_rate=0., - img_size=224, num_classes=1000, global_pool='avg' ): @@ -401,7 +399,7 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - + self.feature_info = [] self.patch_embeds = nn.ModuleList([ PatchEmbed(patch_size=patch_size if i == 0 else 2, @@ -410,12 +408,12 @@ class DaViT(nn.Module): overlapped=overlapped_patch) for i in range(self.num_stages)]) - main_blocks = [] - for block_id, block_param in enumerate(self.architecture): - layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id]))) + self.stages = nn.ModuleList() + for stage_id, stage_param in enumerate(self.architecture): + layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) - block = nn.ModuleList([ - MySequential(*[ + stage = nn.ModuleList([ + nn.ModuleList([ ChannelBlock( dim=self.embed_dims[item], num_heads=self.num_heads[item], @@ -438,27 +436,17 @@ class DaViT(nn.Module): window_size=window_size, ) if attention_type == 'spatial' else None for attention_id, attention_type in enumerate(attention_types)] - ) for layer_id, item in enumerate(block_param) + ) for layer_id, item in enumerate(stage_param) ]) - main_blocks.append(block) - self.main_blocks = nn.ModuleList(main_blocks) - - ''' - # layer norms for pyramid feature extraction - # - # TODO implement pyramid feature extraction - # - # davit should be a good transformer candidate, since the only official implementation - # is for segmentation and detection - for i_layer in range(self.num_stages): - layer = norm_layer(self.embed_dims[i_layer]) - layer_name = f'norm{i_layer}' - self.add_module(layer_name, layer) - ''' + + self.stages.add_module(f'stage_{stage_id}', stage) + self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')] + self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) - + + def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) @@ -467,9 +455,7 @@ class DaViT(nn.Module): elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) - - - + @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable @@ -485,55 +471,67 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - def forward_features_full(self, x): - x, size = self.patch_embeds[0](x, (x.size(2), x.size(3))) + def forward_network(self, x): + size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] sizes = [size] - branches = [0] - - for block_index, block_param in enumerate(self.architecture): - branch_ids = sorted(set(block_param)) - for branch_id in branch_ids: - if branch_id not in branches: - x, size = self.patch_embeds[branch_id](features[-1], sizes[-1]) - features.append(x) - sizes.append(size) - branches.append(branch_id) - for layer_index, branch_id in enumerate(block_param): - if self.grad_checkpointing and not torch.jit.is_scripting(): - features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) - else: - features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) - ''' - # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model + + for patch_layer, stage in zip(self.patch_embeds, self.stages): + features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) + for _, block in enumerate(stage): + for _, layer in enumerate(block): + if self.grad_checkpointing and not torch.jit.is_scripting(): + features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) + else: + features[-1], sizes[-1] = layer(features[-1], sizes[-1]) + + # don't append outputs of last stage, since they are already there + if(len(features) < self.num_stages): + features.append(features[-1]) + sizes.append(sizes[-1]) + + + # non-normalized pyramid features + corresponding sizes + return features, sizes + + def forward_pyramid_features(self, x) -> List[Tensor]: + x, sizes = self.forward_network(x) outs = [] - for i in range(self.num_stages): - norm_layer = getattr(self, f'norm{i}') - x_out = norm_layer(features[i]) + for i, out in enumerate(x): H, W = sizes[i] - out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() - outs.append(out) - ''' - # non-normalized pyramid features + corresponding sizes - return tuple(features), tuple(sizes) + outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + return outs + def forward_features(self, x): - x, sizes = self.forward_features_full(x) + x, sizes = self.forward_network(x) # take final feature and norm x = self.norms(x[-1]) H, W = sizes[-1] x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() - #print(x.shape) return x def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) - def forward(self, x): + def forward_classifier(self, x): x = self.forward_features(x) x = self.forward_head(x) return x + + def forward(self, x): + return self.forward_classifier(x) + + +class DaViTFeatures(DaViT): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) + + def forward(self, x) -> List[Tensor]: + return self.forward_pyramid_features(x) + def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ @@ -542,11 +540,10 @@ def checkpoint_filter_fn(state_dict, model): if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] - + out_dict = {} - import re for k, v in state_dict.items(): - + k = k.replace('main_blocks.', 'stages.stage_') k = k.replace('head.', 'head.fc.') out_dict[k] = v return out_dict @@ -554,8 +551,25 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): - model = build_model_with_cfg(DaViT, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, **kwargs) + model_cls = DaViT + features_only = False + kwargs_filter = None + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + if kwargs.pop('features_only', False): + model_cls = DaViTFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True + model = build_model_with_cfg( + model_cls, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + if features_only: + model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) + model.default_cfg = model.pretrained_cfg # backwards compat return model @@ -573,13 +587,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up default_cfgs = generate_default_cfgs({ - -'davit_tiny.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"), -'davit_small.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), -'davit_base.msft_in1k': _cfg( - url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), + # official microsoft weights from https://github.com/dingmyu/davit + 'davit_tiny.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"), + 'davit_small.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), + 'davit_base.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), })