Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 76b0d9066e
commit 729fd0d53b

@ -788,7 +788,7 @@ class DaViT(nn.Module):
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
'''
def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
@ -808,16 +808,16 @@ class DaViT(nn.Module):
def forward_pyramid_features(self, x) -> List[Tensor]:
x = self.forward_network(x)
'''
outs = []
for i, out in enumerate(x):
H, W = sizes[i]
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
'''
return x
'''
def forward_features(self, x):
x = self.forward_network(x)
x = self.stages(x)
# take final feature and norm
x = self.norms(x[-1].permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#H, W = sizes[-1]
@ -834,7 +834,7 @@ class DaViT(nn.Module):
def forward(self, x):
return self.forward_classifier(x)
'''
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
@ -843,7 +843,7 @@ class DaViTFeatures(DaViT):
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
'''
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict:
@ -866,16 +866,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls,
variant,
@ -883,9 +878,7 @@ def _create_davit(variant, pretrained=False, **kwargs):
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
if features_only:
model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg)
model.default_cfg = model.pretrained_cfg # backwards compat
return model

Loading…
Cancel
Save