diff --git a/models/model_factory.py b/models/model_factory.py index 611e1728..09a8bb95 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -36,7 +36,7 @@ def create_model( else: raise RuntimeError('Unknown model (%s)' % model_name) - if checkpoint_path and not pretrained: + if checkpoint_path: load_checkpoint(model, checkpoint_path) return model diff --git a/validate.py b/validate.py index e0d0a538..5a09f0cd 100644 --- a/validate.py +++ b/validate.py @@ -54,6 +54,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true', def validate(args): + # might as well try to validate something + args.pretrained = args.pretrained or not args.checkpoint # create model model = create_model( @@ -62,10 +64,8 @@ def validate(args): in_chans=3, pretrained=args.pretrained) - if args.checkpoint and not args.pretrained: + if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) - else: - args.pretrained = True # might as well try to validate something... param_count = sum([m.numel() for m in model.parameters()]) print('Model %s created, param count: %d' % (args.model, param_count))