Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent e41e6ced6f
commit b355135a0a

@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
from .features import FeatureInfo
from .fx_features import register_notrace_function, register_notrace_module
from .helpers import build_model_with_cfg, checkpoint_seq, pretrained_cfg_for_features
from .helpers import build_model_with_cfg, checkpoint_seq pretrained_cfg_for_features
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .pretrained import generate_default_cfgs
from .registry import register_model
@ -35,17 +35,9 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['DaViT']
class SequentialWithSize(nn.Sequential):
def __init__(self, *args, **kwargs):
super(SequentialWithSize, self).__init__(*args, **kwargs)
def forward(self, x: Tensor, size: Tuple[int, int]):
for module in self.__iter__():
def forward(self, x : Tensor, size: Tuple[int, int]):
for module in self._modules.values():
x, size = module(x, size)
'''
output = module(x, size)
x : Tensor = output[0]
size : Tuple[int, int] = output[1]
'''
return x, size
@ -419,17 +411,19 @@ class DaViTStage(nn.Module):
window_size=window_size,
))
stage_blocks.append(SequentialWithSize(*dual_attention_block))
stage_blocks.append(nn.ModuleList(*dual_attention_block))
self.blocks = SequentialWithSize(*stage_blocks)
self.blocks = nn.ModuleList(*stage_blocks)
def forward(self, x : Tensor, size: Tuple[int, int]):
x, size = self.patch_embed(x, size)
for block in self.blocks
for layer in block:
if self.grad_checkpointing and not torch.jit.is_scripting():
x, size = checkpoint_seq(self.blocks, x, size)
x, size = checkpoint.checkpoint(layer, x, size)
else:
x, size = self.blocks(x, size)
x, size = layer(x, size)
return x, size
class DaViT(nn.Module):
@ -514,7 +508,7 @@ class DaViT(nn.Module):
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = SequentialWithSize(*stages)
self.stages = nn.ModuleList(*stages)
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -545,22 +539,61 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
def forward_network(self, x):
size: Tuple[int, int] = (x.size(2), x.size(3))
x, size = self.stages(x, size)
x = self.norms(x)
H, W = size
features = [x]
sizes = [size]
for stage in self.stages:
features[-1], sizes[-1] = stage(features[-1], sizes[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x):
x, sizes = self.forward_network(x)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
def forward_classifier(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def forward(self, x):
return self.forward_classifier(x)
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
@ -580,15 +613,25 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
DaViT,
model_cls,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model

Loading…
Cancel
Save