|
|
|
@ -184,19 +184,19 @@ def adapt_input_conv(in_chans, conv_weight):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
|
|
|
|
|
progress=False, hf_checkpoint=None, hf_revision=None):
|
|
|
|
|
progress=False, hf_model_id=None, hf_revision=None):
|
|
|
|
|
cfg = cfg or getattr(model, 'default_cfg')
|
|
|
|
|
if hf_checkpoint is None:
|
|
|
|
|
hg_checkpoint = cfg.get('hf_checkpoint')
|
|
|
|
|
if hf_model_id is None:
|
|
|
|
|
hg_checkpoint = cfg.get('hf_model_id')
|
|
|
|
|
if hf_revision is None:
|
|
|
|
|
hg_revision = cfg.get('hf_revision')
|
|
|
|
|
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None):
|
|
|
|
|
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_model_id is None):
|
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if hf_checkpoint is not None:
|
|
|
|
|
if hf_model_id is not None:
|
|
|
|
|
# TODO, progress is ignored.
|
|
|
|
|
url = hf_hub_url(hf_checkpoint, "pytorch_model.bin", revision=hf_revision)
|
|
|
|
|
url = hf_hub_url(hf_model_id, "pytorch_model.bin", revision=hf_revision)
|
|
|
|
|
cached_filed = cached_download(
|
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
|
)
|
|
|
|
@ -358,7 +358,7 @@ def build_model_with_cfg(
|
|
|
|
|
pretrained_custom_load: bool = False,
|
|
|
|
|
**kwargs):
|
|
|
|
|
pruned = kwargs.pop('pruned', False)
|
|
|
|
|
hf_checkpoint = kwargs.pop('hf_checkpoint', None)
|
|
|
|
|
hf_model_id = kwargs.pop('hf_model_id', None)
|
|
|
|
|
hf_revision = kwargs.pop('hf_revision', None)
|
|
|
|
|
features = False
|
|
|
|
|
feature_cfg = feature_cfg or {}
|
|
|
|
@ -377,7 +377,7 @@ def build_model_with_cfg(
|
|
|
|
|
|
|
|
|
|
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
|
|
|
|
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
|
|
|
|
if pretrained or hf_checkpoint is not None:
|
|
|
|
|
if pretrained or hf_model_id is not None:
|
|
|
|
|
if pretrained_custom_load:
|
|
|
|
|
load_custom_pretrained(model)
|
|
|
|
|
else:
|
|
|
|
@ -385,7 +385,7 @@ def build_model_with_cfg(
|
|
|
|
|
model,
|
|
|
|
|
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
filter_fn=pretrained_filter_fn, strict=pretrained_strict,
|
|
|
|
|
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision)
|
|
|
|
|
hf_model_id=hf_model_id, hf_revision=hf_revision)
|
|
|
|
|
|
|
|
|
|
if features:
|
|
|
|
|
feature_cls = FeatureListNet
|
|
|
|
|