Davit std (#3)

Davit with all features working
pull/1583/head
Fredo Guan 2 years ago committed by GitHub
parent 434a03937d
commit edea013dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -27,7 +27,9 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*' 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
] ]
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)

@ -7,47 +7,34 @@ attention in each block. The attention mechanisms used are linear in complexity.
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
""" """
# Copyright (c) 2022 Mingyu Ding # Copyright (c) 2022 Mingyu Ding
# All rights reserved. # All rights reserved.
# This source code is licensed under the MIT license # This source code is licensed under the MIT license
import itertools import itertools
from typing import Tuple from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import build_model_with_cfg from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from collections import OrderedDict
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from .features import FeatureInfo
from .fx_features import register_notrace_function, register_notrace_module
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .pretrained import generate_default_cfgs from .pretrained import generate_default_cfgs
from .registry import register_model from .registry import register_model
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT'] __all__ = ['DaViT']
class MySequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
super(ConvPosEnc, self).__init__() super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, self.proj = nn.Conv2d(dim,
dim, dim,
@ -56,16 +43,16 @@ class ConvPosEnc(nn.Module):
to_2tuple(k // 2), to_2tuple(k // 2),
groups=dim) groups=dim)
self.normtype = normtype self.normtype = normtype
self.norm = nn.Identity()
if self.normtype == 'batch': if self.normtype == 'batch':
self.norm = nn.BatchNorm2d(dim) self.norm = nn.BatchNorm2d(dim)
elif self.normtype == 'layer': elif self.normtype == 'layer':
self.norm = nn.LayerNorm(dim) self.norm = nn.LayerNorm(dim)
self.activation = nn.GELU() if act else nn.Identity() self.activation = nn.GELU() if act else nn.Identity()
def forward(self, x, size: Tuple[int, int]): def forward(self, x : Tensor, size: Tuple[int, int]):
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == H * W
feat = x.transpose(1, 2).view(B, C, H, W) feat = x.transpose(1, 2).view(B, C, H, W)
feat = self.proj(feat) feat = self.proj(feat)
@ -77,8 +64,11 @@ class ConvPosEnc(nn.Module):
feat = feat.flatten(2).transpose(1, 2) feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat) x = x + self.activation(feat)
return x return x
# reason: dim in control sequence
# FIXME reimplement to allow tracing
@register_notrace_module
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding, """ Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation allowing input size to be adjusted during model forward operation
@ -113,9 +103,10 @@ class PatchEmbed(nn.Module):
padding=to_2tuple(pad)) padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans) self.norm = nn.LayerNorm(in_chans)
def forward(self, x, size):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size H, W = size
dim = len(x.shape) dim = x.dim()
if dim == 3: if dim == 3:
B, HW, C = x.shape B, HW, C = x.shape
x = self.norm(x) x = self.norm(x)
@ -149,7 +140,7 @@ class ChannelAttention(nn.Module):
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
def forward(self, x): def forward(self, x : Tensor):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
@ -186,7 +177,8 @@ class ChannelBlock(nn.Module):
hidden_features=mlp_hidden_dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer) act_layer=act_layer)
def forward(self, x, size):
def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe[0](x, size) x = self.cpe[0](x, size)
cur = self.norm1(x) cur = self.norm1(x)
cur = self.attn(cur) cur = self.attn(cur)
@ -198,7 +190,7 @@ class ChannelBlock(nn.Module):
return x, size return x, size
def window_partition(x, window_size: int): def window_partition(x : Tensor, window_size: int):
""" """
Args: Args:
x: (B, H, W, C) x: (B, H, W, C)
@ -211,8 +203,8 @@ def window_partition(x, window_size: int):
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
return windows return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: int, H: int, W: int): def window_reverse(windows : Tensor, window_size: int, H: int, W: int):
""" """
Args: Args:
windows: (num_windows*B, window_size, window_size, C) windows: (num_windows*B, window_size, window_size, C)
@ -222,6 +214,7 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size / window_size)) B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
@ -252,7 +245,7 @@ class WindowAttention(nn.Module):
self.softmax = nn.Softmax(dim=-1) self.softmax = nn.Softmax(dim=-1)
def forward(self, x): def forward(self, x : Tensor):
B_, N, C = x.shape B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
@ -310,10 +303,11 @@ class SpatialBlock(nn.Module):
hidden_features=mlp_hidden_dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer) act_layer=act_layer)
def forward(self, x, size):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size H, W = size
B, L, C = x.shape B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
shortcut = self.cpe[0](x, size) shortcut = self.cpe[0](x, size)
x = self.norm1(shortcut) x = self.norm1(shortcut)
@ -338,8 +332,8 @@ class SpatialBlock(nn.Module):
C) C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp) x = window_reverse(attn_windows, self.window_size, Hp, Wp)
if pad_r > 0 or pad_b > 0: #if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous() x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C) x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x) x = shortcut + self.drop_path(x)
@ -352,12 +346,17 @@ class SpatialBlock(nn.Module):
class DaViT(nn.Module): class DaViT(nn.Module):
r""" Dual Attention Transformer r""" DaViT
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
Supports arbitrary input sizes and pyramid feature extraction
Args: Args:
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3 in_chans (int): Number of input image channels. Default: 3
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256) num_classes (int): Number of classes for classification head. Default: 1000
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16) depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
patch_size (int | tuple(int)): Patch size. Default: 4
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7 window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
@ -383,7 +382,6 @@ class DaViT(nn.Module):
cpe_act=False, cpe_act=False,
drop_rate=0., drop_rate=0.,
attn_drop_rate=0., attn_drop_rate=0.,
img_size=224,
num_classes=1000, num_classes=1000,
global_pool='avg' global_pool='avg'
): ):
@ -401,7 +399,7 @@ class DaViT(nn.Module):
self.num_features = embed_dims[-1] self.num_features = embed_dims[-1]
self.drop_rate=drop_rate self.drop_rate=drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = []
self.patch_embeds = nn.ModuleList([ self.patch_embeds = nn.ModuleList([
PatchEmbed(patch_size=patch_size if i == 0 else 2, PatchEmbed(patch_size=patch_size if i == 0 else 2,
@ -410,12 +408,12 @@ class DaViT(nn.Module):
overlapped=overlapped_patch) overlapped=overlapped_patch)
for i in range(self.num_stages)]) for i in range(self.num_stages)])
main_blocks = [] self.stages = nn.ModuleList()
for block_id, block_param in enumerate(self.architecture): for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id]))) layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
block = nn.ModuleList([ stage = nn.ModuleList([
MySequential(*[ nn.ModuleList([
ChannelBlock( ChannelBlock(
dim=self.embed_dims[item], dim=self.embed_dims[item],
num_heads=self.num_heads[item], num_heads=self.num_heads[item],
@ -438,27 +436,17 @@ class DaViT(nn.Module):
window_size=window_size, window_size=window_size,
) if attention_type == 'spatial' else None ) if attention_type == 'spatial' else None
for attention_id, attention_type in enumerate(attention_types)] for attention_id, attention_type in enumerate(attention_types)]
) for layer_id, item in enumerate(block_param) ) for layer_id, item in enumerate(stage_param)
]) ])
main_blocks.append(block)
self.main_blocks = nn.ModuleList(main_blocks) self.stages.add_module(f'stage_{stage_id}', stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.stage_{stage_id}')]
'''
# layer norms for pyramid feature extraction
#
# TODO implement pyramid feature extraction
#
# davit should be a good transformer candidate, since the only official implementation
# is for segmentation and detection
for i_layer in range(self.num_stages):
layer = norm_layer(self.embed_dims[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
'''
self.norms = norm_layer(self.num_features) self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02) trunc_normal_(m.weight, std=.02)
@ -467,9 +455,7 @@ class DaViT(nn.Module):
elif isinstance(m, nn.LayerNorm): elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0) nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable self.grad_checkpointing = enable
@ -485,55 +471,67 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features_full(self, x): def forward_network(self, x):
x, size = self.patch_embeds[0](x, (x.size(2), x.size(3))) size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x] features = [x]
sizes = [size] sizes = [size]
branches = [0]
for patch_layer, stage in zip(self.patch_embeds, self.stages):
for block_index, block_param in enumerate(self.architecture): features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
branch_ids = sorted(set(block_param)) for _, block in enumerate(stage):
for branch_id in branch_ids: for _, layer in enumerate(block):
if branch_id not in branches: if self.grad_checkpointing and not torch.jit.is_scripting():
x, size = self.patch_embeds[branch_id](features[-1], sizes[-1]) features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
features.append(x) else:
sizes.append(size) features[-1], sizes[-1] = layer(features[-1], sizes[-1])
branches.append(branch_id)
for layer_index, branch_id in enumerate(block_param): # don't append outputs of last stage, since they are already there
if self.grad_checkpointing and not torch.jit.is_scripting(): if(len(features) < self.num_stages):
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) features.append(features[-1])
else: sizes.append(sizes[-1])
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
'''
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model # non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = [] outs = []
for i in range(self.num_stages): for i, out in enumerate(x):
norm_layer = getattr(self, f'norm{i}')
x_out = norm_layer(features[i])
H, W = sizes[i] H, W = sizes[i]
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
outs.append(out)
'''
# non-normalized pyramid features + corresponding sizes
return tuple(features), tuple(sizes)
return outs
def forward_features(self, x): def forward_features(self, x):
x, sizes = self.forward_features_full(x) x, sizes = self.forward_network(x)
# take final feature and norm # take final feature and norm
x = self.norms(x[-1]) x = self.norms(x[-1])
H, W = sizes[-1] H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
#print(x.shape)
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits) return self.head(x, pre_logits=pre_logits)
def forward(self, x): def forward_classifier(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.forward_head(x) x = self.forward_head(x)
return x return x
def forward(self, x):
return self.forward_classifier(x)
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
@ -542,11 +540,10 @@ def checkpoint_filter_fn(state_dict, model):
if 'state_dict' in state_dict: if 'state_dict' in state_dict:
state_dict = state_dict['state_dict'] state_dict = state_dict['state_dict']
out_dict = {} out_dict = {}
import re
for k, v in state_dict.items(): for k, v in state_dict.items():
k = k.replace('main_blocks.', 'stages.stage_')
k = k.replace('head.', 'head.fc.') k = k.replace('head.', 'head.fc.')
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
@ -554,8 +551,25 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(DaViT, variant, pretrained, model_cls = DaViT
pretrained_filter_fn=checkpoint_filter_fn, **kwargs) features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model return model
@ -573,13 +587,13 @@ def _cfg(url='', **kwargs): # not sure how this should be set up
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg( 'davit_tiny.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_tiny_ed28dd55.pth.tar"),
'davit_small.msft_in1k': _cfg( 'davit_small.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg( 'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
}) })

Loading…
Cancel
Save