Merge pull request #1 from sgugger/add_hf_hub

PoC for using HF model hub
pull/440/head
Sylvain Gugger 4 years ago committed by GitHub
commit 8c270ed7bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint
from .helpers import load_checkpoint, load_hf_checkpoint_config
from .layers import set_layer_config
@ -52,7 +52,16 @@ def create_model(
create_fn = model_entrypoint(model_name)
model = create_fn(**model_args, **kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
try:
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
create_fn = model_entrypoint(model_cfg.pop("architecture"))
model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs)
# Probably need some extra stuff, but this is a PoC of how the config in the model hub
# could overwrite the default config values.
# model.default_cfg.update(model_cfg)
except Exception as e:
raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)

@ -3,11 +3,12 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
import json
import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
@ -17,13 +18,31 @@ try:
except ImportError:
from torch.hub import _get_torch_home as get_dir
from huggingface_hub import cached_download, hf_hub_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .layers import Conv2dSame, Linear
from ..version import __version__
_logger = logging.getLogger(__name__)
def get_cache_dir():
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
return model_dir
def load_state_dict(checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
@ -120,15 +139,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
return
url = cfg['url']
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
model_dir = get_cache_dir()
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
@ -173,13 +184,26 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
progress=False, hf_checkpoint=None, hf_revision=None):
if cfg is None:
cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']:
if hf_checkpoint is None:
hg_checkpoint = cfg.get('hf_checkpoint')
if hf_revision is None:
hg_revision = cfg.get('hf_revision')
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if hf_checkpoint is not None:
# TODO, progress is ignored.
url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision)
cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
state_dict = torch.load(cached_filed, map_location='cpu')
else:
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if filter_fn is not None:
state_dict = filter_fn(state_dict)
@ -336,6 +360,8 @@ def build_model_with_cfg(
pretrained_custom_load: bool = False,
**kwargs):
pruned = kwargs.pop('pruned', False)
hf_checkpoint = kwargs.pop('hf_checkpoint', None)
hf_revision = kwargs.pop('hf_revision', None)
features = False
feature_cfg = feature_cfg or {}
@ -353,14 +379,15 @@ def build_model_with_cfg(
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained or hf_checkpoint is not None:
if pretrained_custom_load:
load_custom_pretrained(model)
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
filter_fn=pretrained_filter_fn, strict=pretrained_strict,
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision)
if features:
feature_cls = FeatureListNet
@ -376,3 +403,17 @@ def build_model_with_cfg(
model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg
return model
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None):
url = hf_hub_url(checkpoint, "config.json", revision=revision)
cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
)
return load_cfg_from_json(cached_filed)

@ -37,6 +37,7 @@ def register_model(fn):
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name]
if has_pretrained:
_model_has_pretrained.add(model_name)
return fn

@ -53,6 +53,7 @@ default_cfgs = {
interpolation='bicubic'),
'resnet50d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth',
hf_revision="master", hf_model_id="sgugger/resnet50d",
interpolation='bicubic', first_conv='conv1.0'),
'resnet101': _cfg(url='', interpolation='bicubic'),
'resnet101d': _cfg(

Loading…
Cancel
Save