You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
80 lines
2.5 KiB
80 lines
2.5 KiB
3 years ago
|
import os
|
||
|
import math
|
||
|
import numpy as np
|
||
|
from glob import glob
|
||
|
|
||
|
from random import shuffle
|
||
|
from PIL import Image, ImageFilter
|
||
|
|
||
|
import torch
|
||
|
import torchvision.transforms.functional as F
|
||
|
import torchvision.transforms as transforms
|
||
|
from torch.utils.data import Dataset, DataLoader
|
||
|
|
||
|
class InpaintingData(Dataset):
|
||
|
def __init__(self, args):
|
||
|
super(Dataset, self).__init__()
|
||
|
self.w = self.h = args.image_size
|
||
|
self.mask_type = args.mask_type
|
||
|
|
||
|
# image and mask
|
||
|
self.image_path = []
|
||
|
for ext in ['*.jpg', '*.png']:
|
||
|
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
|
||
|
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png'))
|
||
|
|
||
|
# augmentation
|
||
|
self.img_trans = transforms.Compose([
|
||
|
transforms.RandomResizedCrop(args.image_size),
|
||
|
transforms.RandomHorizontalFlip(),
|
||
|
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
|
||
|
transforms.ToTensor()])
|
||
|
self.mask_trans = transforms.Compose([
|
||
|
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST),
|
||
|
transforms.RandomHorizontalFlip(),
|
||
|
transforms.RandomRotation(
|
||
|
(0, 45), interpolation=transforms.InterpolationMode.NEAREST),
|
||
|
])
|
||
|
|
||
|
|
||
|
def __len__(self):
|
||
|
return len(self.image_path)
|
||
|
|
||
|
def __getitem__(self, index):
|
||
|
# load image
|
||
|
image = Image.open(self.image_path[index]).convert('RGB')
|
||
|
filename = os.path.basename(self.image_path[index])
|
||
|
|
||
|
if self.mask_type == 'pconv':
|
||
|
index = np.random.randint(0, len(self.mask_path))
|
||
|
mask = Image.open(self.mask_path[index])
|
||
|
mask = mask.convert('L')
|
||
|
else:
|
||
|
mask = np.zeros((self.h, self.w)).astype(np.uint8)
|
||
|
mask[self.h//4:self.h//4*3, self.w//4:self.w//4*3] = 1
|
||
|
mask = Image.fromarray(m).convert('L')
|
||
|
|
||
|
# augment
|
||
|
image = self.img_trans(image) * 2. - 1.
|
||
|
mask = F.to_tensor(self.mask_trans(mask))
|
||
|
|
||
|
return image, mask, filename
|
||
|
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
from attrdict import AttrDict
|
||
|
args = {
|
||
|
'dir_image': '../../../dataset',
|
||
|
'data_train': 'places2',
|
||
|
'dir_mask': '../../../dataset',
|
||
|
'mask_type': 'pconv',
|
||
|
'image_size': 512
|
||
|
}
|
||
|
args = AttrDict(args)
|
||
|
|
||
|
data = InpaintingData(args)
|
||
|
print(len(data), len(data.mask_path))
|
||
|
img, mask, filename = data[0]
|
||
|
print(img.size(), mask.size(), filename)
|