diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 96f551e3..562a01c5 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -198,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non classifier_name = cfg['classifier'] if num_classes == 1000 and cfg['num_classes'] == 1001: + # FIXME this special case is problematic as number of pretrained weight sources increases # special case for imagenet trained models with extra background class in pretrained weights classifier_weight = state_dict[classifier_name + '.weight'] state_dict[classifier_name + '.weight'] = classifier_weight[1:] diff --git a/train.py b/train.py index 94c417b4..aa8e6553 100755 --- a/train.py +++ b/train.py @@ -337,6 +337,9 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly if args.local_rank == 0: _logger.info('Model %s created, param count: %d' % diff --git a/validate.py b/validate.py index be977cc2..5b5f98cf 100755 --- a/validate.py +++ b/validate.py @@ -137,6 +137,9 @@ def validate(args): in_chans=3, global_pool=args.gp, scriptable=args.torchscript) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema)