You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

48 lines
1.5 KiB

3 years ago
import argparse
import numpy as np
from tqdm import tqdm
from glob import glob
from PIL import Image
from multiprocessing import Pool
from metric import metric as module_metric
parser = argparse.ArgumentParser(description='Image Inpainting')
parser.add_argument('--real_dir', required=True, type=str)
parser.add_argument('--fake_dir', required=True, type=str)
parser.add_argument("--metric", type=str, nargs="+")
args = parser.parse_args()
def read_img(name_pair):
rname, fname = name_pair
rimg = Image.open(rname)
fimg = Image.open(fname)
return np.array(rimg), np.array(fimg)
def main(num_worker=8):
real_names = sorted(list(glob(f'{args.real_dir}/*.png')))
fake_names = sorted(list(glob(f'{args.fake_dir}/*.png')))
print(f'real images: {len(real_names)}, fake images: {len(fake_names)}')
real_images = []
fake_images = []
pool = Pool(num_worker)
for rimg, fimg in tqdm(pool.imap_unordered(read_img, zip(real_names, fake_names)), total=len(real_names), desc='loading images'):
real_images.append(rimg)
fake_images.append(fimg)
# metrics prepare for image assesments
metrics = {met: getattr(module_metric, met) for met in args.metric}
evaluation_scores = {key: 0 for key,val in metrics.items()}
for key, val in metrics.items():
evaluation_scores[key] = val(real_images, fake_images, num_worker=num_worker)
print(' '.join(['{}: {:6f},'.format(key, val) for key,val in evaluation_scores.items()]))
if __name__ == '__main__':
main()