Fix a few small things.

pull/419/head
Ross Wightman 4 years ago
parent dc85e5a237
commit b4e216e377

@ -185,7 +185,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
state_dict = filter_fn(state_dict) state_dict = filter_fn(state_dict)
input_convs = cfg.get('first_conv', None) input_convs = cfg.get('first_conv', None)
if input_convs is not None: if input_convs is not None and in_chans != 3:
if isinstance(input_convs, str): if isinstance(input_convs, str):
input_convs = (input_convs,) input_convs = (input_convs,)
for input_conv_name in input_convs: for input_conv_name in input_convs:

@ -32,12 +32,12 @@ default_cfgs = {
# my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz) # my port of Tensorflow SLIM weights (http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)
'tf_inception_v3': _cfg( 'tf_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_inception_v3-e0069de4.pth',
num_classes=1001, has_aux=False), num_classes=1000, has_aux=False, label_offset=1),
# my port of Tensorflow adversarially trained Inception V3 from # my port of Tensorflow adversarially trained Inception V3 from
# http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz # http://download.tensorflow.org/models/adv_inception_v3_2017_08_18.tar.gz
'adv_inception_v3': _cfg( 'adv_inception_v3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/adv_inception_v3-9e27bd63.pth',
num_classes=1001, has_aux=False), num_classes=1000, has_aux=False, label_offset=1),
# from gluon pretrained models, best performing in terms of accuracy/loss metrics # from gluon pretrained models, best performing in terms of accuracy/loss metrics
# https://gluon-cv.mxnet.io/model_zoo/classification.html # https://gluon-cv.mxnet.io/model_zoo/classification.html
'gluon_inception_v3': _cfg( 'gluon_inception_v3': _cfg(

@ -284,7 +284,7 @@ def main():
if args.model == 'all': if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints # validate all models in a list of names with pretrained checkpoints
args.pretrained = True args.pretrained = True
model_names = list_models(pretrained=True) model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model): elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter # model name doesn't exist, try as wildcard filter

Loading…
Cancel
Save