You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.6 KiB

import cv2
import os
import importlib
import numpy as np
from glob import glob
import torch
from torchvision.transforms import ToTensor
from utils.option import args
from utils.painter import Sketcher
def postprocess(image):
image = torch.clamp(image, -1., 1.)
image = (image + 1) / 2.0 * 255.0
image = image.permute(1, 2, 0)
image = image.cpu().numpy().astype(np.uint8)
return image
def demo(args):
# load images
img_list = []
for ext in ['*.jpg', '*.png']:
img_list.extend(glob(os.path.join(args.dir_image, ext)))
img_list.sort()
# Model and version
net = importlib.import_module('model.'+args.model)
model = net.InpaintGenerator(args)
model.load_state_dict(torch.load(args.pre_train, map_location='cpu'))
model.eval()
for fn in img_list:
filename = os.path.basename(fn).split('.')[0]
orig_img = cv2.resize(cv2.imread(fn, cv2.IMREAD_COLOR), (512, 512))
img_tensor = (ToTensor()(orig_img) * 2.0 - 1.0).unsqueeze(0)
h, w, c = orig_img.shape
mask = np.zeros([h, w, 1], np.uint8)
image_copy = orig_img.copy()
sketch = Sketcher(
'input', [image_copy, mask], lambda: ((255, 255, 255), (255, 255, 255)), args.thick, args.painter)
while True:
ch = cv2.waitKey()
if ch == 27:
print("quit!")
break
# inpaint by deep model
elif ch == ord(' '):
print('[**] inpainting ... ')
with torch.no_grad():
mask_tensor = (ToTensor()(mask)).unsqueeze(0)
masked_tensor = (img_tensor * (1 - mask_tensor).float()) + mask_tensor
pred_tensor = model(masked_tensor, mask_tensor)
comp_tensor = (pred_tensor * mask_tensor + img_tensor * (1 - mask_tensor))
pred_np = postprocess(pred_tensor[0])
masked_np = postprocess(masked_tensor[0])
comp_np = postprocess(comp_tensor[0])
cv2.imshow('pred_images', comp_np)
print('inpainting finish!')
# reset mask
elif ch == ord('r'):
img_tensor = (ToTensor()(orig_img) * 2.0 - 1.0).unsqueeze(0)
image_copy[:] = orig_img.copy()
mask[:] = 0
sketch.show()
print("[**] reset!")
# next case
elif ch == ord('n'):
print('[**] move to next image')
cv2.destroyAllWindows()
break
elif ch == ord('k'):
print('[**] apply existing processing to images, and keep editing!')
img_tensor = comp_tensor
image_copy[:] = comp_np.copy()
mask[:] = 0
sketch.show()
print("reset!")
elif ch == ord('+'):
sketch.large_thick()
elif ch == ord('-'):
sketch.small_thick()
# save results
if ch == ord('s'):
cv2.imwrite(os.path.join(args.outputs, f'{filename}_masked.png'), masked_np)
cv2.imwrite(os.path.join(args.outputs, f'{filename}_pred.png'), pred_np)
cv2.imwrite(os.path.join(args.outputs, f'{filename}_comp.png'), comp_np)
cv2.imwrite(os.path.join(args.outputs, f'{filename}_mask.png'), mask)
print('[**] save successfully!')
cv2.destroyAllWindows()
if ch == 27:
break
if __name__ == '__main__':
demo(args)