Merge pull request #1581 from rwightman/refactor-imports
Major module / path restructurepull/1222/merge
commit
e98c93264c
@ -1,4 +1,3 @@
|
||||
dependencies = ['torch']
|
||||
from timm.models import registry
|
||||
|
||||
globals().update(registry._model_entrypoints)
|
||||
import timm
|
||||
globals().update(timm.models._registry._model_entrypoints)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from .version import __version__
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
|
||||
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable, \
|
||||
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
@ -0,0 +1,44 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm import get_norm_layer, create_norm_layer
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
@ -0,0 +1,399 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Callable, Any, Tuple
|
||||
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
|
||||
from timm.models._manipulate import adapt_input_conv
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._prune import adapt_model_from_file
|
||||
from timm.models._registry import get_pretrained_cfg
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variables for rarely used pretrained checkpoint download progress and hash check.
|
||||
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
|
||||
_DOWNLOAD_PROGRESS = False
|
||||
_CHECK_HASH = False
|
||||
|
||||
|
||||
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
|
||||
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
|
||||
|
||||
|
||||
def _resolve_pretrained_source(pretrained_cfg):
|
||||
cfg_source = pretrained_cfg.get('source', '')
|
||||
pretrained_url = pretrained_cfg.get('url', None)
|
||||
pretrained_file = pretrained_cfg.get('file', None)
|
||||
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
||||
# resolve where to load pretrained weights from
|
||||
load_from = ''
|
||||
pretrained_loc = ''
|
||||
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
|
||||
# hf-hub specified as source via model identifier
|
||||
load_from = 'hf-hub'
|
||||
assert hf_hub_id
|
||||
pretrained_loc = hf_hub_id
|
||||
else:
|
||||
# default source == timm or unspecified
|
||||
if pretrained_file:
|
||||
load_from = 'file'
|
||||
pretrained_loc = pretrained_file
|
||||
elif pretrained_url:
|
||||
load_from = 'url'
|
||||
pretrained_loc = pretrained_url
|
||||
elif hf_hub_id and has_hf_hub(necessary=True):
|
||||
# hf-hub available as alternate weight source in default_cfg
|
||||
load_from = 'hf-hub'
|
||||
pretrained_loc = hf_hub_id
|
||||
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
|
||||
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
||||
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
||||
return load_from, pretrained_loc
|
||||
|
||||
|
||||
def set_pretrained_download_progress(enable=True):
|
||||
""" Set download progress for pretrained weights on/off (globally). """
|
||||
global _DOWNLOAD_PROGRESS
|
||||
_DOWNLOAD_PROGRESS = enable
|
||||
|
||||
|
||||
def set_pretrained_check_hash(enable=True):
|
||||
""" Set hash checking for pretrained weights on/off (globally). """
|
||||
global _CHECK_HASH
|
||||
_CHECK_HASH = enable
|
||||
|
||||
|
||||
def load_custom_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
load_fn: Optional[Callable] = None,
|
||||
):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
||||
a passed in custom load fun, or the `load_pretrained` model member fn.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and returned.
|
||||
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
||||
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
||||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
pretrained_cfg (dict): Default pretrained model cfg
|
||||
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if not load_from:
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
if load_from == 'hf-hub': # FIXME
|
||||
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
|
||||
elif load_from == 'url':
|
||||
pretrained_loc = download_cached_file(
|
||||
pretrained_loc,
|
||||
check_hash=_CHECK_HASH,
|
||||
progress=_DOWNLOAD_PROGRESS
|
||||
)
|
||||
|
||||
if load_fn is not None:
|
||||
load_fn(model, pretrained_loc)
|
||||
elif hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
||||
|
||||
|
||||
def load_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
strict: bool = True,
|
||||
):
|
||||
""" Load pretrained checkpoint
|
||||
|
||||
Args:
|
||||
model (nn.Module) : PyTorch model module
|
||||
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
|
||||
num_classes (int): num_classes for target model
|
||||
in_chans (int): in_chans for target model
|
||||
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
||||
strict (bool): strict load of checkpoint
|
||||
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if load_from == 'file':
|
||||
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
||||
state_dict = load_state_dict(pretrained_loc)
|
||||
elif load_from == 'url':
|
||||
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
)
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
|
||||
return
|
||||
|
||||
if filter_fn is not None:
|
||||
# for backwards compat with filter fn that take one arg, try one first, the two
|
||||
try:
|
||||
state_dict = filter_fn(state_dict)
|
||||
except TypeError:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
|
||||
input_convs = pretrained_cfg.get('first_conv', None)
|
||||
if input_convs is not None and in_chans != 3:
|
||||
if isinstance(input_convs, str):
|
||||
input_convs = (input_convs,)
|
||||
for input_conv_name in input_convs:
|
||||
weight_name = input_conv_name + '.weight'
|
||||
try:
|
||||
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
|
||||
_logger.info(
|
||||
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
|
||||
except NotImplementedError as e:
|
||||
del state_dict[weight_name]
|
||||
strict = False
|
||||
_logger.warning(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||
|
||||
classifiers = pretrained_cfg.get('classifier', None)
|
||||
label_offset = pretrained_cfg.get('label_offset', 0)
|
||||
if classifiers is not None:
|
||||
if isinstance(classifiers, str):
|
||||
classifiers = (classifiers,)
|
||||
if num_classes != pretrained_cfg['num_classes']:
|
||||
for classifier_name in classifiers:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
state_dict.pop(classifier_name + '.weight', None)
|
||||
state_dict.pop(classifier_name + '.bias', None)
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
for classifier_name in classifiers:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def pretrained_cfg_for_features(pretrained_cfg):
|
||||
pretrained_cfg = deepcopy(pretrained_cfg)
|
||||
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
||||
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
|
||||
for tr in to_remove:
|
||||
pretrained_cfg.pop(tr, None)
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
def _filter_kwargs(kwargs, names):
|
||||
if not kwargs or not names:
|
||||
return
|
||||
for n in names:
|
||||
kwargs.pop(n, None)
|
||||
|
||||
|
||||
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
||||
""" Update the default_cfg and kwargs before passing to model
|
||||
|
||||
Args:
|
||||
pretrained_cfg: input pretrained cfg (updated in-place)
|
||||
kwargs: keyword args passed to model build fn (updated in-place)
|
||||
kwargs_filter: keyword arg keys that must be removed before model __init__
|
||||
"""
|
||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
|
||||
if pretrained_cfg.get('fixed_input_size', False):
|
||||
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
|
||||
default_kwarg_names += ('img_size',)
|
||||
|
||||
for n in default_kwarg_names:
|
||||
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
||||
# pretrained_cfg has one input_size=(C, H ,W) entry
|
||||
if n == 'img_size':
|
||||
input_size = pretrained_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[-2:])
|
||||
elif n == 'in_chans':
|
||||
input_size = pretrained_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[0])
|
||||
else:
|
||||
default_val = pretrained_cfg.get(n, None)
|
||||
if default_val is not None:
|
||||
kwargs.setdefault(n, pretrained_cfg[n])
|
||||
|
||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||
_filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def resolve_pretrained_cfg(
|
||||
variant: str,
|
||||
pretrained_cfg=None,
|
||||
pretrained_cfg_overlay=None,
|
||||
) -> PretrainedCfg:
|
||||
model_with_tag = variant
|
||||
pretrained_tag = None
|
||||
if pretrained_cfg:
|
||||
if isinstance(pretrained_cfg, dict):
|
||||
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
|
||||
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
|
||||
elif isinstance(pretrained_cfg, str):
|
||||
pretrained_tag = pretrained_cfg
|
||||
pretrained_cfg = None
|
||||
|
||||
# fallback to looking up pretrained cfg in model registry by variant identifier
|
||||
if not pretrained_cfg:
|
||||
if pretrained_tag:
|
||||
model_with_tag = '.'.join([variant, pretrained_tag])
|
||||
pretrained_cfg = get_pretrained_cfg(model_with_tag)
|
||||
|
||||
if not pretrained_cfg:
|
||||
_logger.warning(
|
||||
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
|
||||
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
|
||||
pretrained_cfg = PretrainedCfg() # instance with defaults
|
||||
|
||||
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
|
||||
if not pretrained_cfg.architecture:
|
||||
pretrained_cfg_overlay.setdefault('architecture', variant)
|
||||
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
|
||||
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
pretrained: bool,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict] = None,
|
||||
model_cfg: Optional[Any] = None,
|
||||
feature_cfg: Optional[Dict] = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Optional[Callable] = None,
|
||||
kwargs_filter: Optional[Tuple[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
""" Build model with specified default_cfg and optional model_cfg
|
||||
|
||||
This helper fn aids in the construction of a model including:
|
||||
* handling default_cfg and associated pretrained weight loading
|
||||
* passing through optional model_cfg for models with config based arch spec
|
||||
* features_only model adaptation
|
||||
* pruning config / model adaptation
|
||||
|
||||
Args:
|
||||
model_cls (nn.Module): model class
|
||||
variant (str): model variant name
|
||||
pretrained (bool): load pretrained weights
|
||||
pretrained_cfg (dict): model's pretrained weight/task config
|
||||
model_cfg (Optional[Dict]): model's architecture config
|
||||
feature_cfg (Optional[Dict]: feature extraction adapter config
|
||||
pretrained_strict (bool): load pretrained weights strictly
|
||||
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
||||
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
||||
**kwargs: model args passed through to model __init__
|
||||
"""
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
feature_cfg = feature_cfg or {}
|
||||
|
||||
# resolve and update model pretrained config and model kwargs
|
||||
pretrained_cfg = resolve_pretrained_cfg(
|
||||
variant,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay
|
||||
)
|
||||
|
||||
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
|
||||
pretrained_cfg = pretrained_cfg.to_dict()
|
||||
|
||||
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||
|
||||
# Setup for feature extraction wrapper done at end of this fn
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||
if 'out_indices' in kwargs:
|
||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||
|
||||
# Instantiate the model
|
||||
if model_cfg is None:
|
||||
model = model_cls(**kwargs)
|
||||
else:
|
||||
model = model_cls(cfg=model_cfg, **kwargs)
|
||||
model.pretrained_cfg = pretrained_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
if pruned:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||
if pretrained:
|
||||
if pretrained_cfg.get('custom_load', False):
|
||||
load_custom_pretrained(
|
||||
model,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
)
|
||||
else:
|
||||
load_pretrained(
|
||||
model,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
num_classes=num_classes_pretrained,
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn,
|
||||
strict=pretrained_strict,
|
||||
)
|
||||
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
return model
|
@ -0,0 +1,103 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from timm.layers import set_layer_config
|
||||
from ._pretrained import PretrainedCfg, split_model_name_tag
|
||||
from ._helpers import load_checkpoint
|
||||
from ._hub import load_model_config_from_hf
|
||||
from ._registry import is_model, model_entrypoint
|
||||
|
||||
|
||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||
|
||||
|
||||
def parse_model_name(model_name):
|
||||
if model_name.startswith('hf_hub'):
|
||||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||
parsed = urlsplit(model_name)
|
||||
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
||||
if parsed.scheme == 'hf-hub':
|
||||
# FIXME may use fragment as revision, currently `@` in URI path
|
||||
return parsed.scheme, parsed.path
|
||||
else:
|
||||
model_name = os.path.split(parsed.path)[-1]
|
||||
return 'timm', model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name, remove_source=True):
|
||||
# return a filename / path safe model name
|
||||
def make_safe(name):
|
||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||
if remove_source:
|
||||
model_name = parse_model_name(model_name)[-1]
|
||||
return make_safe(model_name)
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = '',
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a model
|
||||
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
|
||||
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
|
||||
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
|
||||
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
|
||||
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are consumed by builder or model __init__()
|
||||
"""
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == 'hf-hub':
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
||||
if not is_model(model_name):
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
create_fn = model_entrypoint(model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
model = create_fn(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
@ -0,0 +1,287 @@
|
||||
""" PyTorch Feature Extraction Helpers
|
||||
|
||||
A collection of classes, functions, modules to help extract features from models
|
||||
and provide a common interface for describing them.
|
||||
|
||||
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
prev_reduction = 1
|
||||
for fi in feature_info:
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||
prev_reduction = fi['reduction']
|
||||
assert 'module' in fi
|
||||
self.out_indices = out_indices
|
||||
self.info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i][key] for i in idx]
|
||||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
if keys is None:
|
||||
return [self.info[i] for i in self.out_indices]
|
||||
else:
|
||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
||||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
""" feature channels accessor
|
||||
"""
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
""" feature reduction (output stride) accessor
|
||||
"""
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
""" feature module name accessor
|
||||
"""
|
||||
return self.get('module', idx)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.info[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.info)
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
""" Feature Hook Helper
|
||||
|
||||
This module helps with the setup and extraction of hooks for extracting features from
|
||||
internal nodes in a model by node name. This works quite well in eager Python but needs
|
||||
redesign for torchscript.
|
||||
"""
|
||||
|
||||
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
||||
# setup feature hooks
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
m.register_forward_pre_hook(hook_fn)
|
||||
elif hook_type == 'forward':
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
|
||||
def _collect_output_hook(self, hook_id, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
if isinstance(x, tuple):
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][hook_id] = x
|
||||
|
||||
def get_output(self, device) -> Dict[str, torch.tensor]:
|
||||
output = self._feature_outputs[device]
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
||||
|
||||
def _module_list(module, flatten_sequential=False):
|
||||
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
||||
ml = []
|
||||
for name, module in module.named_children():
|
||||
if flatten_sequential and isinstance(module, nn.Sequential):
|
||||
# first level of Sequential containers is flattened into containing model
|
||||
for child_name, child_module in module.named_children():
|
||||
combined = [name, child_name]
|
||||
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
||||
else:
|
||||
ml.append((name, name, module))
|
||||
return ml
|
||||
|
||||
|
||||
def _get_feature_info(net, out_indices):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
elif isinstance(feature_info, (list, tuple)):
|
||||
return FeatureInfo(net.feature_info, out_indices)
|
||||
else:
|
||||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
def _get_return_layers(feature_info, out_map):
|
||||
module_names = feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i, name in enumerate(module_names):
|
||||
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||
return return_layers
|
||||
|
||||
|
||||
class FeatureDictNet(nn.ModuleDict):
|
||||
""" Feature extractor with OrderedDict return
|
||||
|
||||
Wrap a model and extract features as specified by the out indices, the network is
|
||||
partially re-built from contained modules.
|
||||
|
||||
There is a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
||||
trivial modules like `self.relu = nn.ReLU`.
|
||||
|
||||
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model from which we will extract the features
|
||||
out_indices (tuple[int]): model output indices to extract features for
|
||||
out_map (sequence): list or tuple specifying desired return id for each out index,
|
||||
otherwise str(index) is used
|
||||
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.concat = feature_concat
|
||||
self.return_layers = {}
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
if old_name in remaining:
|
||||
# return id has to be consistently str type for torchscript
|
||||
self.return_layers[new_name] = str(return_layers[old_name])
|
||||
remaining.remove(old_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
|
||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
def forward(self, x) -> Dict[str, torch.Tensor]:
|
||||
return self._collect(x)
|
||||
|
||||
|
||||
class FeatureListNet(FeatureDictNet):
|
||||
""" Feature extractor with list return
|
||||
|
||||
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
||||
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureListNet, self).__init__(
|
||||
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential)
|
||||
|
||||
def forward(self, x) -> (List[torch.Tensor]):
|
||||
return list(self._collect(x).values())
|
||||
|
||||
|
||||
class FeatureHookNet(nn.ModuleDict):
|
||||
""" FeatureHookNet
|
||||
|
||||
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
||||
|
||||
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
||||
network in any way.
|
||||
|
||||
If `no_rewrite` is False, the model will be re-written as in the
|
||||
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
||||
|
||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
||||
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
||||
super(FeatureHookNet, self).__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
assert not flatten_sequential
|
||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||
model.reset_classifier(0)
|
||||
layers['body'] = model
|
||||
hooks.extend(self.feature_info.get_dicts())
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info.get_dicts()}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
if fn in remaining:
|
||||
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
||||
del remaining[fn]
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def forward(self, x):
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
@ -0,0 +1,110 @@
|
||||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable, List, Dict, Union, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ._features import _get_feature_info
|
||||
|
||||
try:
|
||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
BilinearAttnTransform, # reason: flow control t <= 1
|
||||
# Reason: get_same_padding has a max which raises a control flow error
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||
}
|
||||
|
||||
try:
|
||||
from timm.layers import InplaceAbn
|
||||
_leaf_modules.add(InplaceAbn)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||
'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
"""
|
||||
_leaf_modules.add(module)
|
||||
return module
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
||||
def register_notrace_function(func: Callable):
|
||||
"""
|
||||
Decorator for functions which ought not to be traced through
|
||||
"""
|
||||
_autowrap_functions.add(func)
|
||||
return func
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
model, return_nodes,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||
)
|
||||
|
||||
|
||||
class FeatureGraphNet(nn.Module):
|
||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||
"""
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
if out_map is not None:
|
||||
assert len(out_map) == len(out_indices)
|
||||
return_nodes = {
|
||||
info['module']: out_map[i] if out_map is not None else info['module']
|
||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
||||
|
||||
|
||||
class GraphExtractNet(nn.Module):
|
||||
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
||||
NOTE:
|
||||
* one can use feature_extractor directly if dictionary output is desired
|
||||
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
||||
metadata for builtin feature extraction mode
|
||||
* create_feature_extractor can be used directly if dictionary output is acceptable
|
||||
|
||||
Args:
|
||||
model: model to extract features from
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
"""
|
||||
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
out = list(self.graph_module(x).values())
|
||||
if self.squeeze_out and len(out) == 1:
|
||||
return out[0]
|
||||
return out
|
@ -0,0 +1,115 @@
|
||||
""" Model creation / weight loading / state_dict helpers
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
cleaned_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k
|
||||
cleaned_state_dict[name] = v
|
||||
return cleaned_state_dict
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=True):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
||||
state_dict_key = 'state_dict_ema'
|
||||
elif use_ema and checkpoint.get('model_ema', None) is not None:
|
||||
state_dict_key = 'model_ema'
|
||||
elif 'state_dict' in checkpoint:
|
||||
state_dict_key = 'state_dict'
|
||||
elif 'model' in checkpoint:
|
||||
state_dict_key = 'model'
|
||||
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
|
||||
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
||||
return state_dict
|
||||
else:
|
||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
timm.models._model_builder.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
out_dict = {}
|
||||
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
||||
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
if va.shape != vb.shape:
|
||||
if allow_reshape:
|
||||
vb = vb.reshape(va.shape)
|
||||
else:
|
||||
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
out_dict[ka] = vb
|
||||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring model state from checkpoint...')
|
||||
state_dict = clean_state_dict(checkpoint['state_dict'])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if optimizer is not None and 'optimizer' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring optimizer state from checkpoint...')
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||||
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
||||
|
||||
if 'epoch' in checkpoint:
|
||||
resume_epoch = checkpoint['epoch']
|
||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return resume_epoch
|
||||
else:
|
||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
@ -0,0 +1,220 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
from timm import __version__
|
||||
from timm.models._pretrained import filter_pretrained_cfg
|
||||
|
||||
try:
|
||||
from huggingface_hub import (
|
||||
create_repo, get_hf_file_metadata,
|
||||
hf_hub_download, hf_hub_url,
|
||||
repo_type_and_id_from_hf_id, upload_folder)
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||
_has_hf_hub = True
|
||||
except ImportError:
|
||||
hf_hub_download = None
|
||||
_has_hf_hub = False
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
child_dir = () if not child_dir else (child_dir,)
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||
return _has_hf_hub
|
||||
|
||||
|
||||
def hf_split(hf_id):
|
||||
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
|
||||
rev_split = hf_id.split('@')
|
||||
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
||||
hf_model_id = rev_split[0]
|
||||
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
||||
return hf_model_id, hf_revision
|
||||
|
||||
|
||||
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def _download_from_hf(model_id: str, filename: str):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
||||
|
||||
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, 'config.json')
|
||||
|
||||
hf_config = load_cfg_from_json(cached_file)
|
||||
if 'pretrained_cfg' not in hf_config:
|
||||
# old form, pull pretrain_cfg out of the base dict
|
||||
pretrained_cfg = hf_config
|
||||
hf_config = {}
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
||||
if 'labels' in pretrained_cfg:
|
||||
hf_config['label_name'] = pretrained_cfg.pop('labels')
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
|
||||
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
||||
pretrained_cfg = hf_config['pretrained_cfg']
|
||||
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
pretrained_cfg['source'] = 'hf-hub'
|
||||
if 'num_classes' in hf_config:
|
||||
# model should be created with parent num_classes if they exist
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
model_name = hf_config['architecture']
|
||||
|
||||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = _download_from_hf(model_id, filename)
|
||||
state_dict = torch.load(cached_file, map_location='cpu')
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_for_hf(model, save_directory, model_config=None):
|
||||
assert has_hf_hub(True)
|
||||
model_config = model_config or {}
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
weights_path = save_directory / 'pytorch_model.bin'
|
||||
torch.save(model.state_dict(), weights_path)
|
||||
|
||||
config_path = save_directory / 'config.json'
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
# set some values at root config level
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
|
||||
if 'label' in model_config:
|
||||
_logger.warning(
|
||||
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
|
||||
"Using provided 'label' field as 'label_name'.")
|
||||
model_config['label_name'] = model_config.pop('label')
|
||||
|
||||
label_name = model_config.pop('label_name', None)
|
||||
if label_name:
|
||||
assert isinstance(label_name, (dict, list, tuple))
|
||||
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
|
||||
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
|
||||
hf_config['label_name'] = model_config['label_name']
|
||||
|
||||
display_name = model_config.pop('display_name', None)
|
||||
if display_name:
|
||||
assert isinstance(display_name, dict)
|
||||
# map label_name -> user interface display name
|
||||
hf_config['display_name'] = model_config['display_name']
|
||||
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
hf_config.update(model_config)
|
||||
|
||||
with config_path.open('w') as f:
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
):
|
||||
# Create repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
# Infer complete repo_id from repo_url
|
||||
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||
repo_id = f"{repo_owner}/{repo_name}"
|
||||
|
||||
# Check if README file already exist in repo
|
||||
try:
|
||||
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||
has_readme = True
|
||||
except EntryNotFoundError:
|
||||
has_readme = False
|
||||
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, tmpdir, model_config=model_config)
|
||||
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
model_name = repo_id.split('/')[-1]
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
# Upload model and return
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=tmpdir,
|
||||
revision=revision,
|
||||
create_pr=create_pr,
|
||||
commit_message=commit_message,
|
||||
)
|
@ -0,0 +1,258 @@
|
||||
import collections.abc
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Callable, Union, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
|
||||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
if exclude_head:
|
||||
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
||||
return [p for p in model.parameters()][:-2]
|
||||
else:
|
||||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
yield name, module
|
||||
|
||||
|
||||
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if module._parameters and not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules_with_params(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if module._parameters and depth_first and include_root:
|
||||
yield name, module
|
||||
|
||||
|
||||
MATCH_PREV_GROUP = (99999,)
|
||||
|
||||
|
||||
def group_with_matcher(
|
||||
named_objects,
|
||||
group_matcher: Union[Dict, Callable],
|
||||
output_values: bool = False,
|
||||
reverse: bool = False
|
||||
):
|
||||
if isinstance(group_matcher, dict):
|
||||
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
|
||||
compiled = []
|
||||
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
|
||||
if mspec is None:
|
||||
continue
|
||||
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
|
||||
if isinstance(mspec, (tuple, list)):
|
||||
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
|
||||
for sspec in mspec:
|
||||
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
|
||||
else:
|
||||
compiled += [(re.compile(mspec), (group_ordinal,), None)]
|
||||
group_matcher = compiled
|
||||
|
||||
def _get_grouping(name):
|
||||
if isinstance(group_matcher, (list, tuple)):
|
||||
for match_fn, prefix, suffix in group_matcher:
|
||||
r = match_fn.match(name)
|
||||
if r:
|
||||
parts = (prefix, r.groups(), suffix)
|
||||
# map all tuple elem to int for numeric sort, filter out None entries
|
||||
return tuple(map(float, chain.from_iterable(filter(None, parts))))
|
||||
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
|
||||
else:
|
||||
ord = group_matcher(name)
|
||||
if not isinstance(ord, collections.abc.Iterable):
|
||||
return ord,
|
||||
return tuple(ord)
|
||||
|
||||
# map layers into groups via ordinals (ints or tuples of ints) from matcher
|
||||
grouping = defaultdict(list)
|
||||
for k, v in named_objects:
|
||||
grouping[_get_grouping(k)].append(v if output_values else k)
|
||||
|
||||
# remap to integers
|
||||
layer_id_to_param = defaultdict(list)
|
||||
lid = -1
|
||||
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
|
||||
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
|
||||
lid += 1
|
||||
layer_id_to_param[lid].extend(grouping[k])
|
||||
|
||||
if reverse:
|
||||
assert not output_values, "reverse mapping only sensible for name output"
|
||||
# output reverse mapping
|
||||
param_to_layer_id = {}
|
||||
for lid, lm in layer_id_to_param.items():
|
||||
for n in lm:
|
||||
param_to_layer_id[n] = lid
|
||||
return param_to_layer_id
|
||||
|
||||
return layer_id_to_param
|
||||
|
||||
|
||||
def group_parameters(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
|
||||
|
||||
|
||||
def group_modules(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
|
||||
|
||||
|
||||
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
|
||||
prefix_is_tuple = isinstance(prefix, tuple)
|
||||
if isinstance(module_types, str):
|
||||
if module_types == 'container':
|
||||
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
|
||||
else:
|
||||
module_types = (nn.Sequential,)
|
||||
for name, module in named_modules:
|
||||
if depth and isinstance(module, module_types):
|
||||
yield from flatten_modules(
|
||||
module.named_children(),
|
||||
depth - 1,
|
||||
prefix=(name,) if prefix_is_tuple else name,
|
||||
module_types=module_types,
|
||||
)
|
||||
else:
|
||||
if prefix_is_tuple:
|
||||
name = prefix + (name,)
|
||||
yield name, module
|
||||
else:
|
||||
if prefix:
|
||||
name = '.'.join([prefix, name])
|
||||
yield name, module
|
||||
|
||||
|
||||
def checkpoint_seq(
|
||||
functions,
|
||||
x,
|
||||
every=1,
|
||||
flatten=False,
|
||||
skip_last=False,
|
||||
preserve_rng_state=True
|
||||
):
|
||||
r"""A helper function for checkpointing sequential models.
|
||||
|
||||
Sequential models execute a list of modules/functions in order
|
||||
(sequentially). Therefore, we can divide such a sequence into segments
|
||||
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
|
||||
manner, i.e., not storing the intermediate activations. The inputs of each
|
||||
checkpointed segment will be saved for re-running the segment in the backward pass.
|
||||
|
||||
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
||||
|
||||
.. warning::
|
||||
Checkpointing currently only supports :func:`torch.autograd.backward`
|
||||
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
||||
is not supported.
|
||||
|
||||
.. warning:
|
||||
At least one of the inputs needs to have :code:`requires_grad=True` if
|
||||
grads are needed for model inputs, otherwise the checkpointed part of the
|
||||
model won't have gradients.
|
||||
|
||||
Args:
|
||||
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
|
||||
x: A Tensor that is input to :attr:`functions`
|
||||
every: checkpoint every-n functions (default: 1)
|
||||
flatten (bool): flatten nn.Sequential of nn.Sequentials
|
||||
skip_last (bool): skip checkpointing the last function in the sequence if True
|
||||
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
|
||||
Returns:
|
||||
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
||||
|
||||
Example:
|
||||
>>> model = nn.Sequential(...)
|
||||
>>> input_var = checkpoint_seq(model, input_var, every=2)
|
||||
"""
|
||||
def run_function(start, end, functions):
|
||||
def forward(_x):
|
||||
for j in range(start, end + 1):
|
||||
_x = functions[j](_x)
|
||||
return _x
|
||||
return forward
|
||||
|
||||
if isinstance(functions, torch.nn.Sequential):
|
||||
functions = functions.children()
|
||||
if flatten:
|
||||
functions = chain.from_iterable(functions)
|
||||
if not isinstance(functions, (tuple, list)):
|
||||
functions = tuple(functions)
|
||||
|
||||
num_checkpointed = len(functions)
|
||||
if skip_last:
|
||||
num_checkpointed -= 1
|
||||
end = -1
|
||||
for start in range(0, num_checkpointed, every):
|
||||
end = min(start + every - 1, num_checkpointed - 1)
|
||||
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
|
||||
if skip_last:
|
||||
return run_function(end + 1, len(functions) - 1, functions)(x)
|
||||
return x
|
||||
|
||||
|
||||
def adapt_input_conv(in_chans, conv_weight):
|
||||
conv_type = conv_weight.dtype
|
||||
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
O, I, J, K = conv_weight.shape
|
||||
if in_chans == 1:
|
||||
if I > 3:
|
||||
assert conv_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
if I != 3:
|
||||
raise NotImplementedError('Weight format not supported by conversion.')
|
||||
else:
|
||||
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||
# the original RGB input layer weights that'd work better for specific cases.
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv_weight *= (3 / float(in_chans))
|
||||
conv_weight = conv_weight.to(conv_type)
|
||||
return conv_weight
|
@ -0,0 +1,113 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
||||
|
||||
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
|
||||
|
||||
|
||||
def extract_layer(model, layer):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
if not hasattr(model, 'module') and layer[0] == 'module':
|
||||
layer = layer[1:]
|
||||
for l in layer:
|
||||
if hasattr(module, l):
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
else:
|
||||
return module
|
||||
return module
|
||||
|
||||
|
||||
def set_layer(model, layer, val):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
lst_index = 0
|
||||
module2 = module
|
||||
for l in layer:
|
||||
if hasattr(module2, l):
|
||||
if not l.isdigit():
|
||||
module2 = getattr(module2, l)
|
||||
else:
|
||||
module2 = module2[int(l)]
|
||||
lst_index += 1
|
||||
lst_index -= 1
|
||||
for l in layer[:lst_index]:
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
l = layer[lst_index]
|
||||
setattr(module, l, val)
|
||||
|
||||
|
||||
def adapt_model_from_string(parent_module, model_string):
|
||||
separator = '***'
|
||||
state_dict = {}
|
||||
lst_shape = model_string.split(separator)
|
||||
for k in lst_shape:
|
||||
k = k.split(':')
|
||||
key = k[0]
|
||||
shape = k[1][1:-1].split(',')
|
||||
if shape[0] != '':
|
||||
state_dict[key] = [int(i) for i in shape]
|
||||
|
||||
new_module = deepcopy(parent_module)
|
||||
for n, m in parent_module.named_modules():
|
||||
old_module = extract_layer(parent_module, n)
|
||||
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
||||
if isinstance(old_module, Conv2dSame):
|
||||
conv = Conv2dSame
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
s = state_dict[n + '.weight']
|
||||
in_channels = s[1]
|
||||
out_channels = s[0]
|
||||
g = 1
|
||||
if old_module.groups > 1:
|
||||
in_channels = out_channels
|
||||
g = in_channels
|
||||
new_conv = conv(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
||||
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
||||
groups=g, stride=old_module.stride)
|
||||
set_layer(new_module, n, new_conv)
|
||||
elif isinstance(old_module, BatchNormAct2d):
|
||||
new_bn = BatchNormAct2d(
|
||||
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
new_bn.drop = old_module.drop
|
||||
new_bn.act = old_module.act
|
||||
set_layer(new_module, n, new_bn)
|
||||
elif isinstance(old_module, nn.BatchNorm2d):
|
||||
new_bn = nn.BatchNorm2d(
|
||||
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
set_layer(new_module, n, new_bn)
|
||||
elif isinstance(old_module, nn.Linear):
|
||||
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
||||
num_features = state_dict[n + '.weight'][1]
|
||||
new_fc = Linear(
|
||||
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
||||
set_layer(new_module, n, new_fc)
|
||||
if hasattr(new_module, 'num_features'):
|
||||
new_module.num_features = num_features
|
||||
new_module.eval()
|
||||
parent_module.eval()
|
||||
|
||||
return new_module
|
||||
|
||||
|
||||
def adapt_model_from_file(parent_module, model_variant):
|
||||
adapt_file = os.path.join(os.path.dirname(__file__), '_pruned', model_variant + '.txt')
|
||||
with open(adapt_file, 'r') as f:
|
||||
return adapt_model_from_string(parent_module, f.read().strip())
|
@ -0,0 +1,212 @@
|
||||
""" Model Registry
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
__all__ = [
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
|
||||
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
|
||||
# add model to __all__ in module
|
||||
model_name = fn.__name__
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(model_name)
|
||||
else:
|
||||
mod.__all__ = [model_name]
|
||||
|
||||
# add entries to registry dict/sets
|
||||
_model_entrypoints[model_name] = fn
|
||||
_model_to_module[model_name] = module_name
|
||||
_module_to_models[module_name].add(model_name)
|
||||
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||
# entrypoints or non-matching combos
|
||||
cfg = mod.default_cfgs[model_name]
|
||||
if not isinstance(cfg, DefaultCfg):
|
||||
# new style default cfg dataclass w/ multiple entries per model-arch
|
||||
assert isinstance(cfg, dict)
|
||||
# old style cfg dict per model-arch
|
||||
cfg = PretrainedCfg(**cfg)
|
||||
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
|
||||
|
||||
for tag_idx, tag in enumerate(cfg.tags):
|
||||
is_default = tag_idx == 0
|
||||
pretrained_cfg = cfg.cfgs[tag]
|
||||
if is_default:
|
||||
_model_pretrained_cfgs[model_name] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add tagless entry if it's default and has weights
|
||||
_model_has_pretrained.add(model_name)
|
||||
if tag:
|
||||
model_name_tag = '.'.join([model_name, tag])
|
||||
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add model w/ tag if tag is valid
|
||||
_model_has_pretrained.add(model_name_tag)
|
||||
_model_with_tags[model_name].append(model_name_tag)
|
||||
else:
|
||||
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
|
||||
|
||||
_model_default_cfgs[model_name] = cfg
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_):
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
pretrained=False,
|
||||
exclude_filters: str = '',
|
||||
name_matches_cfg: bool = False,
|
||||
include_tags: Optional[bool] = None,
|
||||
):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter (str) - Wildcard filter string that works with fnmatch
|
||||
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained (bool) - Include only models with valid pretrained weights if True
|
||||
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
set to True when pretrained=True else False (default: None)
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if include_tags is None:
|
||||
# FIXME should this be default behaviour? or default to include_tags=True?
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
for m in all_models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
models = []
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = set(models).union(include_models)
|
||||
else:
|
||||
models = all_models
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
exclude_filters = [exclude_filters]
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
if len(exclude_models):
|
||||
models = set(models).difference(exclude_models)
|
||||
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
|
||||
if name_matches_cfg:
|
||||
models = set(_model_pretrained_cfgs).intersection(models)
|
||||
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
def list_pretrained(
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: str = '',
|
||||
):
|
||||
return list_models(
|
||||
filter=filter,
|
||||
pretrained=True,
|
||||
exclude_filters=exclude_filters,
|
||||
include_tags=True,
|
||||
)
|
||||
|
||||
|
||||
def is_model(model_name):
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
return arch_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name, module_filter: Optional[str] = None):
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
|
||||
raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
|
||||
return _model_entrypoints[arch_name]
|
||||
|
||||
|
||||
def list_modules():
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return list(sorted(modules))
|
||||
|
||||
|
||||
def is_model_in_modules(model_name, module_names):
|
||||
"""Check if a model exists within a subset of modules
|
||||
Args:
|
||||
model_name (str) - name of model to check
|
||||
module_names (tuple, list, set) - names of modules to search in
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||
|
||||
|
||||
def is_model_pretrained(model_name):
|
||||
return model_name in _model_has_pretrained
|
||||
|
||||
|
||||
def get_pretrained_cfg(model_name):
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
|
||||
|
||||
|
||||
def get_pretrained_cfg_value(model_name, cfg_key):
|
||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||
"""
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
|
||||
raise RuntimeError(f'No pretrained config exist for model {model_name}.')
|
@ -1,100 +1,4 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlsplit
|
||||
from ._factory import *
|
||||
|
||||
from .pretrained import PretrainedCfg, split_model_name_tag
|
||||
from .helpers import load_checkpoint
|
||||
from .hub import load_model_config_from_hf
|
||||
from .layers import set_layer_config
|
||||
from .registry import is_model, model_entrypoint
|
||||
|
||||
|
||||
def parse_model_name(model_name):
|
||||
if model_name.startswith('hf_hub'):
|
||||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||
parsed = urlsplit(model_name)
|
||||
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
||||
if parsed.scheme == 'hf-hub':
|
||||
# FIXME may use fragment as revision, currently `@` in URI path
|
||||
return parsed.scheme, parsed.path
|
||||
else:
|
||||
model_name = os.path.split(parsed.path)[-1]
|
||||
return 'timm', model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name, remove_source=True):
|
||||
# return a filename / path safe model name
|
||||
def make_safe(name):
|
||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||
if remove_source:
|
||||
model_name = parse_model_name(model_name)[-1]
|
||||
return make_safe(model_name)
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = '',
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a model
|
||||
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
|
||||
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
|
||||
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
|
||||
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
|
||||
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are consumed by builder or model __init__()
|
||||
"""
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == 'hf-hub':
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
||||
if not is_model(model_name):
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
create_fn = model_entrypoint(model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
model = create_fn(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
||||
|
@ -1,284 +1,4 @@
|
||||
""" PyTorch Feature Extraction Helpers
|
||||
from ._features import *
|
||||
|
||||
A collection of classes, functions, modules to help extract features from models
|
||||
and provide a common interface for describing them.
|
||||
|
||||
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
prev_reduction = 1
|
||||
for fi in feature_info:
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||
prev_reduction = fi['reduction']
|
||||
assert 'module' in fi
|
||||
self.out_indices = out_indices
|
||||
self.info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i][key] for i in idx]
|
||||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
if keys is None:
|
||||
return [self.info[i] for i in self.out_indices]
|
||||
else:
|
||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
||||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
""" feature channels accessor
|
||||
"""
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
""" feature reduction (output stride) accessor
|
||||
"""
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
""" feature module name accessor
|
||||
"""
|
||||
return self.get('module', idx)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.info[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.info)
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
""" Feature Hook Helper
|
||||
|
||||
This module helps with the setup and extraction of hooks for extracting features from
|
||||
internal nodes in a model by node name. This works quite well in eager Python but needs
|
||||
redesign for torchscript.
|
||||
"""
|
||||
|
||||
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'):
|
||||
# setup feature hooks
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
m.register_forward_pre_hook(hook_fn)
|
||||
elif hook_type == 'forward':
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
|
||||
def _collect_output_hook(self, hook_id, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
if isinstance(x, tuple):
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][hook_id] = x
|
||||
|
||||
def get_output(self, device) -> Dict[str, torch.tensor]:
|
||||
output = self._feature_outputs[device]
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
||||
|
||||
def _module_list(module, flatten_sequential=False):
|
||||
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
||||
ml = []
|
||||
for name, module in module.named_children():
|
||||
if flatten_sequential and isinstance(module, nn.Sequential):
|
||||
# first level of Sequential containers is flattened into containing model
|
||||
for child_name, child_module in module.named_children():
|
||||
combined = [name, child_name]
|
||||
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
||||
else:
|
||||
ml.append((name, name, module))
|
||||
return ml
|
||||
|
||||
|
||||
def _get_feature_info(net, out_indices):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
elif isinstance(feature_info, (list, tuple)):
|
||||
return FeatureInfo(net.feature_info, out_indices)
|
||||
else:
|
||||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
def _get_return_layers(feature_info, out_map):
|
||||
module_names = feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i, name in enumerate(module_names):
|
||||
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||
return return_layers
|
||||
|
||||
|
||||
class FeatureDictNet(nn.ModuleDict):
|
||||
""" Feature extractor with OrderedDict return
|
||||
|
||||
Wrap a model and extract features as specified by the out indices, the network is
|
||||
partially re-built from contained modules.
|
||||
|
||||
There is a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
||||
trivial modules like `self.relu = nn.ReLU`.
|
||||
|
||||
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
|
||||
Arguments:
|
||||
model (nn.Module): model from which we will extract the features
|
||||
out_indices (tuple[int]): model output indices to extract features for
|
||||
out_map (sequence): list or tuple specifying desired return id for each out index,
|
||||
otherwise str(index) is used
|
||||
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
|
||||
vs select element [0]
|
||||
flatten_sequential (bool): whether to flatten sequential modules assigned to model
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.concat = feature_concat
|
||||
self.return_layers = {}
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
if old_name in remaining:
|
||||
# return id has to be consistently str type for torchscript
|
||||
self.return_layers[new_name] = str(return_layers[old_name])
|
||||
remaining.remove(old_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
|
||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||
out = OrderedDict()
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
def forward(self, x) -> Dict[str, torch.Tensor]:
|
||||
return self._collect(x)
|
||||
|
||||
|
||||
class FeatureListNet(FeatureDictNet):
|
||||
""" Feature extractor with list return
|
||||
|
||||
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints.
|
||||
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False):
|
||||
super(FeatureListNet, self).__init__(
|
||||
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential)
|
||||
|
||||
def forward(self, x) -> (List[torch.Tensor]):
|
||||
return list(self._collect(x).values())
|
||||
|
||||
|
||||
class FeatureHookNet(nn.ModuleDict):
|
||||
""" FeatureHookNet
|
||||
|
||||
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
||||
|
||||
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
||||
network in any way.
|
||||
|
||||
If `no_rewrite` is False, the model will be re-written as in the
|
||||
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
||||
|
||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||
"""
|
||||
def __init__(
|
||||
self, model,
|
||||
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False,
|
||||
feature_concat=False, flatten_sequential=False, default_hook_type='forward'):
|
||||
super(FeatureHookNet, self).__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
assert not flatten_sequential
|
||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||
model.reset_classifier(0)
|
||||
layers['body'] = model
|
||||
hooks.extend(self.feature_info.get_dicts())
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info.get_dicts()}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
if fn in remaining:
|
||||
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
||||
del remaining[fn]
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def forward(self, x):
|
||||
for name, module in self.items():
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
||||
|
@ -1,106 +1,4 @@
|
||||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable, List, Dict, Union, Type
|
||||
from ._features_fx import *
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .features import _get_feature_info
|
||||
|
||||
try:
|
||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
||||
from .layers.non_local_attn import BilinearAttnTransform
|
||||
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
BilinearAttnTransform, # reason: flow control t <= 1
|
||||
# Reason: get_same_padding has a max which raises a control flow error
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||
}
|
||||
|
||||
try:
|
||||
from .layers import InplaceAbn
|
||||
_leaf_modules.add(InplaceAbn)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
"""
|
||||
_leaf_modules.add(module)
|
||||
return module
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
||||
def register_notrace_function(func: Callable):
|
||||
"""
|
||||
Decorator for functions which ought not to be traced through
|
||||
"""
|
||||
_autowrap_functions.add(func)
|
||||
return func
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
model, return_nodes,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||
)
|
||||
|
||||
|
||||
class FeatureGraphNet(nn.Module):
|
||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||
"""
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
if out_map is not None:
|
||||
assert len(out_map) == len(out_indices)
|
||||
return_nodes = {
|
||||
info['module']: out_map[i] if out_map is not None else info['module']
|
||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
||||
|
||||
|
||||
class GraphExtractNet(nn.Module):
|
||||
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
||||
NOTE:
|
||||
* one can use feature_extractor directly if dictionary output is desired
|
||||
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
||||
metadata for builtin feature extraction mode
|
||||
* create_feature_extractor can be used directly if dictionary output is acceptable
|
||||
|
||||
Args:
|
||||
model: model to extract features from
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
"""
|
||||
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
out = list(self.graph_module(x).values())
|
||||
if self.squeeze_out and len(out) == 1:
|
||||
return out[0]
|
||||
return out
|
||||
import warnings
|
||||
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue