diff --git a/train.py b/train.py index 929948d8..3943c7d0 100755 --- a/train.py +++ b/train.py @@ -539,7 +539,7 @@ def main(): loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], - batch_size=args.validation_batch_size, + batch_size=args.validation_batch_size or args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'],