diff --git a/src/test.py b/src/test.py index 872493b..bcca0e3 100644 --- a/src/test.py +++ b/src/test.py @@ -35,7 +35,7 @@ def main_worker(args, use_gpu=True): for ext in ['.jpg', '.png']: image_paths.extend(glob(os.path.join(args.dir_image, '*'+ext))) image_paths.sort() - mask_paths = sorted(glob(os.path.join(args.dir_mask, '*.png'))) + mask_paths = sorted(glob(os.path.join(args.dir_mask,args.mask_type,'*.png'))) os.makedirs(args.outputs, exist_ok=True) # iteration through datasets