Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg)

pull/323/head
Ross Wightman 4 years ago
parent 587780e56b
commit 38d8f67570

@ -198,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
classifier_name = cfg['classifier'] classifier_name = cfg['classifier']
if num_classes == 1000 and cfg['num_classes'] == 1001: if num_classes == 1000 and cfg['num_classes'] == 1001:
# FIXME this special case is problematic as number of pretrained weight sources increases
# special case for imagenet trained models with extra background class in pretrained weights # special case for imagenet trained models with extra background class in pretrained weights
classifier_weight = state_dict[classifier_name + '.weight'] classifier_weight = state_dict[classifier_name + '.weight']
state_dict[classifier_name + '.weight'] = classifier_weight[1:] state_dict[classifier_name + '.weight'] = classifier_weight[1:]

@ -337,6 +337,9 @@ def main():
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript, scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.local_rank == 0: if args.local_rank == 0:
_logger.info('Model %s created, param count: %d' % _logger.info('Model %s created, param count: %d' %

@ -137,6 +137,9 @@ def validate(args):
in_chans=3, in_chans=3,
global_pool=args.gp, global_pool=args.gp,
scriptable=args.torchscript) scriptable=args.torchscript)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
if args.checkpoint: if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)

Loading…
Cancel
Save