Merge pull request #501 from rwightman/hf_hub_revisit

Support for huggingface hub via create_model and default_cfgs.
pull/510/head
Ross Wightman 4 years ago committed by GitHub
commit 3eac7dc5a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -509,7 +509,7 @@ for m in model_list:
model.eval()
with torch.no_grad():
# warmup
input = torch.randn((batch_size,) + data_config['input_size']).cuda()
input = torch.randn((batch_size,) + tuple(data_config['input_size'])).cuda()
model(input)
bar = tqdm(desc="Evaluation", mininterval=5, total=50000)

@ -72,8 +72,8 @@ class RandomResizedCropAndInterpolation:
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
interpolation='bilinear'):
if isinstance(size, tuple):
self.size = size
if isinstance(size, (list, tuple)):
self.size = tuple(size)
else:
self.size = (size, size)
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):

@ -78,7 +78,7 @@ def transforms_imagenet_train(
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if isinstance(img_size, tuple):
if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size)
else:
img_size_min = img_size
@ -136,7 +136,7 @@ def transforms_imagenet_eval(
std=IMAGENET_DEFAULT_STD):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, tuple):
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# fall-back to older behaviour so Resize scales to shortest edge if target is square
@ -186,7 +186,7 @@ def create_transform(
tf_preprocessing=False,
separate=False):
if isinstance(input_size, tuple):
if isinstance(input_size, (tuple, list)):
img_size = input_size[-2:]
else:
img_size = input_size

@ -31,7 +31,7 @@ from .xception import *
from .xception_aligned import *
from .hardcorenas import *
from .factory import create_model
from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model

@ -409,8 +409,10 @@ class CspNet(nn.Module):
def _create_cspnet(variant, pretrained=False, **kwargs):
cfg_variant = variant.split('_')[0]
return build_model_with_cfg(
CspNet, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs)
CspNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant],
**kwargs)
@register_model

@ -287,8 +287,10 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs):
kwargs['growth_rate'] = growth_rate
kwargs['block_config'] = block_config
return build_model_with_cfg(
DenseNet, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs)
DenseNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained,
**kwargs)
@register_model

@ -338,8 +338,11 @@ class DLA(nn.Module):
def _create_dla(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
DLA, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs)
DLA, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=False,
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
**kwargs)
@register_model

@ -262,8 +262,10 @@ class DPN(nn.Module):
def _create_dpn(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
DPN, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs)
DPN, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_concat=True, flatten_sequential=True),
**kwargs)
@register_model

@ -452,18 +452,20 @@ class EfficientNetFeatures(nn.Module):
return list(out.values())
def _create_effnet(model_kwargs, variant, pretrained=False):
def _create_effnet(variant, pretrained=False, **kwargs):
features_only = False
model_cls = EfficientNet
if model_kwargs.pop('features_only', False):
kwargs_filter = None
if kwargs.pop('features_only', False):
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')
model_cls = EfficientNetFeatures
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=not features_only, **model_kwargs)
model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
@ -501,7 +503,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -537,7 +539,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs)
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -566,7 +568,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs,variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -595,7 +597,7 @@ def _gen_mobilenet_v2(
act_layer=resolve_act_layer(kwargs, 'relu6'),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -625,7 +627,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -660,7 +662,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -706,7 +708,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -735,7 +737,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
act_layer=resolve_act_layer(kwargs, 'relu'),
**kwargs,
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -765,7 +767,7 @@ def _gen_efficientnet_condconv(
act_layer=resolve_act_layer(kwargs, 'swish'),
**kwargs,
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -806,7 +808,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
norm_kwargs=resolve_bn_args(kwargs),
**kwargs,
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -839,7 +841,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model
@ -872,7 +874,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
norm_kwargs=resolve_bn_args(kwargs),
**kwargs
)
model = _create_effnet(model_kwargs, variant, pretrained)
model = _create_effnet(variant, pretrained, **model_kwargs)
return model

@ -1,6 +1,25 @@
from .registry import is_model, is_model_in_modules, model_entrypoint
from .helpers import load_checkpoint
from .layers import set_layer_config
from .hub import load_model_config_from_hf
def split_model_name(model_name):
model_split = model_name.split(':', 1)
if len(model_split) == 1:
return '', model_split[0]
else:
source_name, model_name = model_split
assert source_name in ('timm', 'hf_hub')
return source_name, model_name
def safe_model_name(model_name, remove_source=True):
def make_safe(name):
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
if remove_source:
model_name = split_model_name(model_name)[-1]
return make_safe(model_name)
def create_model(
@ -26,7 +45,7 @@ def create_model(
global_pool (str): global pool type (default: 'avg')
**: other kwargs are model specific
"""
model_args = dict(pretrained=pretrained)
source_name, model_name = split_model_name(model_name)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
@ -47,12 +66,19 @@ def create_model(
# non-supporting models don't break and default args remain in effect.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if source_name == 'hf_hub':
# For model names specified in the form `hf_hub:path/architecture_name#revision`,
# load model weights + default_cfg from Hugging Face hub.
hf_default_cfg, model_name = load_model_config_from_hf(model_name)
kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday
if is_model(model_name):
create_fn = model_entrypoint(model_name)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
if is_model(model_name):
create_fn = model_entrypoint(model_name)
model = create_fn(**model_args, **kwargs)
else:
raise RuntimeError('Unknown model (%s)' % model_name)
model = create_fn(pretrained=pretrained, **kwargs)
if checkpoint_path:
load_checkpoint(model, checkpoint_path)

@ -58,7 +58,10 @@ default_cfgs = {
def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
return build_model_with_cfg(
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -233,8 +233,10 @@ class Xception65(nn.Module):
def _create_gluon_xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
Xception65, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), **kwargs)
Xception65, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'),
**kwargs)
@register_model

@ -35,7 +35,6 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
"""
num_features = 1280
act_layer = resolve_act_layer(kwargs, 'hard_swish')
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
@ -43,23 +42,24 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
stem_size=32,
channel_multiplier=1,
norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer,
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
**kwargs,
)
features_only = False
model_cls = MobileNetV3
kwargs_filter = None
if model_kwargs.pop('features_only', False):
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=not features_only, **model_kwargs)
model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model

@ -7,17 +7,14 @@ import os
import math
from collections import OrderedDict
from copy import deepcopy
from typing import Callable
from typing import Any, Callable, Optional, Tuple
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .layers import Conv2dSame, Linear
@ -92,7 +89,7 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None,
raise FileNotFoundError()
def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False):
def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False):
r"""Loads a custom (read non .pth) weight file
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
@ -104,7 +101,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
Args:
model: The instantiated model to load weights into
cfg (dict): Default pretrained model cfg
default_cfg (dict): Default pretrained model cfg
load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named
'laod_pretrained' on the model will be called if it exists
progress (bool, optional): whether or not to display a progress bar to stderr. Default: False
@ -113,31 +110,12 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
digits of the SHA256 hash of the contents of the file. The hash is used to
ensure unique names and to verify the contents of the file. Default: False
"""
cfg = cfg or getattr(model, 'default_cfg')
if cfg is None or not cfg.get('url', None):
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
pretrained_url = default_cfg.get('url', None)
if not pretrained_url:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
url = cfg['url']
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
model_dir = os.path.join(hub_dir, 'checkpoints')
os.makedirs(model_dir, exist_ok=True)
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress)
if load_fn is not None:
load_fn(model, cached_file)
@ -172,17 +150,39 @@ def adapt_input_conv(in_chans, conv_weight):
return conv_weight
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
cfg = cfg or getattr(model, 'default_cfg')
if cfg is None or not cfg.get('url', None):
def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
""" Load pretrained checkpoint
Args:
model (nn.Module) : PyTorch model module
default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset
num_classes (int): num_classes for model
in_chans (int): in_chans for model
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
strict (bool): strict load of checkpoint
progress (bool): enable progress bar for weight download
"""
default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {}
pretrained_url = default_cfg.get('url', None)
hf_hub_id = default_cfg.get('hf_hub', None)
if not pretrained_url and not hf_hub_id:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
else:
_logger.info(f'Loading pretrained weights from url ({pretrained_url})')
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
if filter_fn is not None:
state_dict = filter_fn(state_dict)
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
input_convs = cfg.get('first_conv', None)
input_convs = default_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str):
input_convs = (input_convs,)
@ -198,19 +198,20 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
_logger.warning(
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
classifier_name = cfg['classifier']
label_offset = cfg.get('label_offset', 0)
if num_classes != cfg['num_classes']:
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
elif label_offset > 0:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
classifier_name = default_cfg.get('classifier', None)
label_offset = default_cfg.get('label_offset', 0)
if classifier_name is not None:
if num_classes != default_cfg['num_classes']:
# completely discard fully connected if model num_classes doesn't match pretrained weights
del state_dict[classifier_name + '.weight']
del state_dict[classifier_name + '.bias']
strict = False
elif label_offset > 0:
# special case for pretrained weights with an extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)
@ -316,40 +317,123 @@ def adapt_model_from_file(parent_module, model_variant):
def default_cfg_for_features(default_cfg):
default_cfg = deepcopy(default_cfg)
# remove default pretrained cfg fields that don't have much relevance for feature backbone
to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size?
to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size?
for tr in to_remove:
default_cfg.pop(tr, None)
return default_cfg
def overlay_external_default_cfg(default_cfg, kwargs):
""" Overlay 'external_default_cfg' in kwargs on top of default_cfg arg.
"""
external_default_cfg = kwargs.pop('external_default_cfg', None)
if external_default_cfg:
default_cfg.pop('url', None) # url should come from external cfg
default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg
default_cfg.update(external_default_cfg)
def set_default_kwargs(kwargs, names, default_cfg):
for n in names:
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
# default_cfg has one input_size=(C, H ,W) entry
if n == 'img_size':
input_size = default_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[-2:])
elif n == 'in_chans':
input_size = default_cfg.get('input_size', None)
if input_size is not None:
assert len(input_size) == 3
kwargs.setdefault(n, input_size[0])
else:
default_val = default_cfg.get(n, None)
if default_val is not None:
kwargs.setdefault(n, default_cfg[n])
def filter_kwargs(kwargs, names):
if not kwargs or not names:
return
for n in names:
kwargs.pop(n, None)
def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
""" Update the default_cfg and kwargs before passing to model
FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs
could/should be replaced by an improved configuration mechanism
Args:
default_cfg: input default_cfg (updated in-place)
kwargs: keyword args passed to model build fn (updated in-place)
kwargs_filter: keyword arg keys that must be removed before model __init__
"""
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)
def build_model_with_cfg(
model_cls: Callable,
variant: str,
pretrained: bool,
default_cfg: dict,
model_cfg: dict = None,
feature_cfg: dict = None,
model_cfg: Optional[Any] = None,
feature_cfg: Optional[dict] = None,
pretrained_strict: bool = True,
pretrained_filter_fn: Callable = None,
pretrained_filter_fn: Optional[Callable] = None,
pretrained_custom_load: bool = False,
kwargs_filter: Optional[Tuple[str]] = None,
**kwargs):
""" Build model with specified default_cfg and optional model_cfg
This helper fn aids in the construction of a model including:
* handling default_cfg and associated pretained weight loading
* passing through optional model_cfg for models with config based arch spec
* features_only model adaptation
* pruning config / model adaptation
Args:
model_cls (nn.Module): model class
variant (str): model variant name
pretrained (bool): load pretrained weights
default_cfg (dict): model's default pretrained/task config
model_cfg (Optional[Dict]): model's architecture config
feature_cfg (Optional[Dict]: feature extraction adapter config
pretrained_strict (bool): load pretrained weights strictly
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
**kwargs: model args passed through to model __init__
"""
pruned = kwargs.pop('pruned', False)
features = False
feature_cfg = feature_cfg or {}
default_cfg = deepcopy(default_cfg) if default_cfg else {}
update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter)
default_cfg.setdefault('architecture', variant)
# Setup for feature extraction wrapper done at end of this fn
if kwargs.pop('features_only', False):
features = True
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
if 'out_indices' in kwargs:
feature_cfg['out_indices'] = kwargs.pop('out_indices')
# Build the model
model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs)
model.default_cfg = deepcopy(default_cfg)
model.default_cfg = default_cfg
if pruned:
model = adapt_model_from_file(model, variant)
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_custom_load:
@ -357,9 +441,12 @@ def build_model_with_cfg(
else:
load_pretrained(
model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict)
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict)
# Wrap the model in a feature extraction module if enabled
if features:
feature_cls = FeatureListNet
if 'feature_cls' in feature_cfg:

@ -774,13 +774,18 @@ class HighResolutionNetFeatures(HighResolutionNet):
def _create_hrnet(variant, pretrained, **model_kwargs):
model_cls = HighResolutionNet
features_only = False
kwargs_filter = None
if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures
model_kwargs['num_classes'] = 0
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant],
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model

@ -0,0 +1,96 @@
import json
import logging
import os
from functools import partial
from typing import Union, Optional
import torch
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
from timm import __version__
try:
from huggingface_hub import hf_hub_url
from huggingface_hub import cached_download
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
except ImportError:
hf_hub_url = None
cached_download = None
_logger = logging.getLogger(__name__)
def get_cache_dir(child_dir=''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
"""
# Issue warning to move data if old env is set
if os.getenv('TORCH_MODEL_ZOO'):
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
hub_dir = get_dir()
child_dir = () if not child_dir else (child_dir,)
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
os.makedirs(model_dir, exist_ok=True)
return model_dir
def download_cached_file(url, check_hash=True, progress=False):
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if not os.path.exists(cached_file):
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
hash_prefix = None
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
return cached_file
def has_hf_hub(necessary=False):
if hf_hub_url is None and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return hf_hub_url is not None
def hf_split(hf_id):
rev_split = hf_id.split('#')
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one # character to identify revision.'
hf_model_id = rev_split[0]
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
return hf_model_id, hf_revision
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
with open(json_file, "r", encoding="utf-8") as reader:
text = reader.read()
return json.loads(text)
def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
url = hf_hub_url(hf_model_id, filename, revision=hf_revision)
return cached_download(url, cache_dir=get_cache_dir('hf'))
def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json')
default_cfg = load_cfg_from_json(cached_file)
default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation
model_name = default_cfg.get('architecture')
return default_cfg, model_name
def load_state_dict_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict

@ -336,7 +336,9 @@ class InceptionResnetV2(nn.Module):
def _create_inception_resnet_v2(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
InceptionResnetV2, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
InceptionResnetV2, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -434,8 +434,10 @@ def _create_inception_v3(variant, pretrained=False, **kwargs):
model_cls = InceptionV3
load_strict = not default_cfg['has_aux']
return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **kwargs)
model_cls, variant, pretrained,
default_cfg=default_cfg,
pretrained_strict=load_strict,
**kwargs)
@register_model

@ -305,8 +305,10 @@ class InceptionV4(nn.Module):
def _create_inception_v4(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
InceptionV4, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), **kwargs)
InceptionV4, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model

@ -200,19 +200,20 @@ class MobileNetV3Features(nn.Module):
return list(out.values())
def _create_mnv3(model_kwargs, variant, pretrained=False):
def _create_mnv3(variant, pretrained=False, **kwargs):
features_only = False
model_cls = MobileNetV3
if model_kwargs.pop('features_only', False):
kwargs_filter = None
if kwargs.pop('features_only', False):
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=not features_only, **model_kwargs)
model_cls, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model
@ -252,7 +253,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1),
**kwargs,
)
model = _create_mnv3(model_kwargs, variant, pretrained)
model = _create_mnv3(variant, pretrained, **model_kwargs)
return model
@ -348,7 +349,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8),
**kwargs,
)
model = _create_mnv3(model_kwargs, variant, pretrained)
model = _create_mnv3(variant, pretrained, **model_kwargs)
return model

@ -553,7 +553,8 @@ class NASNetALarge(nn.Module):
def _create_nasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant],
NASNetALarge, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs)

@ -334,7 +334,8 @@ class PNASNet5Large(nn.Module):
def _create_pnasnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant],
PNASNet5Large, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
**kwargs)

@ -330,7 +330,10 @@ class RegNet(nn.Module):
def _create_regnet(variant, pretrained, **kwargs):
return build_model_with_cfg(
RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], **kwargs)
RegNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant],
**kwargs)
@register_model

@ -134,7 +134,9 @@ class Bottle2neck(nn.Module):
def _create_res2net(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -141,7 +141,9 @@ class ResNestBottleneck(nn.Module):
def _create_resnest(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -634,7 +634,9 @@ class ResNet(nn.Module):
def _create_resnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -413,8 +413,11 @@ class ResNetV2(nn.Module):
def _create_resnetv2(variant, pretrained=False, **kwargs):
feature_cfg = dict(flatten_sequential=True)
return build_model_with_cfg(
ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True,
feature_cfg=feature_cfg, **kwargs)
ResNetV2, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=feature_cfg,
pretrained_custom_load=True,
**kwargs)
@register_model

@ -199,7 +199,10 @@ class ReXNetV1(nn.Module):
def _create_rexnet(variant, pretrained, **kwargs):
feature_cfg = dict(flatten_sequential=True)
return build_model_with_cfg(
ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs)
ReXNetV1, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=feature_cfg,
**kwargs)
@register_model

@ -196,7 +196,7 @@ class SelecSLS(nn.Module):
return x
def _create_selecsls(variant, pretrained, model_kwargs):
def _create_selecsls(variant, pretrained, **kwargs):
cfg = {}
feature_info = [dict(num_chs=32, reduction=2, module='stem.2')]
if variant.startswith('selecsls42'):
@ -320,40 +320,43 @@ def _create_selecsls(variant, pretrained, model_kwargs):
# this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises?
return build_model_with_cfg(
SelecSLS, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg,
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **model_kwargs)
SelecSLS, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=cfg,
feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True),
**kwargs)
@register_model
def selecsls42(pretrained=False, **kwargs):
"""Constructs a SelecSLS42 model.
"""
return _create_selecsls('selecsls42', pretrained, kwargs)
return _create_selecsls('selecsls42', pretrained, **kwargs)
@register_model
def selecsls42b(pretrained=False, **kwargs):
"""Constructs a SelecSLS42_B model.
"""
return _create_selecsls('selecsls42b', pretrained, kwargs)
return _create_selecsls('selecsls42b', pretrained, **kwargs)
@register_model
def selecsls60(pretrained=False, **kwargs):
"""Constructs a SelecSLS60 model.
"""
return _create_selecsls('selecsls60', pretrained, kwargs)
return _create_selecsls('selecsls60', pretrained, **kwargs)
@register_model
def selecsls60b(pretrained=False, **kwargs):
"""Constructs a SelecSLS60_B model.
"""
return _create_selecsls('selecsls60b', pretrained, kwargs)
return _create_selecsls('selecsls60b', pretrained, **kwargs)
@register_model
def selecsls84(pretrained=False, **kwargs):
"""Constructs a SelecSLS84 model.
"""
return _create_selecsls('selecsls84', pretrained, kwargs)
return _create_selecsls('selecsls84', pretrained, **kwargs)

@ -398,7 +398,9 @@ class SENet(nn.Module):
def _create_senet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
SENet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
SENet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -141,7 +141,9 @@ class SelectiveKernelBottleneck(nn.Module):
def _create_skresnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs)
ResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model

@ -253,8 +253,10 @@ class TResNet(nn.Module):
def _create_tresnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
TResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained,
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs)
TResNet, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
**kwargs)
@register_model

@ -180,9 +180,9 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
# NOTE: VGG is one of the only models with stride==1 features, so indices are offset from other models
out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5))
model = build_model_with_cfg(
VGG, variant, pretrained=pretrained,
model_cfg=cfgs[cfg],
VGG, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=cfgs[cfg],
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
pretrained_filter_fn=_filter_fn,
**kwargs)

@ -21,13 +21,14 @@ import math
import logging
from functools import partial
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2
@ -94,7 +95,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_huge_patch14_224_in21k': _cfg(
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official Google JAX impl)
@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model):
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
default_cfg = default_cfgs[variant]
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-1]
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
@ -475,14 +477,19 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
model.default_cfg = default_cfg
model = build_model_with_cfg(
model_cls, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
if pretrained:
load_pretrained(
model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3),
filter_fn=partial(checkpoint_filter_fn, model=model))
return model

@ -338,8 +338,11 @@ class VovNet(nn.Module):
def _create_vovnet(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
VovNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True), **kwargs)
VovNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model

@ -221,8 +221,10 @@ class Xception(nn.Module):
def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
Xception, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'), **kwargs)
Xception, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(feature_cls='hook'),
**kwargs)
@register_model

@ -173,8 +173,10 @@ class XceptionAligned(nn.Module):
def _xception(variant, pretrained=False, **kwargs):
return build_model_with_cfg(
XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs)
XceptionAligned, variant, pretrained,
default_cfg=default_cfgs[variant],
feature_cfg=dict(flatten_sequential=True, feature_cls='hook'),
**kwargs)
@register_model

@ -29,7 +29,8 @@ import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
from timm.optim import create_optimizer
@ -345,8 +346,8 @@ def main():
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
@ -543,7 +544,7 @@ def main():
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(output_base, 'train', exp_name)

@ -211,7 +211,7 @@ def validate(args):
model.eval()
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + data_config['input_size']).cuda()
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
model(input)

Loading…
Cancel
Save