From a9f91483a69b2c0b69b89457c8793958bd722fbb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Jan 2022 15:08:32 -0800 Subject: [PATCH] Fix #1078, DarkNet has 6 feature maps. Make vgg and darknet out_indices handling/comments equivalent --- timm/models/cspnet.py | 6 +++++- timm/models/vgg.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 39d16200..2408a996 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -408,10 +408,14 @@ class CspNet(nn.Module): def _create_cspnet(variant, pretrained=False, **kwargs): cfg_variant = variant.split('_')[0] + if 'darknet' in variant: + # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] + kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5)) return build_model_with_cfg( CspNet, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], + model_cfg=model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), **kwargs) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 11f6d0ea..42835253 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -179,13 +179,13 @@ def _filter_fn(state_dict): def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: cfg = variant.split('_')[0] - # NOTE: VGG is one of the only models with stride==1 features, so indices are offset from other models - out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) + # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] + kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5)) model = build_model_with_cfg( VGG, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfgs[cfg], - feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_fn, **kwargs) return model