|
|
|
@ -180,10 +180,14 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
|
|
|
|
|
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
|
|
|
|
|
pretrained_url = default_cfg.get('url', None)
|
|
|
|
|
hf_hub_id = default_cfg.get('hf_hub', None)
|
|
|
|
|
pth_file_path = default_cfg.get('pth_path', None)
|
|
|
|
|
if not pretrained_url and not hf_hub_id:
|
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
|
return
|
|
|
|
|
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
|
|
|
|
|
if pth_file_path is not None:
|
|
|
|
|
# load pretrained from local file
|
|
|
|
|
state_dict = torch.load(pth_file_path, map_location='cpu')
|
|
|
|
|
elif hf_hub_id and has_hf_hub(necessary=not pretrained_url):
|
|
|
|
|
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
|
|
|
|
|
state_dict = load_state_dict_from_hf(hf_hub_id)
|
|
|
|
|
else:
|
|
|
|
|