|
|
|
@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
|
|
|
|
|
return default_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def overlay_external_default_cfg(kwargs, default_cfg):
|
|
|
|
|
""" Overlay 'default_cfg' in kwargs on top of default_cfg arg.
|
|
|
|
|
def overlay_external_default_cfg(default_cfg, kwargs):
|
|
|
|
|
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
|
|
|
|
|
"""
|
|
|
|
|
default_cfg = default_cfg or {}
|
|
|
|
|
external_default_cfg = kwargs.pop('external_default_cfg', None)
|
|
|
|
|
if external_default_cfg:
|
|
|
|
|
default_cfg = deepcopy(default_cfg)
|
|
|
|
|
default_cfg.pop('url', None) # url should come from external cfg
|
|
|
|
|
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
|
|
|
|
|
default_cfg.update(external_default_cfg)
|
|
|
|
|
return default_cfg
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set_default_kwargs(kwargs, names, default_cfg):
|
|
|
|
@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
|
|
|
|
|
input_size = default_cfg.get('input_size', None)
|
|
|
|
|
if input_size is not None:
|
|
|
|
|
assert len(input_size) == 3
|
|
|
|
|
kwargs.setdefault(n, input_size[:-2])
|
|
|
|
|
kwargs.setdefault(n, input_size[-2:])
|
|
|
|
|
elif n == 'in_chans':
|
|
|
|
|
input_size = default_cfg.get('input_size', None)
|
|
|
|
|
if input_size is not None:
|
|
|
|
@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
|
|
|
|
|
kwargs.pop(n, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
|
|
|
|
|
""" Update the default_cfg and kwargs before passing to model
|
|
|
|
|
|
|
|
|
|
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
|
|
|
|
|
could/should be replaced by an improved configuration mechanism
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
default_cfg: input default_cfg (updated in-place)
|
|
|
|
|
kwargs: keyword args passed to model build fn (updated in-place)
|
|
|
|
|
kwargs_filter: keyword arg keys that must be removed before model __init__
|
|
|
|
|
"""
|
|
|
|
|
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
|
|
|
|
overlay_external_default_cfg(default_cfg, kwargs)
|
|
|
|
|
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
|
|
|
|
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
|
|
|
|
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
|
|
|
|
filter_kwargs(kwargs, names=kwargs_filter)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_model_with_cfg(
|
|
|
|
|
model_cls: Callable,
|
|
|
|
|
variant: str,
|
|
|
|
@ -399,29 +415,20 @@ def build_model_with_cfg(
|
|
|
|
|
pruned = kwargs.pop('pruned', False)
|
|
|
|
|
features = False
|
|
|
|
|
feature_cfg = feature_cfg or {}
|
|
|
|
|
default_cfg = deepcopy(default_cfg) if default_cfg else {}
|
|
|
|
|
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
|
|
|
|
|
default_cfg.setdefault('architecture', variant)
|
|
|
|
|
|
|
|
|
|
# Setup for featyre extraction wrapper done at end of this fn
|
|
|
|
|
# Setup for feature extraction wrapper done at end of this fn
|
|
|
|
|
if kwargs.pop('features_only', False):
|
|
|
|
|
features = True
|
|
|
|
|
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
|
|
|
|
if 'out_indices' in kwargs:
|
|
|
|
|
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
|
|
|
|
|
|
|
|
|
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
|
|
|
|
|
# could/should be replaced by an improved configuration mechanism
|
|
|
|
|
|
|
|
|
|
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
|
|
|
|
|
default_cfg = overlay_external_default_cfg(kwargs, default_cfg)
|
|
|
|
|
|
|
|
|
|
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
|
|
|
|
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
|
|
|
|
|
|
|
|
|
|
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
|
|
|
|
filter_kwargs(kwargs, names=kwargs_filter)
|
|
|
|
|
|
|
|
|
|
# Build the model
|
|
|
|
|
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
|
|
|
|
|
model.default_cfg = deepcopy(default_cfg)
|
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
|
|
|
|
|
if pruned:
|
|
|
|
|
model = adapt_model_from_file(model, variant)
|
|
|
|
|