Davit std (#3)

Davit with all features working
pull/1583/head
Fredo Guan 1 year ago committed by GitHub
parent 434a03937d
commit edea013dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,7 +27,9 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
]
NUM_NON_STD = len(NON_STD_FILTERS)

@ -7,47 +7,34 @@ attention in each block. The attention mechanisms used are linear in complexity.
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
"""
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
import itertools
from typing import Tuple
from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from collections import OrderedDict
from torch import Tensor
import torch.utils.checkpoint as checkpoint
from .features import FeatureInfo
from .fx_features import register_notrace_function, register_notrace_module
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .pretrained import generate_default_cfgs
from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim,
dim,
@ -56,16 +43,16 @@ class ConvPosEnc(nn.Module):
to_2tuple(k // 2),
groups=dim)
self.normtype = normtype
self.norm = nn.Identity()
if self.normtype == 'batch':
self.norm = nn.BatchNorm2d(dim)
elif self.normtype == 'layer':
self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x, size: Tuple[int, int]):
def forward(self, x : Tensor, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
assert N == H * W
feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat)
@ -77,8 +64,11 @@ class ConvPosEnc(nn.Module):
feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat)
return x
# reason: dim in control sequence
# FIXME reimplement to allow tracing
@register_notrace_module
class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
@ -113,9 +103,10 @@ class PatchEmbed(nn.Module):
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
def forward(self, x, size):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
dim = len(x.shape)
dim = x.dim()
if dim == 3:
B, HW, C = x.shape
x = self.norm(x)
@ -149,7 +140,7 @@ class ChannelAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x):
def forward(self, x : Tensor):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
@ -186,7 +177,8 @@ class ChannelBlock(nn.Module):
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x, size):
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe[0](x, size)
cur = self.norm1(x)
cur = self.attn(cur)
@ -198,7 +190,7 @@ class ChannelBlock(nn.Module):
return x, size
def window_partition(x, window_size: int):
def window_partition(x : Tensor, window_size: int):
"""
Args:
x: (B, H, W, C)
@ -211,8 +203,8 @@ def window_partition(x, window_size: int):
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size: int, H: int, W: int):
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
@ -222,6 +214,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
@ -252,7 +245,7 @@ class WindowAttention(nn.Module):
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
def forward(self, x : Tensor):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
@ -310,10 +303,11 @@ class SpatialBlock(nn.Module):
hidden_features=mlp_hidden_dim,
act_layer=act_layer)
def forward(self, x, size):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = self.cpe[0](x, size)
x = self.norm1(shortcut)
@ -338,8 +332,8 @@ class SpatialBlock(nn.Module):
C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
#if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
@ -352,12 +346,17 @@ class SpatialBlock(nn.Module):
class DaViT(nn.Module):
r""" Dual Attention Transformer
r""" DaViT
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
Supports arbitrary input sizes and pyramid feature extraction
Args:
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
patch_size (int | tuple(int)): Patch size. Default: 4
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
@ -383,7 +382,6 @@ class DaViT(nn.Module):
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
img_size=224,
num_classes=1000,
global_pool='avg'
):
@ -401,7 +399,7 @@ class DaViT(nn.Module):
self.num_features = embed_dims[-1]
self.drop_rate=drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2,
@ -410,12 +408,12 @@ class DaViT(nn.Module):
overlapped=overlapped_patch)
for i in range(self.num_stages)])
main_blocks = []
for block_id, block_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
self.stages = nn.ModuleList()
for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
block = nn.ModuleList([
MySequential(*[
stage = nn.ModuleList([
nn.ModuleList([
ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
@ -438,27 +436,17 @@ class DaViT(nn.Module):
window_size=window_size,
) if attention_type == 'spatial' else None
for attention_id, attention_type in enumerate(attention_types)]
) for layer_id, item in enumerate(block_param)
) for layer_id, item in enumerate(stage_param)
])
main_blocks.append(block)
self.main_blocks = nn.ModuleList(main_blocks)
'''
# layer norms for pyramid feature extraction
#
# TODO implement pyramid feature extraction
#
# davit should be a good transformer candidate, since the only official implementation
# is for segmentation and detection
for i_layer in range(self.num_stages):
layer = norm_layer(self.embed_dims[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
'''
self.stages.add_module(f'stage_{stage_id}', stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
@ -467,9 +455,7 @@ class DaViT(nn.Module):
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@ -485,55 +471,67 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features_full(self, x):
x, size = self.patch_embeds[0](x, (x.size(2), x.size(3)))
def forward_network(self, x):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]
branches = [0]
for block_index, block_param in enumerate(self.architecture):
branch_ids = sorted(set(block_param))
for branch_id in branch_ids:
if branch_id not in branches:
x, size = self.patch_embeds[branch_id](features[-1], sizes[-1])
features.append(x)
sizes.append(size)
branches.append(branch_id)
for layer_index, branch_id in enumerate(block_param):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
else:
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
'''
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
for patch_layer, stage in zip(self.patch_embeds, self.stages):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for _, block in enumerate(stage):
for _, layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else:
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i in range(self.num_stages):
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(features[i])
for i, out in enumerate(x):
H, W = sizes[i]
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out)
'''
# non-normalized pyramid features + corresponding sizes
return tuple(features), tuple(sizes)
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x):
x, sizes = self.forward_features_full(x)
x, sizes = self.forward_network(x)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
#print(x.shape)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
def forward_classifier(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def forward(self, x):
return self.forward_classifier(x)
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
@ -542,11 +540,10 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('main_blocks.', 'stages.stage_')
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict
@ -554,8 +551,25 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(DaViT, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, **kwargs)
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model
@ -573,13 +587,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
default_cfgs = generate_default_cfgs({
'davit_tiny.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
'davit_small.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
'davit_small.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
})

Loading…
Cancel
Save