From 7d657d2ef45fc841f3a987ca0f18868686dbecf5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:55:25 -0700 Subject: [PATCH] Improve resolve_pretrained_cfg behaviour when no cfg exists, warn instead of crash. Improve usability ex #1311 --- timm/models/helpers.py | 27 ++++++++++++++++-------- timm/models/inception_v3.py | 2 +- timm/models/vision_transformer.py | 2 +- timm/models/vision_transformer_relpos.py | 2 +- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 1276b68e..11630bb6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -455,18 +455,27 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): +def resolve_pretrained_cfg(variant: str, **kwargs): + pretrained_cfg = kwargs.pop('pretrained_cfg', None) if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed explicitly + # highest priority, pretrained_cfg available and passed in args return deepcopy(pretrained_cfg) - if kwargs and 'pretrained_cfg' in kwargs: - # next highest, pretrained_cfg in a kwargs dict, pop and return - pretrained_cfg = kwargs.pop('pretrained_cfg', {}) - if pretrained_cfg: - return deepcopy(pretrained_cfg) - # lookup pretrained cfg in model registry by variant + # fallback to looking up pretrained cfg in model registry by variant identifier pretrained_cfg = get_pretrained_cfg(variant) - assert pretrained_cfg + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {variant} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = dict( + url='', + num_classes=1000, + input_size=(3, 224, 224), + pool_size=None, + crop_pct=.9, + interpolation='bicubic', + first_conv='', + classifier='', + ) return pretrained_cfg diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index e34de657..2c6e7eb7 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3): def _create_inception_v3(variant, pretrained=False, **kwargs): - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) aux_logits = kwargs.pop('aux_logits', False) if aux_logits: assert not kwargs.pop('features_only', False) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 59fd7849..8551feae 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c2ac376..0c9ac989 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .registry import register_model