diff --git a/train.py b/train.py index 44a0e292..42985e12 100755 --- a/train.py +++ b/train.py @@ -381,7 +381,7 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, - ml_decoder_head=args.use_ml_decoder_head) + use_ml_decoder_head=args.use_ml_decoder_head) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly