A few minor fixes and bit more cleanup on the huggingface hub integration.

pull/501/head
Ross Wightman 3 years ago
parent ead80d33c5
commit 45c048ba13

@ -1,7 +1,7 @@
from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint
from .layers import set_layer_config
from .hub import load_config_from_hf
from .hub import load_model_config_from_hf
def split_model_name(model_name):
@ -67,11 +67,9 @@ def create_model(
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if source_name == 'hf_hub':
# Load model weights + default_cfg from Hugging Face hub.
# For model names specified in the form `hf_hub:path/architecture_name#revision`
hf_default_cfg = load_config_from_hf(model_name)
hf_default_cfg['hf_hub'] = model_name # insert hf_hub id for pretrained weight load during creation
model_name = hf_default_cfg.get('architecture')
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
# load model weights + default_cfg from Hugging Face hub.
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
if is_model(model_name):

@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
return default_cfg
def overlay_external_default_cfg(kwargs, default_cfg):
""" Overlay 'default_cfg' in kwargs on top of default_cfg arg.
def overlay_external_default_cfg(default_cfg, kwargs):
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
"""
default_cfg = default_cfg or {}
external_default_cfg = kwargs.pop('external_default_cfg', None)
if external_default_cfg:
default_cfg = deepcopy(default_cfg)
default_cfg.pop('url', None) # url should come from external cfg
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
default_cfg.update(external_default_cfg)
return default_cfg
def set_default_kwargs(kwargs, names, default_cfg):
@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
input_size = default_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[:-2])
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = default_cfg.get('input_size', None)
if input_size is not None:
@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
kwargs.pop(n, None)
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
could/should be replaced by an improved configuration mechanism
Args:
default_cfg: input default_cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
def build_model_with_cfg(
model_cls: Callable,
variant: str,
@ -399,29 +415,20 @@ def build_model_with_cfg(
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
default_cfg = deepcopy(default_cfg) if default_cfg else {}
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
default_cfg.setdefault('architecture', variant)
# Setup for featyre extraction wrapper done at end of this fn
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
# could/should be replaced by an improved configuration mechanism
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
default_cfg = overlay_external_default_cfg(kwargs, default_cfg)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
# Build the model
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
model.default_cfg = default_cfg
if pruned:
model = adapt_model_from_file(model, variant)

@ -23,7 +23,7 @@ except ImportError:
_logger = logging.getLogger(__name__)
def get_cache_dir(child=''):
def get_cache_dir(child_dir=''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
@ -32,8 +32,8 @@ def get_cache_dir(child=''):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
children = () if not child else child,
model_dir = os.path.join(hub_dir, 'checkpoints', *children)
child_dir = () if not child_dir else (child_dir,)
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
os.makedirs(model_dir, exist_ok=True)
return model_dir
@ -80,10 +80,13 @@ def _download_from_hf(model_id: str, filename: str):
return cached_download(url, cache_dir=get_cache_dir('hf'))
def load_config_from_hf(model_id: str):
def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json')
return load_cfg_from_json(cached_file)
default_cfg = load_cfg_from_json(cached_file)
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
model_name = default_cfg.get('architecture')
return default_cfg, model_name
def load_state_dict_from_hf(model_id: str):

@ -21,6 +21,7 @@ import math
import logging
from functools import partial
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
default_cfg = overlay_external_default_cfg(kwargs, default_cfgs[variant])
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-1]
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)

Loading…
Cancel
Save