pull/9/merge
Dominic-ZZ 3 years ago committed by GitHub
commit 46787b117a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,10 +30,10 @@ class InpaintingData(Dataset):
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05), transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
transforms.ToTensor()]) transforms.ToTensor()])
self.mask_trans = transforms.Compose([ self.mask_trans = transforms.Compose([
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST), transforms.Resize(args.image_size),
transforms.RandomHorizontalFlip(), transforms.RandomHorizontalFlip(),
transforms.RandomRotation( transforms.RandomRotation(
(0, 45), interpolation=transforms.InterpolationMode.NEAREST), (0, 45)),
]) ])
@ -77,4 +77,4 @@ if __name__ == '__main__':
data = InpaintingData(args) data = InpaintingData(args)
print(len(data), len(data.mask_path)) print(len(data), len(data.mask_path))
img, mask, filename = data[0] img, mask, filename = data[0]
print(img.size(), mask.size(), filename) print(img.size(), mask.size(), filename)

@ -33,9 +33,9 @@ def main_worker(args, use_gpu=True):
# prepare dataset # prepare dataset
image_paths = [] image_paths = []
for ext in ['.jpg', '.png']: for ext in ['.jpg', '.png']:
image_paths.extend(glob(os.path.join(args.dir_image, '*'+ext))) image_paths.extend(glob(os.path.join(args.dir_test, '*'+ext)))
image_paths.sort() image_paths.sort()
mask_paths = sorted(glob(os.path.join(args.dir_mask, '*.png'))) mask_paths = sorted(glob(os.path.join(args.dir_mask,args.mask_type,'*.png')))
os.makedirs(args.outputs, exist_ok=True) os.makedirs(args.outputs, exist_ok=True)
# iteration through datasets # iteration through datasets

@ -9,8 +9,8 @@ parser.add_argument('--dir_mask', type=str, default='../../dataset',
help='mask dataset directory') help='mask dataset directory')
parser.add_argument('--data_train', type=str, default='places2', parser.add_argument('--data_train', type=str, default='places2',
help='dataname used for training') help='dataname used for training')
parser.add_argument('--data_test', type=str, default='places2', parser.add_argument('--dir_test', type=str, default='../datasets/test_imgs/',
help='dataname used for testing') help='test image dataset directory')
parser.add_argument('--image_size', type=int, default=512, parser.add_argument('--image_size', type=int, default=512,
help='image size used during training') help='image size used during training')
parser.add_argument('--mask_type', type=str, default='pconv', parser.add_argument('--mask_type', type=str, default='pconv',
@ -93,4 +93,4 @@ losses = list(args.rec_loss.split('+'))
args.rec_loss = {} args.rec_loss = {}
for l in losses: for l in losses:
weight, name = l.split('*') weight, name = l.split('*')
args.rec_loss[name] = float(weight) args.rec_loss[name] = float(weight)

Loading…
Cancel
Save