Improve resolve_pretrained_cfg behaviour when no cfg exists, warn instead of crash. Improve usability ex #1311

pull/1317/head
Ross Wightman 3 years ago
parent 879df47c0a
commit 7d657d2ef4

@ -455,18 +455,27 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
filter_kwargs(kwargs, names=kwargs_filter) filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): def resolve_pretrained_cfg(variant: str, **kwargs):
pretrained_cfg = kwargs.pop('pretrained_cfg', None)
if pretrained_cfg and isinstance(pretrained_cfg, dict): if pretrained_cfg and isinstance(pretrained_cfg, dict):
# highest priority, pretrained_cfg available and passed explicitly # highest priority, pretrained_cfg available and passed in args
return deepcopy(pretrained_cfg) return deepcopy(pretrained_cfg)
if kwargs and 'pretrained_cfg' in kwargs: # fallback to looking up pretrained cfg in model registry by variant identifier
# next highest, pretrained_cfg in a kwargs dict, pop and return
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
if pretrained_cfg:
return deepcopy(pretrained_cfg)
# lookup pretrained cfg in model registry by variant
pretrained_cfg = get_pretrained_cfg(variant) pretrained_cfg = get_pretrained_cfg(variant)
assert pretrained_cfg if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {variant} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = dict(
url='',
num_classes=1000,
input_size=(3, 224, 224),
pool_size=None,
crop_pct=.9,
interpolation='bicubic',
first_conv='',
classifier='',
)
return pretrained_cfg return pretrained_cfg

@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3):
def _create_inception_v3(variant, pretrained=False, **kwargs): def _create_inception_v3(variant, pretrained=False, **kwargs):
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
aux_logits = kwargs.pop('aux_logits', False) aux_logits = kwargs.pop('aux_logits', False)
if aux_logits: if aux_logits:
assert not kwargs.pop('features_only', False) assert not kwargs.pop('features_only', False)

@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg, pretrained_cfg=pretrained_cfg,

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model

Loading…
Cancel
Save