Merge pull request #1 from sgugger/add_hf_hub

PoC for using HF model hub
pull/440/head
Sylvain Gugger 4 years ago committed by GitHub
commit 8c270ed7bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
from .registry import is_model, is_model_in_modules, model_entrypoint from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint from .helpers import load_checkpoint, load_hf_checkpoint_config
from .layers import set_layer_config from .layers import set_layer_config
@ -52,7 +52,16 @@ def create_model(
create_fn = model_entrypoint(model_name) create_fn = model_entrypoint(model_name)
model = create_fn(**model_args, **kwargs) model = create_fn(**model_args, **kwargs)
else: else:
raise RuntimeError('Unknown model (%s)' % model_name) try:
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
create_fn = model_entrypoint(model_cfg.pop("architecture"))
model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs)
# Probably need some extra stuff, but this is a PoC of how the config in the model hub
# could overwrite the default config values.
# model.default_cfg.update(model_cfg)
except Exception as e:
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)
if checkpoint_path: if checkpoint_path:
load_checkpoint(model, checkpoint_path) load_checkpoint(model, checkpoint_path)

@ -3,11 +3,12 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import logging import logging
import json
import os import os
import math import math
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy from copy import deepcopy
from typing import Callable from typing import Callable, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -17,13 +18,31 @@ try:
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
from huggingface_hub import cached_download, hf_hub_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear
from ..version import __version__
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def get_cache_dir():
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
return model_dir
def load_state_dict(checkpoint_path, use_ema=False): def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
@ -120,15 +139,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
return return
url = cfg['url'] url = cfg['url']
# Issue warning to move data if old env is set model_dir = get_cache_dir()
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url) parts = urlparse(url)
filename = os.path.basename(parts.path) filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename) cached_file = os.path.join(model_dir, filename)
@ -173,14 +184,27 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight return conv_weight
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
progress=False, hf_checkpoint=None, hf_revision=None):
if cfg is None: if cfg is None:
cfg = getattr(model, 'default_cfg') cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']: if hf_checkpoint is None:
hg_checkpoint = cfg.get('hf_checkpoint')
if hf_revision is None:
hg_revision = cfg.get('hf_revision')
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') if hf_checkpoint is not None:
# TODO, progress is ignored.
url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision)
cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
state_dict = torch.load(cached_filed, map_location='cpu')
else:
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if filter_fn is not None: if filter_fn is not None:
state_dict = filter_fn(state_dict) state_dict = filter_fn(state_dict)
@ -336,6 +360,8 @@ def build_model_with_cfg(
pretrained_custom_load: bool = False, pretrained_custom_load: bool = False,
**kwargs): **kwargs):
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
hf_checkpoint = kwargs.pop('hf_checkpoint', None)
hf_revision = kwargs.pop('hf_revision', None)
features = False features = False
feature_cfg = feature_cfg or {} feature_cfg = feature_cfg or {}
@ -353,14 +379,15 @@ def build_model_with_cfg(
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained: if pretrained or hf_checkpoint is not None:
if pretrained_custom_load: if pretrained_custom_load:
load_custom_pretrained(model) load_custom_pretrained(model)
else: else:
load_pretrained( load_pretrained(
model, model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict) filter_fn=pretrained_filter_fn, strict=pretrained_strict,
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision)
if features: if features:
feature_cls = FeatureListNet feature_cls = FeatureListNet
@ -376,3 +403,17 @@ def build_model_with_cfg(
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model return model
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None):
url = hf_hub_url(checkpoint, "config.json", revision=revision)
cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
return load_cfg_from_json(cached_filed)

@ -37,6 +37,7 @@ def register_model(fn):
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name]
if has_pretrained: if has_pretrained:
_model_has_pretrained.add(model_name) _model_has_pretrained.add(model_name)
return fn return fn

@ -53,6 +53,7 @@ default_cfgs = {
interpolation='bicubic'), interpolation='bicubic'),
'resnet50d': _cfg( 'resnet50d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
hf_revision="master", hf_model_id="sgugger/resnet50d",
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101': _cfg(url='', interpolation='bicubic'),
'resnet101d': _cfg( 'resnet101d': _cfg(

Loading…
Cancel
Save