|
|
@ -17,7 +17,7 @@ try:
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
|
|
|
|
|
|
|
|
from huggingface_hub import cached_download
|
|
|
|
from huggingface_hub import cached_download, hf_hub_url
|
|
|
|
|
|
|
|
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
|
|
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
|
|
|
|
from .layers import Conv2dSame, Linear
|
|
|
|
from .layers import Conv2dSame, Linear
|
|
|
@ -138,10 +138,17 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
|
|
|
|
return
|
|
|
|
return
|
|
|
|
url = cfg['url']
|
|
|
|
url = cfg['url']
|
|
|
|
|
|
|
|
|
|
|
|
# TODO, progress and check_hash are ignored.
|
|
|
|
model_dir = get_cache_dir()
|
|
|
|
cached_filed = cached_download(
|
|
|
|
parts = urlparse(url)
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
filename = os.path.basename(parts.path)
|
|
|
|
)
|
|
|
|
cached_file = os.path.join(model_dir, filename)
|
|
|
|
|
|
|
|
if not os.path.exists(cached_file):
|
|
|
|
|
|
|
|
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
|
|
|
|
|
|
|
hash_prefix = None
|
|
|
|
|
|
|
|
if check_hash:
|
|
|
|
|
|
|
|
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
|
|
|
|
|
|
|
hash_prefix = r.group(1) if r else None
|
|
|
|
|
|
|
|
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
|
|
|
|
|
|
|
|
|
|
|
if load_fn is not None:
|
|
|
|
if load_fn is not None:
|
|
|
|
load_fn(model, cached_file)
|
|
|
|
load_fn(model, cached_file)
|
|
|
@ -179,15 +186,19 @@ def adapt_input_conv(in_chans, conv_weight):
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
|
|
|
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
|
|
|
|
if cfg is None:
|
|
|
|
if cfg is None:
|
|
|
|
cfg = getattr(model, 'default_cfg')
|
|
|
|
cfg = getattr(model, 'default_cfg')
|
|
|
|
if cfg is None or 'url' not in cfg or not cfg['url']:
|
|
|
|
if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']:
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if cfg.get('hf_checkpoint') is not None:
|
|
|
|
# TODO, progress is ignored.
|
|
|
|
# TODO, progress is ignored.
|
|
|
|
|
|
|
|
url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision'))
|
|
|
|
cached_filed = cached_download(
|
|
|
|
cached_filed = cached_download(
|
|
|
|
cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
|
|
|
|
)
|
|
|
|
)
|
|
|
|
state_dict = torch.load(cached_filed, map_location='cpu')
|
|
|
|
state_dict = torch.load(cached_filed, map_location='cpu')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
|
|
|
|
if filter_fn is not None:
|
|
|
|
if filter_fn is not None:
|
|
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
state_dict = filter_fn(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|