Fix out_indices handling breakage, should have left as per vgg approach.

pull/1083/head
Ross Wightman 3 years ago
parent a9f91483a6
commit ccfeb06936

@ -408,14 +408,13 @@ class CspNet(nn.Module):
def _create_cspnet(variant, pretrained=False, **kwargs): def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0] cfg_variant = variant.split('_')[0]
if 'darknet' in variant: # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5] out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5) if 'darknet' in variant else (0, 1, 2, 3, 4))
kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5))
return build_model_with_cfg( return build_model_with_cfg(
CspNet, variant, pretrained, CspNet, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[cfg_variant], model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs) **kwargs)

@ -180,12 +180,12 @@ def _filter_fn(state_dict):
def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
cfg = variant.split('_')[0] cfg = variant.split('_')[0]
# NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5] # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
kwargs.setdefault('out_indices', (0, 1, 2, 3, 4, 5)) out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
model = build_model_with_cfg( model = build_model_with_cfg(
VGG, variant, pretrained, VGG, variant, pretrained,
default_cfg=default_cfgs[variant], default_cfg=default_cfgs[variant],
model_cfg=cfgs[cfg], model_cfg=cfgs[cfg],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
pretrained_filter_fn=_filter_fn, pretrained_filter_fn=_filter_fn,
**kwargs) **kwargs)
return model return model

Loading…
Cancel
Save