From b4e216e377cd748cafc3fd24722bbece9779d1af Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 9 Feb 2021 17:33:43 -0800 Subject: [PATCH] Fix a few small things. --- timm/models/helpers.py | 2 +- timm/models/inception_v3.py | 4 ++-- validate.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d56cdc57..33744eb5 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -185,7 +185,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non state_dict = filter_fn(state_dict) input_convs = cfg.get('first_conv', None) - if input_convs is not None: + if input_convs is not None and in_chans != 3: if isinstance(input_convs, str): input_convs = (input_convs,) for input_conv_name in input_convs: diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index 9ae7105f..cdb1f1c0 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -32,12 +32,12 @@ default_cfgs = { # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) 'tf_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', - num_classes=1001, has_aux=False), + num_classes=1000, has_aux=False, label_offset=1), # my port of Tensorflow adversarially trained Inception V3 from # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz 'adv_inception_v3': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', - num_classes=1001, has_aux=False), + num_classes=1000, has_aux=False, label_offset=1), # from gluon pretrained models, best performing in terms of accuracy/loss metrics # https://gluon-cv.mxnet.io/model_zoo/classification.html 'gluon_inception_v3': _cfg( diff --git a/validate.py b/validate.py index 8ad9cb1f..83f66fa5 100755 --- a/validate.py +++ b/validate.py @@ -284,7 +284,7 @@ def main(): if args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True - model_names = list_models(pretrained=True) + model_names = list_models(pretrained=True, exclude_filters=['*in21k']) model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter