Fix pretrained override logic for validate, checkpoint always trump pretrained flag during model create

pull/13/head
Ross Wightman 6 years ago
parent 0e1fd11ad8
commit b9f8d40b10

@ -36,7 +36,7 @@ def create_model(
else: else:
raise RuntimeError('Unknown model (%s)' % model_name) raise RuntimeError('Unknown model (%s)' % model_name)
if checkpoint_path and not pretrained: if checkpoint_path:
load_checkpoint(model, checkpoint_path) load_checkpoint(model, checkpoint_path)
return model return model

@ -54,6 +54,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
def validate(args): def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
# create model # create model
model = create_model( model = create_model(
@ -62,10 +64,8 @@ def validate(args):
in_chans=3, in_chans=3,
pretrained=args.pretrained) pretrained=args.pretrained)
if args.checkpoint and not args.pretrained: if args.checkpoint:
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)
else:
args.pretrained = True # might as well try to validate something...
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
print('Model %s created, param count: %d' % (args.model, param_count)) print('Model %s created, param count: %d' % (args.model, param_count))

Loading…
Cancel
Save