|
|
|
@ -514,7 +514,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = nn.ModuleList(stages)
|
|
|
|
|
self.stages = SequentialWithSize(*stages)
|
|
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
@ -544,7 +544,7 @@ class DaViT(nn.Module):
|
|
|
|
|
global_pool = self.head.global_pool.pool_type
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def forward_network(self, x : Tensor):
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
features = [x]
|
|
|
|
@ -570,12 +570,19 @@ class DaViT(nn.Module):
|
|
|
|
|
outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous())
|
|
|
|
|
|
|
|
|
|
return outs
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
#x, sizes = self.forward_network(x)
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
x, size = stages(x, size)
|
|
|
|
|
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x[-1])
|
|
|
|
|
H, W = sizes[-1]
|
|
|
|
|
x = self.norms(x)
|
|
|
|
|
H, W = sizes
|
|
|
|
|
x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
@ -591,6 +598,7 @@ class DaViT(nn.Module):
|
|
|
|
|
return self.forward_classifier(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
@ -600,6 +608,8 @@ class DaViTFeatures(DaViT):
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
return self.forward_pyramid_features(x)
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
""" Remap MSFT checkpoints -> timm """
|
|
|
|
@ -620,7 +630,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
model_cls = DaViT
|
|
|
|
|
features_only = False
|
|
|
|
@ -643,6 +653,19 @@ def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
model.default_cfg = model.pretrained_cfg # backwards compat
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
DaViT,
|
|
|
|
|
variant,
|
|
|
|
|
pretrained,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs): # not sure how this should be set up
|
|
|
|
|