From d00ac9033a8fc56cdcbce525c43f161fa927b983 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 05:21:00 -0800 Subject: [PATCH] Update davit.py clean up --- timm/models/davit.py | 84 +++----------------------------------------- 1 file changed, 5 insertions(+), 79 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 8d288bd0..74d51f73 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -36,23 +36,13 @@ __all__ = ['DaViT'] # modified nn.Sequential that includes a size tuple in the forward function # FIXME doesn't work with torchscript/JIT - +# Module 'SequentialWithSize' has no attribute '_modules' class SequentialWithSize(nn.Sequential): def forward(self, x : Tensor, size: Tuple[int, int]): for module in self._modules.values(): x, size = module(x, size) return x, size -''' -class SequentialWithSize(nn.Sequential): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def forward(self, x : Tensor, size: Tuple[int, int]): - for module in self._modules.values(): - x, size = module(x, size) - return x, size -''' class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): @@ -555,37 +545,7 @@ class DaViT(nn.Module): global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - ''' - def forward_network(self, x : Tensor): - size: Tuple[int, int] = (x.size(2), x.size(3)) - features = [x] - sizes = [size] - - for stage in self.stages: - features[-1], sizes[-1] = stage(features[-1], sizes[-1]) - - # don't append outputs of last stage, since they are already there - if(len(features) < self.num_stages): - features.append(features[-1]) - sizes.append(sizes[-1]) - - # non-normalized pyramid features + corresponding sizes - return features, sizes - - def forward_pyramid_features(self, x) -> List[Tensor]: - x, sizes = self.forward_network(x) - outs = [] - for i, out in enumerate(x): - H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - - return outs - ''' - - - - def forward_features(self, x): #x, sizes = self.forward_network(x) size: Tuple[int, int] = (x.size(2), x.size(3)) @@ -608,18 +568,6 @@ class DaViT(nn.Module): def forward(self, x): return self.forward_classifier(x) - -''' -class DaViTFeatures(DaViT): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) - - def forward(self, x) -> List[Tensor]: - return self.forward_pyramid_features(x) - -''' def checkpoint_filter_fn(state_dict, model): @@ -640,31 +588,7 @@ def checkpoint_filter_fn(state_dict, model): out_dict[k] = v return out_dict - -''' -def _create_davit(variant, pretrained=False, **kwargs): - model_cls = DaViT - features_only = False - kwargs_filter = None - default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) - out_indices = kwargs.pop('out_indices', default_out_indices) - if kwargs.pop('features_only', False): - model_cls = DaViTFeatures - kwargs_filter = ('num_classes', 'global_pool') - features_only = True - model = build_model_with_cfg( - model_cls, - variant, - pretrained, - pretrained_filter_fn=checkpoint_filter_fn, - feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), - **kwargs) - if features_only: - model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) - model.default_cfg = model.pretrained_cfg # backwards compat - return model -''' def _create_davit(variant, pretrained=False, **kwargs): default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) @@ -699,6 +623,9 @@ default_cfgs = generate_default_cfgs({ url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_small_d1ecf281.pth.tar"), 'davit_base.msft_in1k': _cfg( url="https://github.com/fffffgggg54/pytorch-image-models/releases/download/checkpoint/davit_base_67d9ac26.pth.tar"), + 'davit_large': _cfg(), + 'davit_huge': _cfg(), + 'davit_giant': _cfg(), }) @@ -721,7 +648,7 @@ def davit_base(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs) return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) -''' models without weights + # TODO contact authors to get larger pretrained models @register_model def davit_large(pretrained=False, **kwargs): @@ -740,4 +667,3 @@ def davit_giant(pretrained=False, **kwargs): model_kwargs = dict(depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs) return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) -''' \ No newline at end of file