From ccfeb06936549f19c453b7f1f27e8e632cfbe1c2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Jan 2022 19:30:51 -0800 Subject: [PATCH] Fix out_indices handling breakage, should have left as per vgg approach. --- timm/models/cspnet.py | 7 +++---- timm/models/vgg.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 2408a996..8465e97d 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -408,14 +408,13 @@ class CspNet(nn.Module): def _create_cspnet(variant, pretrained=False, **kwargs): cfg_variant = variant.split('_')[0] - if 'darknet' in variant: - # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] - kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5)) + # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] + out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4)) return build_model_with_cfg( CspNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[cfg_variant], - feature_cfg=dict(flatten_sequential=True), + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) diff --git a/timm/models/vgg.py b/timm/models/vgg.py index 42835253..b2fe07c8 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -180,12 +180,12 @@ def _filter_fn(state_dict): def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: cfg = variant.split('_')[0] # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] - kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5)) + out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) model = build_model_with_cfg( VGG, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfgs[cfg], - feature_cfg=dict(flatten_sequential=True), + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), pretrained_filter_fn=_filter_fn, **kwargs) return model