You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
2.0 KiB

import os
import argparse
import importlib
import numpy as np
from PIL import Image
from glob import glob
import torch
import torch.nn as nn
from torchvision.transforms import ToTensor
from utils.option import args
def postprocess(image):
image = torch.clamp(image, -1., 1.)
image = (image + 1) / 2.0 * 255.0
image = image.permute(1, 2, 0)
image = image.cpu().numpy().astype(np.uint8)
return Image.fromarray(image)
def main_worker(args, use_gpu=True):
device = torch.device('cuda') if use_gpu else torch.device('cpu')
# Model and version
net = importlib.import_module('model.'+args.model)
model = net.InpaintGenerator(args).cuda()
model.load_state_dict(torch.load(args.pre_train, map_location='cuda'))
model.eval()
# prepare dataset
image_paths = []
for ext in ['.jpg', '.png']:
image_paths.extend(glob(os.path.join(args.dir_image, '*'+ext)))
image_paths.sort()
mask_paths = sorted(glob(os.path.join(args.dir_mask, '*.png')))
os.makedirs(args.outputs, exist_ok=True)
# iteration through datasets
for ipath, mpath in zip(image_paths, mask_paths):
image = ToTensor()(Image.open(ipath).convert('RGB'))
image = (image * 2.0 - 1.0).unsqueeze(0)
mask = ToTensor()(Image.open(mpath).convert('L'))
mask = mask.unsqueeze(0)
image, mask = image.cuda(), mask.cuda()
image_masked = image * (1 - mask.float()) + mask
with torch.no_grad():
pred_img = model(image_masked, mask)
comp_imgs = (1 - mask) * image + mask * pred_img
image_name = os.path.basename(ipath).split('.')[0]
postprocess(image_masked[0]).save(os.path.join(args.outputs, f'{image_name}_masked.png'))
postprocess(pred_img[0]).save(os.path.join(args.outputs, f'{image_name}_pred.png'))
postprocess(comp_imgs[0]).save(os.path.join(args.outputs, f'{image_name}_comp.png'))
print(f'saving to {os.path.join(args.outputs, image_name)}')
if __name__ == '__main__':
main_worker(args)