diff --git a/sotabench.py b/sotabench.py index 65aca6dd..6a0e10a1 100644 --- a/sotabench.py +++ b/sotabench.py @@ -509,7 +509,7 @@ for m in model_list: model.eval() with torch.no_grad(): # warmup - input = torch.randn((batch_size,) + data_config['input_size']).cuda() + input = torch.randn((batch_size,) + tuple(data_config['input_size'])).cuda() model(input) bar = tqdm(desc="Evaluation", mininterval=5, total=50000) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index b3b08e30..4220304f 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -72,8 +72,8 @@ class RandomResizedCropAndInterpolation: def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation='bilinear'): - if isinstance(size, tuple): - self.size = size + if isinstance(size, (list, tuple)): + self.size = tuple(size) else: self.size = (size, size) if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): diff --git a/timm/data/transforms_factory.py b/timm/data/transforms_factory.py index 01c9fcf2..df6e0de0 100644 --- a/timm/data/transforms_factory.py +++ b/timm/data/transforms_factory.py @@ -78,7 +78,7 @@ def transforms_imagenet_train( secondary_tfl = [] if auto_augment: assert isinstance(auto_augment, str) - if isinstance(img_size, tuple): + if isinstance(img_size, (tuple, list)): img_size_min = min(img_size) else: img_size_min = img_size @@ -136,7 +136,7 @@ def transforms_imagenet_eval( std=IMAGENET_DEFAULT_STD): crop_pct = crop_pct or DEFAULT_CROP_PCT - if isinstance(img_size, tuple): + if isinstance(img_size, (tuple, list)): assert len(img_size) == 2 if img_size[-1] == img_size[-2]: # fall-back to older behaviour so Resize scales to shortest edge if target is square @@ -186,7 +186,7 @@ def create_transform( tf_preprocessing=False, separate=False): - if isinstance(input_size, tuple): + if isinstance(input_size, (tuple, list)): img_size = input_size[-2:] else: img_size = input_size diff --git a/timm/models/__init__.py b/timm/models/__init__.py index c04aad11..aa13bbcc 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -31,7 +31,7 @@ from .xception import * from .xception_aligned import * from .hardcorenas import * -from .factory import create_model +from .factory import create_model, split_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index afd1dcd7..39d16200 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -409,8 +409,10 @@ class CspNet(nn.Module): def _create_cspnet(variant, pretrained=False, **kwargs): cfg_variant = variant.split('_')[0] return build_model_with_cfg( - CspNet, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], **kwargs) + CspNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), model_cfg=model_cfgs[cfg_variant], + **kwargs) @register_model diff --git a/timm/models/densenet.py b/timm/models/densenet.py index e4e20564..38a19727 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -287,8 +287,10 @@ def _create_densenet(variant, growth_rate, block_config, pretrained, **kwargs): kwargs['growth_rate'] = growth_rate kwargs['block_config'] = block_config return build_model_with_cfg( - DenseNet, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, **kwargs) + DenseNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), pretrained_filter_fn=_filter_torchvision_pretrained, + **kwargs) @register_model diff --git a/timm/models/dla.py b/timm/models/dla.py index a41ec326..64ad61d6 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -338,8 +338,11 @@ class DLA(nn.Module): def _create_dla(variant, pretrained=False, **kwargs): return build_model_with_cfg( - DLA, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=False, feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), **kwargs) + DLA, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=False, + feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)), + **kwargs) @register_model diff --git a/timm/models/dpn.py b/timm/models/dpn.py index ac9c7755..90ef11cc 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -262,8 +262,10 @@ class DPN(nn.Module): def _create_dpn(variant, pretrained=False, **kwargs): return build_model_with_cfg( - DPN, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(feature_concat=True, flatten_sequential=True), **kwargs) + DPN, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_concat=True, flatten_sequential=True), + **kwargs) @register_model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 4a89590b..086a5e2e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -452,18 +452,20 @@ class EfficientNetFeatures(nn.Module): return list(out.values()) -def _create_effnet(model_kwargs, variant, pretrained=False): +def _create_effnet(variant, pretrained=False, **kwargs): features_only = False model_cls = EfficientNet - if model_kwargs.pop('features_only', False): + kwargs_filter = None + if kwargs.pop('features_only', False): features_only = True - model_kwargs.pop('num_classes', 0) - model_kwargs.pop('num_features', 0) - model_kwargs.pop('head_conv', None) + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool') model_cls = EfficientNetFeatures model = build_model_with_cfg( - model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=not features_only, **model_kwargs) + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **kwargs) if features_only: model.default_cfg = default_cfg_for_features(model.default_cfg) return model @@ -501,7 +503,7 @@ def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -537,7 +539,7 @@ def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs) norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -566,7 +568,7 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs,variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -595,7 +597,7 @@ def _gen_mobilenet_v2( act_layer=resolve_act_layer(kwargs, 'relu6'), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -625,7 +627,7 @@ def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -660,7 +662,7 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -706,7 +708,7 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -735,7 +737,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 act_layer=resolve_act_layer(kwargs, 'relu'), **kwargs, ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -765,7 +767,7 @@ def _gen_efficientnet_condconv( act_layer=resolve_act_layer(kwargs, 'swish'), **kwargs, ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -806,7 +808,7 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0 norm_kwargs=resolve_bn_args(kwargs), **kwargs, ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -839,7 +841,7 @@ def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model @@ -872,7 +874,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai norm_kwargs=resolve_bn_args(kwargs), **kwargs ) - model = _create_effnet(model_kwargs, variant, pretrained) + model = _create_effnet(variant, pretrained, **model_kwargs) return model diff --git a/timm/models/factory.py b/timm/models/factory.py index a7b6c90e..d040a9ff 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,6 +1,25 @@ from .registry import is_model, is_model_in_modules, model_entrypoint from .helpers import load_checkpoint from .layers import set_layer_config +from .hub import load_model_config_from_hf + + +def split_model_name(model_name): + model_split = model_name.split(':', 1) + if len(model_split) == 1: + return '', model_split[0] + else: + source_name, model_name = model_split + assert source_name in ('timm', 'hf_hub') + return source_name, model_name + + +def safe_model_name(model_name, remove_source=True): + def make_safe(name): + return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_') + if remove_source: + model_name = split_model_name(model_name)[-1] + return make_safe(model_name) def create_model( @@ -26,7 +45,7 @@ def create_model( global_pool (str): global pool type (default: 'avg') **: other kwargs are model specific """ - model_args = dict(pretrained=pretrained) + source_name, model_name = split_model_name(model_name) # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) @@ -47,12 +66,19 @@ def create_model( # non-supporting models don't break and default args remain in effect. kwargs = {k: v for k, v in kwargs.items() if v is not None} + if source_name == 'hf_hub': + # For model names specified in the form `hf_hub:path/architecture_name#revision`, + # load model weights + default_cfg from Hugging Face hub. + hf_default_cfg, model_name = load_model_config_from_hf(model_name) + kwargs['external_default_cfg'] = hf_default_cfg # FIXME revamp default_cfg interface someday + + if is_model(model_name): + create_fn = model_entrypoint(model_name) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit): - if is_model(model_name): - create_fn = model_entrypoint(model_name) - model = create_fn(**model_args, **kwargs) - else: - raise RuntimeError('Unknown model (%s)' % model_name) + model = create_fn(pretrained=pretrained, **kwargs) if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index 9bd99d04..027a10b5 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -58,7 +58,10 @@ default_cfgs = { def _create_resnet(variant, pretrained=False, **kwargs): - return build_model_with_cfg(ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + return build_model_with_cfg( + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/gluon_xception.py b/timm/models/gluon_xception.py index 8fc398d6..fbd668a5 100644 --- a/timm/models/gluon_xception.py +++ b/timm/models/gluon_xception.py @@ -233,8 +233,10 @@ class Xception65(nn.Module): def _create_gluon_xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( - Xception65, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(feature_cls='hook'), **kwargs) + Xception65, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), + **kwargs) @register_model diff --git a/timm/models/hardcorenas.py b/timm/models/hardcorenas.py index b820265d..4420172b 100644 --- a/timm/models/hardcorenas.py +++ b/timm/models/hardcorenas.py @@ -35,7 +35,6 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): """ num_features = 1280 - act_layer = resolve_act_layer(kwargs, 'hard_swish') model_kwargs = dict( block_args=decode_arch_def(arch_def), @@ -43,23 +42,24 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs): stem_size=32, channel_multiplier=1, norm_kwargs=resolve_bn_args(kwargs), - act_layer=act_layer, + act_layer=resolve_act_layer(kwargs, 'hard_swish'), se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), **kwargs, ) features_only = False model_cls = MobileNetV3 + kwargs_filter = None if model_kwargs.pop('features_only', False): features_only = True - model_kwargs.pop('num_classes', 0) - model_kwargs.pop('num_features', 0) - model_kwargs.pop('head_conv', None) - model_kwargs.pop('head_bias', None) + kwargs_filter = ('num_classes', 'num_features', 'global_pool', 'head_conv', 'head_bias', 'global_pool') model_cls = MobileNetV3Features model = build_model_with_cfg( - model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=not features_only, **model_kwargs) + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) if features_only: model.default_cfg = default_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 4d9b8a28..2f6e098b 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -7,17 +7,14 @@ import os import math from collections import OrderedDict from copy import deepcopy -from typing import Callable +from typing import Any, Callable, Optional, Tuple import torch import torch.nn as nn -from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX -try: - from torch.hub import get_dir -except ImportError: - from torch.hub import _get_torch_home as get_dir + from .features import FeatureListNet, FeatureDictNet, FeatureHookNet +from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url from .layers import Conv2dSame, Linear @@ -92,7 +89,7 @@ def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, raise FileNotFoundError() -def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_hash=False): +def load_custom_pretrained(model, default_cfg=None, load_fn=None, progress=False, check_hash=False): r"""Loads a custom (read non .pth) weight file Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls @@ -104,7 +101,7 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ Args: model: The instantiated model to load weights into - cfg (dict): Default pretrained model cfg + default_cfg (dict): Default pretrained model cfg load_fn: An external stand alone fn that loads weights into provided model, otherwise a fn named 'laod_pretrained' on the model will be called if it exists progress (bool, optional): whether or not to display a progress bar to stderr. Default: False @@ -113,31 +110,12 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ digits of the SHA256 hash of the contents of the file. The hash is used to ensure unique names and to verify the contents of the file. Default: False """ - cfg = cfg or getattr(model, 'default_cfg') - if cfg is None or not cfg.get('url', None): + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + if not pretrained_url: _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - url = cfg['url'] - - # Issue warning to move data if old env is set - if os.getenv('TORCH_MODEL_ZOO'): - _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') - - hub_dir = get_dir() - model_dir = os.path.join(hub_dir, 'checkpoints') - - os.makedirs(model_dir, exist_ok=True) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - cached_file = os.path.join(model_dir, filename) - if not os.path.exists(cached_file): - _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) - hash_prefix = None - if check_hash: - r = HASH_REGEX.search(filename) # r is Optional[Match[str]] - hash_prefix = r.group(1) if r else None - download_url_to_file(url, cached_file, hash_prefix, progress=progress) + cached_file = download_cached_file(default_cfg['url'], check_hash=check_hash, progress=progress) if load_fn is not None: load_fn(model, cached_file) @@ -172,17 +150,39 @@ def adapt_input_conv(in_chans, conv_weight): return conv_weight -def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): - cfg = cfg or getattr(model, 'default_cfg') - if cfg is None or not cfg.get('url', None): +def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): + """ Load pretrained checkpoint + + Args: + model (nn.Module) : PyTorch model module + default_cfg (Optional[Dict]): default configuration for pretrained weights / target dataset + num_classes (int): num_classes for model + in_chans (int): in_chans for model + filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args) + strict (bool): strict load of checkpoint + progress (bool): enable progress bar for weight download + + """ + default_cfg = default_cfg or getattr(model, 'default_cfg', None) or {} + pretrained_url = default_cfg.get('url', None) + hf_hub_id = default_cfg.get('hf_hub', None) + if not pretrained_url and not hf_hub_id: _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - - state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') + if hf_hub_id and has_hf_hub(necessary=not pretrained_url): + _logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})') + state_dict = load_state_dict_from_hf(hf_hub_id) + else: + _logger.info(f'Loading pretrained weights from url ({pretrained_url})') + state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu') if filter_fn is not None: - state_dict = filter_fn(state_dict) + # for backwards compat with filter fn that take one arg, try one first, the two + try: + state_dict = filter_fn(state_dict) + except TypeError: + state_dict = filter_fn(state_dict, model) - input_convs = cfg.get('first_conv', None) + input_convs = default_cfg.get('first_conv', None) if input_convs is not None and in_chans != 3: if isinstance(input_convs, str): input_convs = (input_convs,) @@ -198,19 +198,20 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non _logger.warning( f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.') - classifier_name = cfg['classifier'] - label_offset = cfg.get('label_offset', 0) - if num_classes != cfg['num_classes']: - # completely discard fully connected if model num_classes doesn't match pretrained weights - del state_dict[classifier_name + '.weight'] - del state_dict[classifier_name + '.bias'] - strict = False - elif label_offset > 0: - # special case for pretrained weights with an extra background class in pretrained weights - classifier_weight = state_dict[classifier_name + '.weight'] - state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] - classifier_bias = state_dict[classifier_name + '.bias'] - state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] + classifier_name = default_cfg.get('classifier', None) + label_offset = default_cfg.get('label_offset', 0) + if classifier_name is not None: + if num_classes != default_cfg['num_classes']: + # completely discard fully connected if model num_classes doesn't match pretrained weights + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + elif label_offset > 0: + # special case for pretrained weights with an extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:] model.load_state_dict(state_dict, strict=strict) @@ -316,40 +317,123 @@ def adapt_model_from_file(parent_module, model_variant): def default_cfg_for_features(default_cfg): default_cfg = deepcopy(default_cfg) # remove default pretrained cfg fields that don't have much relevance for feature backbone - to_remove = ('num_classes', 'crop_pct', 'classifier') # add default final pool size? + to_remove = ('num_classes', 'crop_pct', 'classifier', 'global_pool') # add default final pool size? for tr in to_remove: default_cfg.pop(tr, None) return default_cfg +def overlay_external_default_cfg(default_cfg, kwargs): + """ Overlay 'external_default_cfg' in kwargs on top of default_cfg arg. + """ + external_default_cfg = kwargs.pop('external_default_cfg', None) + if external_default_cfg: + default_cfg.pop('url', None) # url should come from external cfg + default_cfg.pop('hf_hub', None) # hf hub id should come from external cfg + default_cfg.update(external_default_cfg) + + +def set_default_kwargs(kwargs, names, default_cfg): + for n in names: + # for legacy reasons, model __init__args uses img_size + in_chans as separate args while + # default_cfg has one input_size=(C, H ,W) entry + if n == 'img_size': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[-2:]) + elif n == 'in_chans': + input_size = default_cfg.get('input_size', None) + if input_size is not None: + assert len(input_size) == 3 + kwargs.setdefault(n, input_size[0]) + else: + default_val = default_cfg.get(n, None) + if default_val is not None: + kwargs.setdefault(n, default_cfg[n]) + + +def filter_kwargs(kwargs, names): + if not kwargs or not names: + return + for n in names: + kwargs.pop(n, None) + + +def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): + """ Update the default_cfg and kwargs before passing to model + + FIXME this sequence of overlay default_cfg, set default kwargs, filter kwargs + could/should be replaced by an improved configuration mechanism + + Args: + default_cfg: input default_cfg (updated in-place) + kwargs: keyword args passed to model build fn (updated in-place) + kwargs_filter: keyword arg keys that must be removed before model __init__ + """ + # Overlay default cfg values from `external_default_cfg` if it exists in kwargs + overlay_external_default_cfg(default_cfg, kwargs) + # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) + set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) + # Filter keyword args for task specific model variants (some 'features only' models, etc.) + filter_kwargs(kwargs, names=kwargs_filter) + + def build_model_with_cfg( model_cls: Callable, variant: str, pretrained: bool, default_cfg: dict, - model_cfg: dict = None, - feature_cfg: dict = None, + model_cfg: Optional[Any] = None, + feature_cfg: Optional[dict] = None, pretrained_strict: bool = True, - pretrained_filter_fn: Callable = None, + pretrained_filter_fn: Optional[Callable] = None, pretrained_custom_load: bool = False, + kwargs_filter: Optional[Tuple[str]] = None, **kwargs): + """ Build model with specified default_cfg and optional model_cfg + + This helper fn aids in the construction of a model including: + * handling default_cfg and associated pretained weight loading + * passing through optional model_cfg for models with config based arch spec + * features_only model adaptation + * pruning config / model adaptation + + Args: + model_cls (nn.Module): model class + variant (str): model variant name + pretrained (bool): load pretrained weights + default_cfg (dict): model's default pretrained/task config + model_cfg (Optional[Dict]): model's architecture config + feature_cfg (Optional[Dict]: feature extraction adapter config + pretrained_strict (bool): load pretrained weights strictly + pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights + pretrained_custom_load (bool): use custom load fn, to load numpy or other non PyTorch weights + kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model + **kwargs: model args passed through to model __init__ + """ pruned = kwargs.pop('pruned', False) features = False feature_cfg = feature_cfg or {} + default_cfg = deepcopy(default_cfg) if default_cfg else {} + update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter) + default_cfg.setdefault('architecture', variant) + # Setup for feature extraction wrapper done at end of this fn if kwargs.pop('features_only', False): features = True feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4)) if 'out_indices' in kwargs: feature_cfg['out_indices'] = kwargs.pop('out_indices') + # Build the model model = model_cls(**kwargs) if model_cfg is None else model_cls(cfg=model_cfg, **kwargs) - model.default_cfg = deepcopy(default_cfg) + model.default_cfg = default_cfg if pruned: model = adapt_model_from_file(model, variant) - # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + # For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: if pretrained_custom_load: @@ -357,9 +441,12 @@ def build_model_with_cfg( else: load_pretrained( model, - num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, strict=pretrained_strict) - + num_classes=num_classes_pretrained, + in_chans=kwargs.get('in_chans', 3), + filter_fn=pretrained_filter_fn, + strict=pretrained_strict) + + # Wrap the model in a feature extraction module if enabled if features: feature_cls = FeatureListNet if 'feature_cls' in feature_cfg: diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 1c0bc9f0..c56964f6 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -774,13 +774,18 @@ class HighResolutionNetFeatures(HighResolutionNet): def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNet features_only = False + kwargs_filter = None if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures - model_kwargs['num_classes'] = 0 + kwargs_filter = ('num_classes', 'global_pool') features_only = True model = build_model_with_cfg( - model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs) + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfg_cls[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **model_kwargs) if features_only: model.default_cfg = default_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/hub.py b/timm/models/hub.py new file mode 100644 index 00000000..69365a0c --- /dev/null +++ b/timm/models/hub.py @@ -0,0 +1,96 @@ +import json +import logging +import os +from functools import partial +from typing import Union, Optional + +import torch +from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX +try: + from torch.hub import get_dir +except ImportError: + from torch.hub import _get_torch_home as get_dir + +from timm import __version__ +try: + from huggingface_hub import hf_hub_url + from huggingface_hub import cached_download + cached_download = partial(cached_download, library_name="timm", library_version=__version__) +except ImportError: + hf_hub_url = None + cached_download = None + +_logger = logging.getLogger(__name__) + + +def get_cache_dir(child_dir=''): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + child_dir = () if not child_dir else (child_dir,) + model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir) + os.makedirs(model_dir, exist_ok=True) + return model_dir + + +def download_cached_file(url, check_hash=True, progress=False): + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(get_cache_dir(), filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) + return cached_file + + +def has_hf_hub(necessary=False): + if hf_hub_url is None and necessary: + # if no HF Hub module installed and it is necessary to continue, raise error + raise RuntimeError( + 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') + return hf_hub_url is not None + + +def hf_split(hf_id): + rev_split = hf_id.split('#') + assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one # character to identify revision.' + hf_model_id = rev_split[0] + hf_revision = rev_split[-1] if len(rev_split) > 1 else None + return hf_model_id, hf_revision + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def _download_from_hf(model_id: str, filename: str): + hf_model_id, hf_revision = hf_split(model_id) + url = hf_hub_url(hf_model_id, filename, revision=hf_revision) + return cached_download(url, cache_dir=get_cache_dir('hf')) + + +def load_model_config_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'config.json') + default_cfg = load_cfg_from_json(cached_file) + default_cfg['hf_hub'] = model_id # insert hf_hub id for pretrained weight load during model creation + model_name = default_cfg.get('architecture') + return default_cfg, model_name + + +def load_state_dict_from_hf(model_id: str): + assert has_hf_hub(True) + cached_file = _download_from_hf(model_id, 'pytorch_model.bin') + state_dict = torch.load(cached_file, map_location='cpu') + return state_dict diff --git a/timm/models/inception_resnet_v2.py b/timm/models/inception_resnet_v2.py index adfe330e..71672849 100644 --- a/timm/models/inception_resnet_v2.py +++ b/timm/models/inception_resnet_v2.py @@ -336,7 +336,9 @@ class InceptionResnetV2(nn.Module): def _create_inception_resnet_v2(variant, pretrained=False, **kwargs): return build_model_with_cfg( - InceptionResnetV2, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + InceptionResnetV2, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index cdb1f1c0..cbb1107b 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -434,8 +434,10 @@ def _create_inception_v3(variant, pretrained=False, **kwargs): model_cls = InceptionV3 load_strict = not default_cfg['has_aux'] return build_model_with_cfg( - model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **kwargs) + model_cls, variant, pretrained, + default_cfg=default_cfg, + pretrained_strict=load_strict, + **kwargs) @register_model diff --git a/timm/models/inception_v4.py b/timm/models/inception_v4.py index 69f9ff5a..cc899e15 100644 --- a/timm/models/inception_v4.py +++ b/timm/models/inception_v4.py @@ -305,8 +305,10 @@ class InceptionV4(nn.Module): def _create_inception_v4(variant, pretrained=False, **kwargs): return build_model_with_cfg( - InceptionV4, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(flatten_sequential=True), **kwargs) + InceptionV4, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) @register_model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index 8a48ce72..3ec1ab9b 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -200,19 +200,20 @@ class MobileNetV3Features(nn.Module): return list(out.values()) -def _create_mnv3(model_kwargs, variant, pretrained=False): +def _create_mnv3(variant, pretrained=False, **kwargs): features_only = False model_cls = MobileNetV3 - if model_kwargs.pop('features_only', False): + kwargs_filter = None + if kwargs.pop('features_only', False): features_only = True - model_kwargs.pop('num_classes', 0) - model_kwargs.pop('num_features', 0) - model_kwargs.pop('head_conv', None) - model_kwargs.pop('head_bias', None) + kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool') model_cls = MobileNetV3Features model = build_model_with_cfg( - model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=not features_only, **model_kwargs) + model_cls, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_strict=not features_only, + kwargs_filter=kwargs_filter, + **kwargs) if features_only: model.default_cfg = default_cfg_for_features(model.default_cfg) return model @@ -252,7 +253,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=1), **kwargs, ) - model = _create_mnv3(model_kwargs, variant, pretrained) + model = _create_mnv3(variant, pretrained, **model_kwargs) return model @@ -348,7 +349,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg se_kwargs=dict(act_layer=nn.ReLU, gate_fn=hard_sigmoid, reduce_mid=True, divisor=8), **kwargs, ) - model = _create_mnv3(model_kwargs, variant, pretrained) + model = _create_mnv3(variant, pretrained, **model_kwargs) return model diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 1f1a3b75..2afe82c3 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -553,7 +553,8 @@ class NASNetALarge(nn.Module): def _create_nasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - NASNetALarge, variant, pretrained, default_cfg=default_cfgs[variant], + NASNetALarge, variant, pretrained, + default_cfg=default_cfgs[variant], feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model **kwargs) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 73073009..99918156 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -334,7 +334,8 @@ class PNASNet5Large(nn.Module): def _create_pnasnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - PNASNet5Large, variant, pretrained, default_cfg=default_cfgs[variant], + PNASNet5Large, variant, pretrained, + default_cfg=default_cfgs[variant], feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model **kwargs) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 68bb817c..40988946 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -330,7 +330,10 @@ class RegNet(nn.Module): def _create_regnet(variant, pretrained, **kwargs): return build_model_with_cfg( - RegNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], **kwargs) + RegNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 6e51d491..977d872f 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -134,7 +134,9 @@ class Bottle2neck(nn.Module): def _create_res2net(variant, pretrained=False, **kwargs): return build_model_with_cfg( - ResNet, variant, pretrained, default_cfg=default_cfgs[variant], **kwargs) + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 5a8bb348..154e250c 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -141,7 +141,9 @@ class ResNestBottleneck(nn.Module): def _create_resnest(variant, pretrained=False, **kwargs): return build_model_with_cfg( - ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 6dec9d53..656e3a51 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -634,7 +634,9 @@ class ResNet(nn.Module): def _create_resnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 63e17d0b..dbbd2de9 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -413,8 +413,11 @@ class ResNetV2(nn.Module): def _create_resnetv2(variant, pretrained=False, **kwargs): feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( - ResNetV2, variant, pretrained, default_cfg=default_cfgs[variant], pretrained_custom_load=True, - feature_cfg=feature_cfg, **kwargs) + ResNetV2, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, + pretrained_custom_load=True, + **kwargs) @register_model diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index c21442ba..859b584e 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -199,7 +199,10 @@ class ReXNetV1(nn.Module): def _create_rexnet(variant, pretrained, **kwargs): feature_cfg = dict(flatten_sequential=True) return build_model_with_cfg( - ReXNetV1, variant, pretrained, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs) + ReXNetV1, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, + **kwargs) @register_model diff --git a/timm/models/selecsls.py b/timm/models/selecsls.py index 73bc7732..1f3379db 100644 --- a/timm/models/selecsls.py +++ b/timm/models/selecsls.py @@ -196,7 +196,7 @@ class SelecSLS(nn.Module): return x -def _create_selecsls(variant, pretrained, model_kwargs): +def _create_selecsls(variant, pretrained, **kwargs): cfg = {} feature_info = [dict(num_chs=32, reduction=2, module='stem.2')] if variant.startswith('selecsls42'): @@ -320,40 +320,43 @@ def _create_selecsls(variant, pretrained, model_kwargs): # this model can do 6 feature levels by default, unlike most others, leave as 0-4 to avoid surprises? return build_model_with_cfg( - SelecSLS, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=cfg, - feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), **model_kwargs) + SelecSLS, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=cfg, + feature_cfg=dict(out_indices=(0, 1, 2, 3, 4), flatten_sequential=True), + **kwargs) @register_model def selecsls42(pretrained=False, **kwargs): """Constructs a SelecSLS42 model. """ - return _create_selecsls('selecsls42', pretrained, kwargs) + return _create_selecsls('selecsls42', pretrained, **kwargs) @register_model def selecsls42b(pretrained=False, **kwargs): """Constructs a SelecSLS42_B model. """ - return _create_selecsls('selecsls42b', pretrained, kwargs) + return _create_selecsls('selecsls42b', pretrained, **kwargs) @register_model def selecsls60(pretrained=False, **kwargs): """Constructs a SelecSLS60 model. """ - return _create_selecsls('selecsls60', pretrained, kwargs) + return _create_selecsls('selecsls60', pretrained, **kwargs) @register_model def selecsls60b(pretrained=False, **kwargs): """Constructs a SelecSLS60_B model. """ - return _create_selecsls('selecsls60b', pretrained, kwargs) + return _create_selecsls('selecsls60b', pretrained, **kwargs) @register_model def selecsls84(pretrained=False, **kwargs): """Constructs a SelecSLS84 model. """ - return _create_selecsls('selecsls84', pretrained, kwargs) + return _create_selecsls('selecsls84', pretrained, **kwargs) diff --git a/timm/models/senet.py b/timm/models/senet.py index 8073229a..8227a453 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -398,7 +398,9 @@ class SENet(nn.Module): def _create_senet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - SENet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + SENet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 6c654922..bd9dd393 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -141,7 +141,9 @@ class SelectiveKernelBottleneck(nn.Module): def _create_skresnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - ResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, **kwargs) + ResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) @register_model diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index e371292f..a8c237ed 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -253,8 +253,10 @@ class TResNet(nn.Module): def _create_tresnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - TResNet, variant, default_cfg=default_cfgs[variant], pretrained=pretrained, - feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), **kwargs) + TResNet, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True), + **kwargs) @register_model diff --git a/timm/models/vgg.py b/timm/models/vgg.py index ceede650..8bea03e7 100644 --- a/timm/models/vgg.py +++ b/timm/models/vgg.py @@ -180,9 +180,9 @@ def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG: # NOTE: VGG is one of the only models with stride==1 features, so indices are offset from other models out_indices = kwargs.get('out_indices', (0, 1, 2, 3, 4, 5)) model = build_model_with_cfg( - VGG, variant, pretrained=pretrained, - model_cfg=cfgs[cfg], + VGG, variant, pretrained, default_cfg=default_cfgs[variant], + model_cfg=cfgs[cfg], feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), pretrained_filter_fn=_filter_fn, **kwargs) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index acd4d18d..82d4ee49 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -21,13 +21,14 @@ import math import logging from functools import partial from collections import OrderedDict +from copy import deepcopy import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import load_pretrained +from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2 @@ -94,7 +95,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_huge_patch14_224_in21k': _cfg( - url='', # FIXME I have weights for this but > 2GB limit for github release binaries + hf_hub='timm/vit_huge_patch14_224_in21k', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # hybrid models (weights ported from official Google JAX impl) @@ -462,9 +463,10 @@ def checkpoint_filter_fn(state_dict, model): def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): - default_cfg = default_cfgs[variant] + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-1] + default_img_size = default_cfg['input_size'][-2:] num_classes = kwargs.pop('num_classes', default_num_classes) img_size = kwargs.pop('img_size', default_img_size) @@ -475,14 +477,19 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa _logger.warning("Removing representation layer for fine-tuning.") repr_size = None + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model_cls = DistilledVisionTransformer if distilled else VisionTransformer - model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) - model.default_cfg = default_cfg + model = build_model_with_cfg( + model_cls, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) - if pretrained: - load_pretrained( - model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3), - filter_fn=partial(checkpoint_filter_fn, model=model)) return model diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index f544433c..ec5b3e81 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -338,8 +338,11 @@ class VovNet(nn.Module): def _create_vovnet(variant, pretrained=False, **kwargs): return build_model_with_cfg( - VovNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant], - feature_cfg=dict(flatten_sequential=True), **kwargs) + VovNet, variant, pretrained, + default_cfg=default_cfgs[variant], + model_cfg=model_cfgs[variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) @register_model diff --git a/timm/models/xception.py b/timm/models/xception.py index a61548dc..86f558cb 100644 --- a/timm/models/xception.py +++ b/timm/models/xception.py @@ -221,8 +221,10 @@ class Xception(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( - Xception, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(feature_cls='hook'), **kwargs) + Xception, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(feature_cls='hook'), + **kwargs) @register_model diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index dd7a7a86..ea7f5c05 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -173,8 +173,10 @@ class XceptionAligned(nn.Module): def _xception(variant, pretrained=False, **kwargs): return build_model_with_cfg( - XceptionAligned, variant, pretrained, default_cfg=default_cfgs[variant], - feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), **kwargs) + XceptionAligned, variant, pretrained, + default_cfg=default_cfgs[variant], + feature_cfg=dict(flatten_sequential=True, feature_cls='hook'), + **kwargs) @register_model diff --git a/train.py b/train.py index 9abcfed3..f3da4a36 100755 --- a/train.py +++ b/train.py @@ -29,7 +29,8 @@ import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model, model_parameters +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ + convert_splitbn_model, model_parameters from timm.utils import * from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.optim import create_optimizer @@ -345,8 +346,8 @@ def main(): args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: - _logger.info('Model %s created, param count: %d' % - (args.model, sum([m.numel() for m in model.parameters()]))) + _logger.info( + f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) @@ -543,7 +544,7 @@ def main(): output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, + safe_model_name(args.model), str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) diff --git a/validate.py b/validate.py index a311112d..3f201314 100755 --- a/validate.py +++ b/validate.py @@ -211,7 +211,7 @@ def validate(args): model.eval() with torch.no_grad(): # warmup, reduce variability of first batch time, especially for comparing torchscript vs non - input = torch.randn((args.batch_size,) + data_config['input_size']).cuda() + input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda() if args.channels_last: input = input.contiguous(memory_format=torch.channels_last) model(input)