You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/_builder.py

396 lines
16 KiB

import dataclasses
import logging
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
from torch import nn as nn
from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
from timm.models._registry import get_pretrained_cfg
_logger = logging.getLogger(__name__)
# Global variables for rarely used pretrained checkpoint download progress and hash check.
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
_DOWNLOAD_PROGRESS = False
_CHECK_HASH = False
def _resolve_pretrained_source(pretrained_cfg):
cfg_source = pretrained_cfg.get('source', '')
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
# hf-hub specified as source via model identifier
load_from = 'hf-hub'
assert hf_hub_id
pretrained_loc = hf_hub_id
else:
# default source == timm or unspecified
if pretrained_file:
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
return load_from, pretrained_loc
def set_pretrained_download_progress(enable=True):
""" Set download progress for pretrained weights on/off (globally). """
global _DOWNLOAD_PROGRESS
_DOWNLOAD_PROGRESS = enable
def set_pretrained_check_hash(enable=True):
""" Set hash checking for pretrained weights on/off (globally). """
global _CHECK_HASH
_CHECK_HASH = enable
def load_custom_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
load_fn: Optional[Callable] = None,
):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
a passed in custom load fun, or the `load_pretrained` model member fn.
If the object is already present in `model_dir`, it's deserialized and returned.
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
Args:
model: The instantiated model to load weights into
pretrained_cfg (dict): Default pretrained model cfg
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if not load_from:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if load_from == 'hf-hub': # FIXME
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
elif load_from == 'url':
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
)
if load_fn is not None:
load_fn(model, pretrained_loc)
elif hasattr(model, 'load_pretrained'):
model.load_pretrained(pretrained_loc)
else:
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
def load_pretrained(
model: nn.Module,
pretrained_cfg: Optional[Dict] = None,
num_classes: int = 1000,
in_chans: int = 3,
filter_fn: Optional[Callable] = None,
strict: bool = True,
):
""" Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
num_classes (int): num_classes for target model
in_chans (int): in_chans for target model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
"""
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
if not pretrained_cfg:
_logger.warning("Invalid pretrained config, cannot load weights.")
return
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
if load_from == 'file':
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
state_dict = load_state_dict_from_hf(*pretrained_loc)
else:
state_dict = load_state_dict_from_hf(pretrained_loc)
else:
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
return
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str):
input_convs = (input_convs,)
for input_conv_name in input_convs:
weight_name = input_conv_name + '.weight'
try:
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
_logger.info(
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
except NotImplementedError as e:
del state_dict[weight_name]
strict = False
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifiers = pretrained_cfg.get('classifier', None)
label_offset = pretrained_cfg.get('label_offset', 0)
if classifiers is not None:
if isinstance(classifiers, str):
classifiers = (classifiers,)
if num_classes != pretrained_cfg['num_classes']:
for classifier_name in classifiers:
# completely discard fully connected if model num_classes doesn't match pretrained weights
state_dict.pop(classifier_name + '.weight', None)
state_dict.pop(classifier_name + '.bias', None)
strict = False
elif label_offset > 0:
for classifier_name in classifiers:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)
def pretrained_cfg_for_features(pretrained_cfg):
pretrained_cfg = deepcopy(pretrained_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove:
pretrained_cfg.pop(tr, None)
return pretrained_cfg
def _filter_kwargs(kwargs, names):
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
Args:
pretrained_cfg: input pretrained cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if pretrained_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
for n in default_kwarg_names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# pretrained_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = pretrained_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = pretrained_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, pretrained_cfg[n])
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
_filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(
variant: str,
pretrained_cfg=None,
pretrained_cfg_overlay=None,
) -> PretrainedCfg:
model_with_tag = variant
pretrained_tag = None
if pretrained_cfg:
if isinstance(pretrained_cfg, dict):
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
elif isinstance(pretrained_cfg, str):
pretrained_tag = pretrained_cfg
pretrained_cfg = None
# fallback to looking up pretrained cfg in model registry by variant identifier
if not pretrained_cfg:
if pretrained_tag:
model_with_tag = '.'.join([variant, pretrained_tag])
pretrained_cfg = get_pretrained_cfg(model_with_tag)
if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = PretrainedCfg() # instance with defaults
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
if not pretrained_cfg.architecture:
pretrained_cfg_overlay.setdefault('architecture', variant)
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
return pretrained_cfg
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
pretrained_cfg: Optional[Dict] = None,
pretrained_cfg_overlay: Optional[Dict] = None,
model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs,
):
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including:
* handling default_cfg and associated pretrained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
pretrained_cfg (dict): model's pretrained weight/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
# resolve and update model pretrained config and model kwargs
pretrained_cfg = resolve_pretrained_cfg(
variant,
pretrained_cfg=pretrained_cfg,
pretrained_cfg_overlay=pretrained_cfg_overlay
)
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
pretrained_cfg = pretrained_cfg.to_dict()
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
# Instantiate the model
if model_cfg is None:
model = model_cls(**kwargs)
else:
model = model_cls(cfg=model_cfg, **kwargs)
model.pretrained_cfg = pretrained_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
if pruned:
model = adapt_model_from_file(model, variant)
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_cfg.get('custom_load', False):
load_custom_pretrained(
model,
pretrained_cfg=pretrained_cfg,
)
else:
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:
feature_cls = feature_cfg.pop('feature_cls')
if isinstance(feature_cls, str):
feature_cls = feature_cls.lower()
if 'hook' in feature_cls:
feature_cls = FeatureHookNet
elif feature_cls == 'fx':
feature_cls = FeatureGraphNet
else:
assert False, f'Unknown feature class {feature_cls}'
model = feature_cls(model, **feature_cfg)
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
model.default_cfg = model.pretrained_cfg # alias for backwards compat
return model