Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent fb074f89ba
commit 5d8ea5a21d

@ -23,8 +23,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
from .features import FeatureInfo
from collections import OrderedDict from collections import OrderedDict
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from .pretrained import generate_default_cfgs from .pretrained import generate_default_cfgs
@ -381,8 +382,7 @@ class DaViT(nn.Module):
drop_rate=0., drop_rate=0.,
attn_drop_rate=0., attn_drop_rate=0.,
num_classes=1000, num_classes=1000,
global_pool='avg', global_pool='avg'
**kwargs
): ):
super().__init__() super().__init__()
@ -439,12 +439,7 @@ class DaViT(nn.Module):
]) ])
self.stages.add_module(f'stage_{stage_id}', stage) self.stages.add_module(f'stage_{stage_id}', stage)
self.feature_info += [dict(num_chs=self.embed_dims[item], reduction=2, module=f'stages.stage_{stage_id}')]
self.feature_info += [dict(
num_chs=self.embed_dims[stage_id],
reduction = 2,
module=f'stages.stage_{stage_id}.{depths[stage_id] - 1}.{len(attention_types) - 1}.mlp')]
self.norms = norm_layer(self.num_features) self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -498,15 +493,7 @@ class DaViT(nn.Module):
# non-normalized pyramid features + corresponding sizes # non-normalized pyramid features + corresponding sizes
return features, sizes return features, sizes
def forward_pyramid_features(self, x):
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def forward_features(self, x): def forward_features(self, x):
x, sizes = self.forward_network(x) x, sizes = self.forward_network(x)
# take final feature and norm # take final feature and norm
@ -523,6 +510,22 @@ class DaViT(nn.Module):
x = self.forward_head(x) x = self.forward_head(x)
return x return x
class DaViTFeatures(DaViT):
def __init__(*args):
super(DaViT, self).__init__(*args, **kwargs)
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
self.feature_info = FeatureInfo(self.feature_info, out_indices)
def forward_pyramid_features(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
@ -543,15 +546,23 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) model_cls = HighResolutionNet
out_indices = kwargs.pop('out_indices', default_out_indices) features_only = False
kwargs_filter = None
if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg( model = build_model_with_cfg(
DaViT, model_cls,
variant, variant,
pretrained, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs) **kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model return model

Loading…
Cancel
Save