Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent b83430350f
commit 1b76714397

@ -514,7 +514,7 @@ class DaViT(nn.Module):
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = nn.ModuleList(stages)
self.stages = SequentialWithSize(*stages)
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -544,7 +544,7 @@ class DaViT(nn.Module):
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
'''
def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
@ -570,12 +570,19 @@ class DaViT(nn.Module):
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs
'''
def forward_features(self, x):
x, sizes = self.forward_network(x)
#x, sizes = self.forward_network(x)
size: Tuple[int, int] = (x.size(2), x.size(3))
x, size = stages(x, size)
# take final feature and norm
x = self.norms(x[-1])
H, W = sizes[-1]
x = self.norms(x)
H, W = sizes
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x
@ -590,7 +597,8 @@ class DaViT(nn.Module):
def forward(self, x):
return self.forward_classifier(x)
'''
class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs):
@ -600,6 +608,8 @@ class DaViTFeatures(DaViT):
def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x)
'''
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
@ -620,7 +630,7 @@ def checkpoint_filter_fn(state_dict, model):
return out_dict
'''
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
@ -643,6 +653,19 @@ def _create_davit(variant, pretrained=False, **kwargs):
model.default_cfg = model.pretrained_cfg # backwards compat
return model
'''
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs): # not sure how this should be set up

Loading…
Cancel
Save