pull/9/merge
Dominic-ZZ 4 years ago committed by GitHub
commit 46787b117a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,10 +30,10 @@ class InpaintingData(Dataset):
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
transforms.ToTensor()]) transforms.ToTensor()])
self.mask_trans = transforms.Compose([ self.mask_trans = transforms.Compose([
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.Resize(args.image_size),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.RandomRotation( transforms.RandomRotation(
(0, 45), interpolation=transforms.InterpolationMode.NEAREST), (0, 45)),
]) ])

@ -33,9 +33,9 @@ def main_worker(args, use_gpu=True):
# prepare dataset # prepare dataset
image_paths = [] image_paths = []
for ext in ['.jpg', '.png']: for ext in ['.jpg', '.png']:
image_paths.extend(glob(os.path.join(args.dir_image, '*'+ext))) image_paths.extend(glob(os.path.join(args.dir_test, '*'+ext)))
image_paths.sort() image_paths.sort()
mask_paths = sorted(glob(os.path.join(args.dir_mask, '*.png'))) mask_paths = sorted(glob(os.path.join(args.dir_mask,args.mask_type,'*.png')))
os.makedirs(args.outputs, exist_ok=True) os.makedirs(args.outputs, exist_ok=True)
# iteration through datasets # iteration through datasets

@ -9,8 +9,8 @@ parser.add_argument('--dir_mask', type=str, default='../../dataset',
help='mask dataset directory') help='mask dataset directory')
parser.add_argument('--data_train', type=str, default='places2', parser.add_argument('--data_train', type=str, default='places2',
help='dataname used for training') help='dataname used for training')
parser.add_argument('--data_test', type=str, default='places2', parser.add_argument('--dir_test', type=str, default='../datasets/test_imgs/',
help='dataname used for testing') help='test image dataset directory')
parser.add_argument('--image_size', type=int, default=512, parser.add_argument('--image_size', type=int, default=512,
help='image size used during training') help='image size used during training')
parser.add_argument('--mask_type', type=str, default='pconv', parser.add_argument('--mask_type', type=str, default='pconv',

Loading…
Cancel
Save