|
|
@ -349,12 +349,16 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DaViT(nn.Module):
|
|
|
|
class DaViT(nn.Module):
|
|
|
|
r""" Dual Attention Transformer
|
|
|
|
r""" DaViT
|
|
|
|
|
|
|
|
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
|
|
|
|
|
|
|
in_chans (int): Number of input image channels. Default: 3
|
|
|
|
in_chans (int): Number of input image channels. Default: 3
|
|
|
|
embed_dims (tuple(int)): Patch embedding dimension. Default: (64, 128, 192, 256)
|
|
|
|
num_classes (int): Number of classes for classification head. Default: 1000
|
|
|
|
num_heads (tuple(int)): Number of attention heads in different layers. Default: (4, 8, 12, 16)
|
|
|
|
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
|
|
|
|
|
|
|
|
patch_size (int | tuple(int)): Patch size. Default: 4
|
|
|
|
|
|
|
|
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
|
|
|
|
|
|
|
|
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
|
|
|
|
window_size (int): Window size. Default: 7
|
|
|
|
window_size (int): Window size. Default: 7
|
|
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
|
|
|
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
|
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
|
|
|
@ -380,11 +384,10 @@ class DaViT(nn.Module):
|
|
|
|
cpe_act=False,
|
|
|
|
cpe_act=False,
|
|
|
|
drop_rate=0.,
|
|
|
|
drop_rate=0.,
|
|
|
|
attn_drop_rate=0.,
|
|
|
|
attn_drop_rate=0.,
|
|
|
|
img_size=224,
|
|
|
|
|
|
|
|
num_classes=1000,
|
|
|
|
num_classes=1000,
|
|
|
|
global_pool='avg',
|
|
|
|
global_pool='avg',
|
|
|
|
#features_only = False
|
|
|
|
**kwargs
|
|
|
|
**kwargs):
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
architecture = [[index] * item for index, item in enumerate(depths)]
|
|
|
|
architecture = [[index] * item for index, item in enumerate(depths)]
|
|
|
@ -399,7 +402,6 @@ class DaViT(nn.Module):
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
self.num_features = embed_dims[-1]
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
self.drop_rate=drop_rate
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self._features_only = kwargs.get('features_only', False)
|
|
|
|
|
|
|
|
self.feature_info = []
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
|
self.patch_embeds = nn.ModuleList([
|
|
|
@ -409,12 +411,11 @@ class DaViT(nn.Module):
|
|
|
|
overlapped=overlapped_patch)
|
|
|
|
overlapped=overlapped_patch)
|
|
|
|
for i in range(self.num_stages)])
|
|
|
|
for i in range(self.num_stages)])
|
|
|
|
|
|
|
|
|
|
|
|
#main_blocks = []
|
|
|
|
self.stages = nn.ModuleList()
|
|
|
|
self.main_blocks = nn.ModuleList()
|
|
|
|
for stage_id, stage_param in enumerate(self.architecture):
|
|
|
|
for block_id, block_param in enumerate(self.architecture):
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block = nn.ModuleList([
|
|
|
|
stage = nn.ModuleList([
|
|
|
|
nn.ModuleList([
|
|
|
|
nn.ModuleList([
|
|
|
|
ChannelBlock(
|
|
|
|
ChannelBlock(
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
dim=self.embed_dims[item],
|
|
|
@ -438,73 +439,17 @@ class DaViT(nn.Module):
|
|
|
|
window_size=window_size,
|
|
|
|
window_size=window_size,
|
|
|
|
) if attention_type == 'spatial' else None
|
|
|
|
) if attention_type == 'spatial' else None
|
|
|
|
for attention_id, attention_type in enumerate(attention_types)]
|
|
|
|
for attention_id, attention_type in enumerate(attention_types)]
|
|
|
|
) for layer_id, item in enumerate(block_param)
|
|
|
|
) for layer_id, item in enumerate(stage_param)
|
|
|
|
])
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
|
|
self.main_blocks.add_module(f'block_{block_id}', block)
|
|
|
|
self.main_blocks.add_module(f'stage_{stage_id}', stage)
|
|
|
|
|
|
|
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[block_id], reduction = 2, module=f'block_{block_id}')]
|
|
|
|
|
|
|
|
#self.main_blocks = nn.ModuleList(main_blocks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
# layer norms for pyramid feature extraction
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# TODO implement pyramid feature extraction
|
|
|
|
|
|
|
|
#
|
|
|
|
|
|
|
|
# davit should be a good transformer candidate, since the only official implementation
|
|
|
|
|
|
|
|
# is for segmentation and detection
|
|
|
|
|
|
|
|
for i_layer in range(self.num_stages):
|
|
|
|
|
|
|
|
layer = norm_layer(self.embed_dims[i_layer])
|
|
|
|
|
|
|
|
layer_name = f'norm{i_layer}'
|
|
|
|
|
|
|
|
self.add_module(layer_name, layer)
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#self._update_forward_fn()
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction = 2, module=f'stage_{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
|
#self.forward = self._get_forward_fn()
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
|
|
|
self.forward = self.forward_features_full
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.forward = self.forward_classification
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
|
|
|
return self.forward_features_full
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
|
|
|
def _get_forward_fn(self):
|
|
|
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
|
|
|
return self.forward_features_full
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
return self.forward_classification
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
|
|
|
def _update_forward_fn(self):
|
|
|
|
|
|
|
|
if self._features_only == True:
|
|
|
|
|
|
|
|
self.forward = self.forward_pyramid_features
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.forward = self.forward_classification
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
|
|
def features_only(self):
|
|
|
|
|
|
|
|
return self._features_only
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@features_only.setter
|
|
|
|
|
|
|
|
def features_only(self, new_value : bool):
|
|
|
|
|
|
|
|
self._features_only = new_value
|
|
|
|
|
|
|
|
#self.forward = self._get_forward_fn()
|
|
|
|
|
|
|
|
self._update_forward_fn()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.norm = norm_layer(self.num_features)
|
|
|
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
def _init_weights(self, m):
|
|
|
@ -516,8 +461,6 @@ class DaViT(nn.Module):
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
@torch.jit.ignore
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
self.grad_checkpointing = enable
|
|
|
@ -534,15 +477,11 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_network(self, x):
|
|
|
|
def forward_network(self, x):
|
|
|
|
#x, size = self.patch_embeds[0](x, (x.size(2), x.size(3)))
|
|
|
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
features = [x]
|
|
|
|
features = [x]
|
|
|
|
sizes = [size]
|
|
|
|
sizes = [size]
|
|
|
|
#branches = [0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
|
|
|
|
for patch_layer, stage in enumerate(zip(self.patch_embeds, self.stages)):
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
for _, block in enumerate(stage):
|
|
|
|
for _, block in enumerate(stage):
|
|
|
|
for _, layer in enumerate(block):
|
|
|
|
for _, layer in enumerate(block):
|
|
|
@ -551,62 +490,14 @@ class DaViT(nn.Module):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
|
|
|
|
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# don't append outputs of last stage, since they are already there
|
|
|
|
|
|
|
|
if(len(features) < self.num_stages):
|
|
|
|
features.append(features[-1])
|
|
|
|
features.append(features[-1])
|
|
|
|
sizes.append(sizes[-1])
|
|
|
|
sizes.append(sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
for block_index, block_param in enumerate(self.architecture):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
branch_ids = sorted(set(block_param))
|
|
|
|
|
|
|
|
for branch_id in branch_ids:
|
|
|
|
|
|
|
|
if branch_id not in branches:
|
|
|
|
|
|
|
|
x, size = self.patch_embeds[branch_id](features[-1], sizes[-1])
|
|
|
|
|
|
|
|
features.append(x)
|
|
|
|
|
|
|
|
sizes.append(size)
|
|
|
|
|
|
|
|
branches.append(branch_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
block_index : int = block_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if block_index not in branches:
|
|
|
|
|
|
|
|
x, size = self.patch_embeds[block_index](features[-1], sizes[-1])
|
|
|
|
|
|
|
|
features.append(x)
|
|
|
|
|
|
|
|
sizes.append(size)
|
|
|
|
|
|
|
|
branches.append(branch_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer_index, branch_id in enumerate(block_param):
|
|
|
|
|
|
|
|
layer_index : int = layer_index
|
|
|
|
|
|
|
|
branch_id : int = branch_id
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
|
|
|
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
|
|
|
|
|
|
|
|
outs = []
|
|
|
|
|
|
|
|
for i in range(self.num_stages):
|
|
|
|
|
|
|
|
norm_layer = getattr(self, f'norm{i}')
|
|
|
|
|
|
|
|
x_out = norm_layer(features[i])
|
|
|
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
|
|
|
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
outs.append(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
return features[:-1], sizes[:-1]
|
|
|
|
return features, sizes
|
|
|
|
|
|
|
|
|
|
|
|
def forward_pyramid_features(self, x):
|
|
|
|
def forward_pyramid_features(self, x):
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
@ -620,22 +511,19 @@ class DaViT(nn.Module):
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
# take final feature and norm
|
|
|
|
# take final feature and norm
|
|
|
|
x = self.norms(x[-1])
|
|
|
|
x = self.norm(x[-1])
|
|
|
|
H, W = sizes[-1]
|
|
|
|
H, W = sizes[-1]
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
#print(x.shape)
|
|
|
|
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
|
|
|
|
|
|
|
|
def forward_classification(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
x = self.forward_head(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
@ -647,9 +535,8 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
state_dict = state_dict['state_dict']
|
|
|
|
|
|
|
|
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
import re
|
|
|
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
k = k.replace('main_blocks.', 'main_blocks.stage_')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
out_dict[k] = v
|
|
|
|
out_dict[k] = v
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
@ -657,10 +544,15 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
out_indices = (i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
default_out_indices = (i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
feature_cfg = {'out_indices': out_indices}
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
model = build_model_with_cfg(DaViT, variant, pretrained,
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=feature_cfg, **kwargs)
|
|
|
|
DaViT,
|
|
|
|
|
|
|
|
variant,
|
|
|
|
|
|
|
|
pretrained,
|
|
|
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
|
|
|
|
|
|
|
**kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|