Be more opt-in

pull/440/head
Sylvain Gugger 5 years ago
parent ebe69dd4d3
commit 9857e12c0c

@ -17,7 +17,7 @@ try:
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
from huggingface_hub import cached_download from huggingface_hub import cached_download, hf_hub_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
@ -138,10 +138,17 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
return return
url = cfg['url'] url = cfg['url']
# TODO, progress and check_hash are ignored. model_dir = get_cache_dir()
cached_filed = cached_download( parts = urlparse(url)
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() filename = os.path.basename(parts.path)
) cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
if load_fn is not None: if load_fn is not None:
load_fn(model, cached_file) load_fn(model, cached_file)
@ -179,15 +186,19 @@ def adapt_input_conv(in_chans, conv_weight):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
if cfg is None: if cfg is None:
cfg = getattr(model, 'default_cfg') cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']: if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']:
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
if cfg.get('hf_checkpoint') is not None:
# TODO, progress is ignored. # TODO, progress is ignored.
url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision'))
cached_filed = cached_download( cached_filed = cached_download(
cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir() url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
) )
state_dict = torch.load(cached_filed, map_location='cpu') state_dict = torch.load(cached_filed, map_location='cpu')
else:
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if filter_fn is not None: if filter_fn is not None:
state_dict = filter_fn(state_dict) state_dict = filter_fn(state_dict)

@ -37,6 +37,7 @@ def register_model(fn):
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name]
if has_pretrained: if has_pretrained:
_model_has_pretrained.add(model_name) _model_has_pretrained.add(model_name)
return fn return fn

@ -52,7 +52,8 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth',
interpolation='bicubic'), interpolation='bicubic'),
'resnet50d': _cfg( 'resnet50d': _cfg(
url='https://huggingface.co/sgugger/resnet50d/resolve/main/pytorch_model.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
hf_revision="master", hf_model_id="sgugger/resnet50d",
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101': _cfg(url='', interpolation='bicubic'),
'resnet101d': _cfg( 'resnet101d': _cfg(

Loading…
Cancel
Save