Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent b83430350f
commit 1b76714397

@ -514,7 +514,7 @@ class DaViT(nn.Module):
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = nn.ModuleList(stages) self.stages = SequentialWithSize(*stages)
self.norms = norm_layer(self.num_features) self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -544,7 +544,7 @@ class DaViT(nn.Module):
global_pool = self.head.global_pool.pool_type global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
'''
def forward_network(self, x : Tensor): def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3)) size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x] features = [x]
@ -570,12 +570,19 @@ class DaViT(nn.Module):
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
return outs return outs
'''
def forward_features(self, x): def forward_features(self, x):
x, sizes = self.forward_network(x) #x, sizes = self.forward_network(x)
size: Tuple[int, int] = (x.size(2), x.size(3))
x, size = stages(x, size)
# take final feature and norm # take final feature and norm
x = self.norms(x[-1]) x = self.norms(x)
H, W = sizes[-1] H, W = sizes
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
return x return x
@ -590,7 +597,8 @@ class DaViT(nn.Module):
def forward(self, x): def forward(self, x):
return self.forward_classifier(x) return self.forward_classifier(x)
'''
class DaViTFeatures(DaViT): class DaViTFeatures(DaViT):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
@ -600,6 +608,8 @@ class DaViTFeatures(DaViT):
def forward(self, x) -> List[Tensor]: def forward(self, x) -> List[Tensor]:
return self.forward_pyramid_features(x) return self.forward_pyramid_features(x)
'''
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """ """ Remap MSFT checkpoints -> timm """
@ -620,7 +630,7 @@ def checkpoint_filter_fn(state_dict, model):
return out_dict return out_dict
'''
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT model_cls = DaViT
features_only = False features_only = False
@ -643,6 +653,19 @@ def _create_davit(variant, pretrained=False, **kwargs):
model.default_cfg = model.pretrained_cfg # backwards compat model.default_cfg = model.pretrained_cfg # backwards compat
return model return model
'''
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs): # not sure how this should be set up def _cfg(url='', **kwargs): # not sure how this should be set up

Loading…
Cancel
Save