Update helpers.py

load pretrained from local file
pull/967/head
jim4399266 4 years ago committed by GitHub
parent a457b6d14d
commit caf28b6bc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -180,10 +180,14 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
pretrained_url = default_cfg.get('url', None)
hf_hub_id = default_cfg.get('hf_hub', None)
pth_file_path = default_cfg.get('pth_path', None)
if not pretrained_url and not hf_hub_id:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
if pth_file_path is not None:
# load pretrained from local file
state_dict = torch.load(pth_file_path, map_location='cpu')
elif hf_hub_id and has_hf_hub(necessary=not pretrained_url):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
else:

Loading…
Cancel
Save