From 45c048ba135f120e7a1928438e378b7b68ae16f3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 17 Mar 2021 13:18:52 -0700 Subject: [PATCH] A few minor fixes and bit more cleanup on the huggingface hub integration. --- timm/models/factory.py | 10 +++---- timm/models/helpers.py | 47 ++++++++++++++++++------------- timm/models/hub.py | 13 +++++---- timm/models/vision_transformer.py | 6 ++-- 4 files changed, 43 insertions(+), 33 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index d8cca264..d040a9ff 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,7 +1,7 @@ from .registry import is_model, is_model_in_modules, model_entrypoint from .helpers import load_checkpoint from .layers import set_layer_config -from .hub import load_config_from_hf +from .hub import load_model_config_from_hf def split_model_name(model_name): @@ -67,11 +67,9 @@ def create_model( kwargs = {k: v for k, v in kwargs.items() if v is not None} if source_name == 'hf_hub': - # Load model weights + default_cfg from Hugging Face hub. - # For model names specified in the form `hf_hub:path/architecture_name#revision` - hf_default_cfg = load_config_from_hf(model_name) - hf_default_cfg['hf_hub'] = model_name # insert hf_hub id for pretrained weight load during creation - model_name = hf_default_cfg.get('architecture') + # For model names specified in the form `hf_hub:path/architecture_name#revision`, + # load model weights + default_cfg from Hugging Face hub. + hf_default_cfg, model_name = load_model_config_from_hf(model_name) kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday if is_model(model_name): diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 8e2809f3..2f6e098b 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -323,17 +323,14 @@ def default_cfg_for_features(default_cfg): return default_cfg -def overlay_external_default_cfg(kwargs, default_cfg): - """ Overlay 'default_cfg' in kwargs on top of default_cfg arg. +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. """ - default_cfg = default_cfg or {} external_default_cfg = kwargs.pop('external_default_cfg', None) if external_default_cfg: - default_cfg = deepcopy(default_cfg) default_cfg.pop('url', None) # url should come from external cfg default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg default_cfg.update(external_default_cfg) - return default_cfg def set_default_kwargs(kwargs, names, default_cfg): @@ -344,7 +341,7 @@ def set_default_kwargs(kwargs, names, default_cfg): input_size = default_cfg.get('input_size', None) if input_size is not None: assert len(input_size) == 3 - kwargs.setdefault(n, input_size[:-2]) + kwargs.setdefault(n, input_size[-2:]) elif n == 'in_chans': input_size = default_cfg.get('input_size', None) if input_size is not None: @@ -363,6 +360,25 @@ def filter_kwargs(kwargs, names): kwargs.pop(n, None) +def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs + could/should be replaced by an improved configuration mechanism + + Args: + default_cfg: input default_cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Overlay default cfg values from `external_default_cfg` if it exists in kwargs + overlay_external_default_cfg(default_cfg, kwargs) + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + filter_kwargs(kwargs, names=kwargs_filter) + + def build_model_with_cfg( model_cls: Callable, variant: str, @@ -399,29 +415,20 @@ def build_model_with_cfg( pruned = kwargs.pop('pruned', False) features = False feature_cfg = feature_cfg or {} + default_cfg = deepcopy(default_cfg) if default_cfg else {} + update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) + default_cfg.setdefault('architecture', variant) - # Setup for featyre extraction wrapper done at end of this fn + # Setup for feature extraction wrapper done at end of this fn if kwargs.pop('features_only', False): features = True feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) if 'out_indices' in kwargs: feature_cfg['out_indices'] = kwargs.pop('out_indices') - # FIXME this next sequence of overlay default_cfg, set default kwargs, filter kwargs - # could/should be replaced by an improved configuration mechanism - - # Overlay default cfg values from `external_default_cfg` if it exists in kwargs - default_cfg = overlay_external_default_cfg(kwargs, default_cfg) - - # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) - - # Filter keyword args for task specific model variants (some 'features only' models, etc.) - filter_kwargs(kwargs, names=kwargs_filter) - # Build the model model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) - model.default_cfg = deepcopy(default_cfg) + model.default_cfg = default_cfg if pruned: model = adapt_model_from_file(model, variant) diff --git a/timm/models/hub.py b/timm/models/hub.py index e255e39b..69365a0c 100644 --- a/timm/models/hub.py +++ b/timm/models/hub.py @@ -23,7 +23,7 @@ except ImportError: _logger = logging.getLogger(__name__) -def get_cache_dir(child=''): +def get_cache_dir(child_dir=''): """ Returns the location of the directory where models are cached (and creates it if necessary). """ @@ -32,8 +32,8 @@ def get_cache_dir(child=''): _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') hub_dir = get_dir() - children = () if not child else child, - model_dir = os.path.join(hub_dir, 'checkpoints', *children) + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) os.makedirs(model_dir, exist_ok=True) return model_dir @@ -80,10 +80,13 @@ def _download_from_hf(model_id: str, filename: str): return cached_download(url, cache_dir=get_cache_dir('hf')) -def load_config_from_hf(model_id: str): +def load_model_config_from_hf(model_id: str): assert has_hf_hub(True) cached_file = _download_from_hf(model_id, 'config.json') - return load_cfg_from_json(cached_file) + default_cfg = load_cfg_from_json(cached_file) + default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation + model_name = default_cfg.get('architecture') + return default_cfg, model_name def load_state_dict_from_hf(model_id: str): diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b0bbdb09..82d4ee49 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -21,6 +21,7 @@ import math import logging from functools import partial from collections import OrderedDict +from copy import deepcopy import torch import torch.nn as nn @@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model): def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): - default_cfg = overlay_external_default_cfg(kwargs, default_cfgs[variant]) + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-1] + default_img_size = default_cfg['input_size'][-2:] num_classes = kwargs.pop('num_classes', default_num_classes) img_size = kwargs.pop('img_size', default_img_size)