diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 301186dd..b1f82789 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -8,6 +8,7 @@ from .convmixer import * from .convnext import * from .crossvit import * from .cspnet import * +from .davit import * from .deit import * from .densenet import * from .dla import * diff --git a/timm/models/davit.py b/timm/models/davit.py new file mode 100644 index 00000000..444f21f3 --- /dev/null +++ b/timm/models/davit.py @@ -0,0 +1,624 @@ +""" DaViT: Dual Attention Vision Transformers + +As described in https://arxiv.org/abs/2204.03645 + +Input size invariant transformer architecture that combines channel and spacial +attention in each block. The attention mechanisms used are linear in complexity. + +DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below + + + + + +""" +# Copyright (c) 2022 Mingyu Ding +# All rights reserved. +# This source code is licensed under the MIT license + +import itertools +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from .helpers import build_model_with_cfg +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp +from collections import OrderedDict +import torch.utils.checkpoint as checkpoint +from .pretrained import generate_default_cfgs +from .registry import register_model + + +__all__ = ['DaViT'] + + + + +class MySequential(nn.Sequential): + def forward(self, *inputs): + for module in self._modules.values(): + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class ConvPosEnc(nn.Module): + def __init__(self, dim, k=3, act=False, normtype=False): + super(ConvPosEnc, self).__init__() + self.proj = nn.Conv2d(dim, + dim, + to_2tuple(k), + to_2tuple(1), + to_2tuple(k // 2), + groups=dim) + self.normtype = normtype + if self.normtype == 'batch': + self.norm = nn.BatchNorm2d(dim) + elif self.normtype == 'layer': + self.norm = nn.LayerNorm(dim) + self.activation = nn.GELU() if act else nn.Identity() + + def forward(self, x, size: Tuple[int, int]): + B, N, C = x.shape + H, W = size + assert N == H * W + + feat = x.transpose(1, 2).view(B, C, H, W) + feat = self.proj(feat) + if self.normtype == 'batch': + feat = self.norm(feat).flatten(2).transpose(1, 2) + elif self.normtype == 'layer': + feat = self.norm(feat.flatten(2).transpose(1, 2)) + else: + feat = feat.flatten(2).transpose(1, 2) + x = x + self.activation(feat) + return x + + +class PatchEmbed(nn.Module): + """ Size-agnostic implementation of 2D image to patch embedding, + allowing input size to be adjusted during model forward operation + """ + + def __init__( + self, + patch_size=16, + in_chans=3, + embed_dim=96, + overlapped=False): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + + if patch_size[0] == 4: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(7, 7), + stride=patch_size, + padding=(3, 3)) + self.norm = nn.LayerNorm(embed_dim) + if patch_size[0] == 2: + kernel = 3 if overlapped else 2 + pad = 1 if overlapped else 0 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=to_2tuple(kernel), + stride=patch_size, + padding=to_2tuple(pad)) + self.norm = nn.LayerNorm(in_chans) + + def forward(self, x, size): + H, W = size + dim = len(x.shape) + if dim == 3: + B, HW, C = x.shape + x = self.norm(x) + x = x.reshape(B, + H, + W, + C).permute(0, 3, 1, 2).contiguous() + + B, C, H, W = x.shape + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) + newsize = (x.size(2), x.size(3)) + x = x.flatten(2).transpose(1, 2) + if dim == 4: + x = self.norm(x) + return x, newsize + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class ChannelBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + ffn=True, cpe_act=False): + super().__init__() + + self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), + ConvPosEnc(dim=dim, k=3, act=cpe_act)]) + self.ffn = ffn + self.norm1 = norm_layer(dim) + self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if self.ffn: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer) + + def forward(self, x, size): + x = self.cpe[0](x, size) + cur = self.norm1(x) + cur = self.attn(cur) + x = x + self.drop_path(cur) + + x = self.cpe[1](x, size) + if self.ffn: + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x, size + + +def window_partition(x, window_size: int): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size: int, H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True): + + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x): + B_, N, C = x.shape + + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x + + +class SpatialBlock(nn.Module): + r""" Windows Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, + mlp_ratio=4., qkv_bias=True, drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + ffn=True, cpe_act=False): + super().__init__() + self.dim = dim + self.ffn = ffn + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act), + ConvPosEnc(dim=dim, k=3, act=cpe_act)]) + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if self.ffn: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer) + + def forward(self, x, size): + H, W = size + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = self.cpe[0](x, size) + x = self.norm1(shortcut) + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x_windows = window_partition(x, self.window_size) + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) + + # merge windows + attn_windows = attn_windows.view(-1, + self.window_size, + self.window_size, + C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(x) + + x = self.cpe[1](x, size) + if self.ffn: + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x, size + + + +class DaViT(nn.Module): + r""" Dual Attention Transformer + Args: + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256) + num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16) + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__( + self, + in_chans=3, + depths=(1, 1, 3, 1), + patch_size=4, + embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + attention_types=('spatial', 'channel'), + ffn=True, + overlapped_patch=False, + cpe_act=False, + drop_rate=0., + attn_drop_rate=0., + img_size=224, + num_classes=1000, + global_pool='avg' + ): + super().__init__() + + architecture = [[index] * item for index, item in enumerate(depths)] + self.architecture = architecture + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_stages = len(self.embed_dims) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, 2 * len(list(itertools.chain(*self.architecture))))] + assert self.num_stages == len(self.num_heads) == (sorted(list(itertools.chain(*self.architecture)))[-1] + 1) + + self.num_classes = num_classes + self.num_features = embed_dims[-1] + self.drop_rate=drop_rate + self.grad_checkpointing = False + + + self.patch_embeds = nn.ModuleList([ + PatchEmbed(patch_size=patch_size if i == 0 else 2, + in_chans=in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + overlapped=overlapped_patch) + for i in range(self.num_stages)]) + + main_blocks = [] + for block_id, block_param in enumerate(self.architecture): + layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id]))) + + block = nn.ModuleList([ + MySequential(*[ + ChannelBlock( + dim=self.embed_dims[item], + num_heads=self.num_heads[item], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act + ) if attention_type == 'channel' else + SpatialBlock( + dim=self.embed_dims[item], + num_heads=self.num_heads[item], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act, + window_size=window_size, + ) if attention_type == 'spatial' else None + for attention_id, attention_type in enumerate(attention_types)] + ) for layer_id, item in enumerate(block_param) + ]) + main_blocks.append(block) + self.main_blocks = nn.ModuleList(main_blocks) + + ''' + # layer norms for pyramid feature extraction + # + # TODO implement pyramid feature extraction + # + # davit should be a good transformer candidate, since the only official implementation + # is for segmentation and detection + for i_layer in range(self.num_stages): + layer = norm_layer(self.embed_dims[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + ''' + self.norms = norm_layer(self.num_features) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + + def forward_features_full(self, x): + x, size = self.patch_embeds[0](x, (x.size(2), x.size(3))) + features = [x] + sizes = [size] + branches = [0] + + for block_index, block_param in enumerate(self.architecture): + branch_ids = sorted(set(block_param)) + for branch_id in branch_ids: + if branch_id not in branches: + x, size = self.patch_embeds[branch_id](features[-1], sizes[-1]) + features.append(x) + sizes.append(size) + branches.append(branch_id) + for layer_index, branch_id in enumerate(block_param): + if self.grad_checkpointing and not torch.jit.is_scripting(): + features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) + else: + features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) + ''' + # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model + outs = [] + for i in range(self.num_stages): + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(features[i]) + H, W = sizes[i] + out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() + outs.append(out) + ''' + # non-normalized pyramid features + corresponding sizes + return tuple(features), tuple(sizes) + + def forward_features(self, x): + x, sizes = self.forward_features_full(x) + # take final feature and norm + x = self.norms(x[-1]) + H, W = sizes[-1] + x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() + #print(x.shape) + return x + + def forward_head(self, x, pre_logits: bool = False): + + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + +def checkpoint_filter_fn(state_dict, model): + """ Remap MSFT checkpoints -> timm """ + if 'head.norm.weight' in state_dict: + return state_dict # non-MSFT checkpoint + + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + out_dict = {} + import re + for k, v in state_dict.items(): + + k = k.replace('head.', 'head.fc.') + out_dict[k] = v + return out_dict + + + +def _create_davit(variant, pretrained=False, **kwargs): + model = build_model_with_cfg(DaViT, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) + return model + + + +def _cfg(url='', **kwargs): # not sure how this should be set up + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bilinear', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head.fc', + **kwargs + } + + + +default_cfgs = generate_default_cfgs({ + +'davit_tiny.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"), +'davit_small.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), +'davit_base.msft_in1k': _cfg( + url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), +}) + + + +@register_model +def davit_tiny(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) + +@register_model +def davit_small(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), **kwargs) + return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) + +@register_model +def davit_base(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), + num_heads=(4, 8, 16, 32), **kwargs) + return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) + +''' models without weights +# TODO contact authors to get larger pretrained models +@register_model +def davit_large(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), + num_heads=(6, 12, 24, 48), **kwargs) + return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) + +@register_model +def davit_huge(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), + num_heads=(8, 16, 32, 64), **kwargs) + return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) + +@register_model +def davit_giant(pretrained=False, **kwargs): + model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), + num_heads=(12, 24, 48, 96), **kwargs) + return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) +''' \ No newline at end of file