Update davit.py

clean up
pull/1630/head
Fredo Guan 3 years ago
parent 03c779f1cf
commit d00ac9033a

@ -36,23 +36,13 @@ __all__ = ['DaViT']
# modified nn.Sequential that includes a size tuple in the forward function # modified nn.Sequential that includes a size tuple in the forward function
# FIXME doesn't work with torchscript/JIT # FIXME doesn't work with torchscript/JIT
# Module 'SequentialWithSize' has no attribute '_modules'
class SequentialWithSize(nn.Sequential): class SequentialWithSize(nn.Sequential):
def forward(self, x : Tensor, size: Tuple[int, int]): def forward(self, x : Tensor, size: Tuple[int, int]):
for module in self._modules.values(): for module in self._modules.values():
x, size = module(x, size) x, size = module(x, size)
return x, size return x, size
'''
class SequentialWithSize(nn.Sequential):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def forward(self, x : Tensor, size: Tuple[int, int]):
for module in self._modules.values():
x, size = module(x, size)
return x, size
'''
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
@ -555,36 +545,6 @@ class DaViT(nn.Module):
global_pool = self.head.global_pool.pool_type global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
'''
def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]
for stage in self.stages:
features[-1], sizes[-1] = stage(features[-1], sizes[-1])
# don't append outputs of last stage, since they are already there
if(len(features) < self.num_stages):
features.append(features[-1])
sizes.append(sizes[-1])
# non-normalized pyramid features + corresponding sizes
return features, sizes
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
'''
def forward_features(self, x): def forward_features(self, x):
#x, sizes = self.forward_network(x) #x, sizes = self.forward_network(x)
@ -609,18 +569,6 @@ class DaViT(nn.Module):
return self.forward_classifier(x) return self.forward_classifier(x)
'''
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3)))
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
'''
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
@ -641,30 +589,6 @@ def checkpoint_filter_fn(state_dict, model):
return out_dict return out_dict
'''
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model
'''
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices) out_indices = kwargs.pop('out_indices', default_out_indices)
@ -699,6 +623,9 @@ default_cfgs = generate_default_cfgs({
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"),
'davit_base.msft_in1k': _cfg( 'davit_base.msft_in1k': _cfg(
url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
}) })
@ -721,7 +648,7 @@ def davit_base(pretrained=False, **kwargs):
num_heads=(4, 8, 16, 32), **kwargs) num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
''' models without weights
# TODO contact authors to get larger pretrained models # TODO contact authors to get larger pretrained models
@register_model @register_model
def davit_large(pretrained=False, **kwargs): def davit_large(pretrained=False, **kwargs):
@ -740,4 +667,3 @@ def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072),
num_heads=(12, 24, 48, 96), **kwargs) num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)
'''
Loading…
Cancel
Save