A few minor fixes and bit more cleanup on the huggingface hub integration.

pull/501/head
Ross Wightman 4 years ago
parent ead80d33c5
commit 45c048ba13

@ -1,7 +1,7 @@
from .registry import is_model, is_model_in_modules, model_entrypoint from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint from .helpers import load_checkpoint
from .layers import set_layer_config from .layers import set_layer_config
from .hub import load_config_from_hf from .hub import load_model_config_from_hf
def split_model_name(model_name): def split_model_name(model_name):
@ -67,11 +67,9 @@ def create_model(
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
if source_name == 'hf_hub': if source_name == 'hf_hub':
# Load model weights + default_cfg from Hugging Face hub. # For model names specified in the form `hf_hub:path/architecture_name#revision`,
# For model names specified in the form `hf_hub:path/architecture_name#revision` # load model weights + default_cfg from Hugging Face hub.
hf_default_cfg = load_config_from_hf(model_name) hf_default_cfg, model_name = load_model_config_from_hf(model_name)
hf_default_cfg['hf_hub'] = model_name # insert hf_hub id for pretrained weight load during creation
model_name = hf_default_cfg.get('architecture')
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
if is_model(model_name): if is_model(model_name):

@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg):
return default_cfg return default_cfg
def overlay_external_default_cfg(kwargs, default_cfg): def overlay_external_default_cfg(default_cfg, kwargs):
""" Overlay 'default_cfg' in kwargs on top of default_cfg arg. """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
""" """
default_cfg = default_cfg or {}
external_default_cfg = kwargs.pop('external_default_cfg', None) external_default_cfg = kwargs.pop('external_default_cfg', None)
if external_default_cfg: if external_default_cfg:
default_cfg = deepcopy(default_cfg)
default_cfg.pop('url', None) # url should come from external cfg default_cfg.pop('url', None) # url should come from external cfg
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
default_cfg.update(external_default_cfg) default_cfg.update(external_default_cfg)
return default_cfg
def set_default_kwargs(kwargs, names, default_cfg): def set_default_kwargs(kwargs, names, default_cfg):
@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg):
input_size = default_cfg.get('input_size', None) input_size = default_cfg.get('input_size', None)
if input_size is not None: if input_size is not None:
assert len(input_size) == 3 assert len(input_size) == 3
kwargs.setdefault(n, input_size[:-2]) kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans': elif n == 'in_chans':
input_size = default_cfg.get('input_size', None) input_size = default_cfg.get('input_size', None)
if input_size is not None: if input_size is not None:
@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names):
kwargs.pop(n, None) kwargs.pop(n, None)
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
could/should be replaced by an improved configuration mechanism
Args:
default_cfg: input default_cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
def build_model_with_cfg( def build_model_with_cfg(
model_cls: Callable, model_cls: Callable,
variant: str, variant: str,
@ -399,29 +415,20 @@ def build_model_with_cfg(
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
features = False features = False
feature_cfg = feature_cfg or {} feature_cfg = feature_cfg or {}
default_cfg = deepcopy(default_cfg) if default_cfg else {}
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
default_cfg.setdefault('architecture', variant)
# Setup for featyre extraction wrapper done at end of this fn # Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False): if kwargs.pop('features_only', False):
features = True features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs: if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices') feature_cfg['out_indices'] = kwargs.pop('out_indices')
# FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs
# could/should be replaced by an improved configuration mechanism
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
default_cfg = overlay_external_default_cfg(kwargs, default_cfg)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
# Build the model # Build the model
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg) model.default_cfg = default_cfg
if pruned: if pruned:
model = adapt_model_from_file(model, variant) model = adapt_model_from_file(model, variant)

@ -23,7 +23,7 @@ except ImportError:
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
def get_cache_dir(child=''): def get_cache_dir(child_dir=''):
""" """
Returns the location of the directory where models are cached (and creates it if necessary). Returns the location of the directory where models are cached (and creates it if necessary).
""" """
@ -32,8 +32,8 @@ def get_cache_dir(child=''):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir() hub_dir = get_dir()
children = () if not child else child, child_dir = () if not child_dir else (child_dir,)
model_dir = os.path.join(hub_dir, 'checkpoints', *children) model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
return model_dir return model_dir
@ -80,10 +80,13 @@ def _download_from_hf(model_id: str, filename: str):
return cached_download(url, cache_dir=get_cache_dir('hf')) return cached_download(url, cache_dir=get_cache_dir('hf'))
def load_config_from_hf(model_id: str): def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json') cached_file = _download_from_hf(model_id, 'config.json')
return load_cfg_from_json(cached_file) default_cfg = load_cfg_from_json(cached_file)
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
model_name = default_cfg.get('architecture')
return default_cfg, model_name
def load_state_dict_from_hf(model_id: str): def load_state_dict_from_hf(model_id: str):

@ -21,6 +21,7 @@ import math
import logging import logging
from functools import partial from functools import partial
from collections import OrderedDict from collections import OrderedDict
from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
default_cfg = overlay_external_default_cfg(kwargs, default_cfgs[variant]) default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes'] default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-1] default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes) num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size) img_size = kwargs.pop('img_size', default_img_size)

Loading…
Cancel
Save