parent
4180346273
commit
000b2223c4
@ -0,0 +1,326 @@
|
||||
import os
|
||||
import io
|
||||
import scipy
|
||||
import torch
|
||||
import random
|
||||
import time
|
||||
import lmdb
|
||||
import math
|
||||
import argparse
|
||||
import pickle
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw
|
||||
from skimage.feature import canny
|
||||
from xml.etree import ElementTree as ET
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.transforms.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
|
||||
def random_bbox(img_width, img_height):
|
||||
"""Generate a random tlhw regular mask
|
||||
"""
|
||||
vertical_margin = horizontal_margin = 0
|
||||
mask_height = img_height//2
|
||||
mask_width = img_width//2
|
||||
max_delta_height = img_height // 8
|
||||
max_delta_width = img_width // 8
|
||||
maxt = img_height - vertical_margin - mask_height
|
||||
maxl = img_width - horizontal_margin - mask_width
|
||||
mask = np.zeros((img_height, img_width), np.uint8)
|
||||
|
||||
t = np.random.randint(vertical_margin, maxt)
|
||||
l = np.random.randint(horizontal_margin, maxl)
|
||||
h = np.random.randint(max_delta_height//2+1)
|
||||
w = np.random.randint(max_delta_width//2+1)
|
||||
mask[t+h:t+mask_height-h,
|
||||
l+w:l+mask_width-w] = 1
|
||||
return mask
|
||||
|
||||
|
||||
def center_bbox(img_width, img_height):
|
||||
"""Generate a center square mask
|
||||
"""
|
||||
mask = np.zeros((img_height, img_width), np.uint8)
|
||||
mask[:, img_height//4:img_height//4*3,
|
||||
img_width//4:img_width//4*3] = 1
|
||||
return mask
|
||||
|
||||
|
||||
def random_stroke(img_width, img_height):
|
||||
min_num_vertex = 4
|
||||
max_num_vertex = 12
|
||||
mean_angle = 2*math.pi / 5
|
||||
angle_range = 2*math.pi / 15
|
||||
min_width = 12
|
||||
max_width = 40
|
||||
average_radius = math.sqrt(img_height*img_height+img_width*img_width) / 8
|
||||
mask = Image.new('L', (img_width, img_height), 0)
|
||||
|
||||
steps = 6
|
||||
for _ in range(np.random.randint(1, steps)):
|
||||
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
||||
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
||||
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
||||
angles = []
|
||||
vertex = []
|
||||
for i in range(num_vertex):
|
||||
if i % 2 == 0:
|
||||
angles.append(
|
||||
2*math.pi - np.random.uniform(angle_min, angle_max))
|
||||
else:
|
||||
angles.append(np.random.uniform(angle_min, angle_max))
|
||||
|
||||
h, w = mask.size
|
||||
vertex.append((int(np.random.randint(0, w)),
|
||||
int(np.random.randint(0, h))))
|
||||
for i in range(num_vertex):
|
||||
r = np.clip(
|
||||
np.random.normal(loc=average_radius,
|
||||
scale=average_radius//2),
|
||||
0, 2*average_radius)
|
||||
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
||||
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
||||
vertex.append((int(new_x), int(new_y)))
|
||||
|
||||
draw = ImageDraw.Draw(mask)
|
||||
width = int(np.random.uniform(min_width, max_width))
|
||||
draw.line(vertex, fill=1, width=width)
|
||||
for v in vertex:
|
||||
draw.ellipse((v[0] - width//2,
|
||||
v[1] - width//2,
|
||||
v[0] + width//2,
|
||||
v[1] + width//2),
|
||||
fill=1)
|
||||
|
||||
if np.random.normal() > 0:
|
||||
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
if np.random.normal() > 0:
|
||||
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
||||
mask = np.asarray(mask, np.uint8)
|
||||
return mask
|
||||
|
||||
|
||||
def bbox2np(img_width, img_height, bbox, pad_rate=0):
|
||||
mask = np.zeros((img_height, img_width), np.uint8)
|
||||
for (xmin, ymin, xmax, ymax) in bbox:
|
||||
pad = int(pad_rate*min(xmax-xmin, ymax-ymin))
|
||||
mask[max(0,ymin-pad):min(ymax+pad, img_height),
|
||||
max(0,xmin-pad):min(xmax+pad, img_width)] = 1
|
||||
return Image.fromarray(mask*255)
|
||||
|
||||
|
||||
def outside_xml(oriw, orih, bbox):
|
||||
if bbox is None:
|
||||
return random_bbox(oriw, orih)
|
||||
|
||||
mask = np.zeros((orih, oriw))
|
||||
logo_mask = np.array(bbox2np(oriw, orih, bbox))//255
|
||||
random_mask = small_block(oriw, orih)
|
||||
|
||||
if np.sum(logo_mask) < 0.6*oriw*orih:
|
||||
mask = ((mask + random_mask)>0) * (1-logo_mask)
|
||||
iters = 5
|
||||
while np.sum(mask) < 0.1*oriw*orih and iters>0:
|
||||
random_mask = small_block(oriw, orih)
|
||||
mask = ((mask+random_mask)>0) * (1-logo_mask)
|
||||
iters -= 1
|
||||
else:
|
||||
mask = small_block(oriw, orih)
|
||||
return np.array(mask).astype(np.uint8)
|
||||
|
||||
|
||||
def small_block(width, height, nums=3):
|
||||
mask = np.zeros((height, width)).astype(np.uint8)
|
||||
margin_width = width//8
|
||||
margin_height = height//8
|
||||
for i in range(nums):
|
||||
x = random.randint(margin_width, width - margin_width)
|
||||
y = random.randint(margin_height, height - margin_height)
|
||||
w = random.randint(margin_width, width//2)
|
||||
h = random.randint(margin_height, height//2)
|
||||
mask[y:min(y+h, height), x:min(x+w, width)] = 1
|
||||
return mask
|
||||
|
||||
|
||||
|
||||
class LmdbReader(object):
|
||||
lmdb_env = None
|
||||
|
||||
def __init__(self):
|
||||
super(LmdbReader, self).__init__()
|
||||
|
||||
@staticmethod
|
||||
def build_lmdb_env(lmdb_path):
|
||||
if LmdbReader.lmdb_env is None:
|
||||
LmdbReader.lmdb_env = lmdb.open(lmdb_path, max_readers=64, readonly=True,
|
||||
lock=False, readahead=False, meminit=False,)
|
||||
return LmdbReader.lmdb_env
|
||||
|
||||
@staticmethod
|
||||
def read(path, key, val_type='img'):
|
||||
env = LmdbReader.build_lmdb_env(path)
|
||||
with env.begin(write=False) as txn:
|
||||
try:
|
||||
if val_type == 'int':
|
||||
val = txn.get(key).decode('utf-8')
|
||||
val = int(val)
|
||||
elif val_type == 'img':
|
||||
val = io.BytesIO(txn.get(key))
|
||||
val = Image.open(val)
|
||||
elif val_type == 'list':
|
||||
val = io.BytesIO(txn.get(key))
|
||||
val = pickle.load(val)
|
||||
val = list(val)
|
||||
except:
|
||||
val = None
|
||||
return val
|
||||
|
||||
|
||||
class Dataset(torch.utils.data.Dataset):
|
||||
def __init__(self, path, size, mask_type='bbox', split='train'):
|
||||
super(Dataset, self).__init__()
|
||||
self.lmdb_path = path
|
||||
self.mask_type = mask_type
|
||||
self.split = split
|
||||
self.size = size
|
||||
self.train_total = LmdbReader.read(path, 'train-total'.encode('utf-8'), val_type='int')
|
||||
self.test_total = LmdbReader.read(path, 'test-total'.encode('utf-8'), val_type='int')
|
||||
|
||||
self._train_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]
|
||||
)
|
||||
|
||||
self._test_transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ]
|
||||
)
|
||||
|
||||
|
||||
def set_split(self, split='train'):
|
||||
if split == 'train':
|
||||
self.split = split
|
||||
else:
|
||||
self.split = 'test'
|
||||
|
||||
|
||||
def __len__(self):
|
||||
if self.split == 'train':
|
||||
return self.train_total
|
||||
else:
|
||||
return self.test_total
|
||||
|
||||
|
||||
def set_subset(self, start, end):
|
||||
self.data = self.data[start:end]
|
||||
|
||||
|
||||
def __getitem__(self, index):
|
||||
key = f'{self.split}-{str(index).zfill(7)}-image'
|
||||
orig_img = LmdbReader.read(self.lmdb_path, key.encode('utf-8'), val_type='img')
|
||||
if self.split == 'train':
|
||||
# obtain mask
|
||||
if self.mask_type != 'xml':
|
||||
img = transforms.Compose([transforms.RandomHorizontalFlip(),
|
||||
transforms.Resize(self.size),
|
||||
transforms.RandomCrop(self.size),])(orig_img)
|
||||
if self.mask_type == 'bbox':
|
||||
mask = random_bbox(self.size, self.size)
|
||||
elif self.mask_type == 'stroke':
|
||||
mask = random_stroke(self.size, self.size)
|
||||
elif self.mask_type == 'comp':
|
||||
mask = np.array(np.logical_or(random_bbox(self.size, self.size),
|
||||
random_stroke(self.size, self.size))).astype(np.uint8)
|
||||
mask = Image.fromarray(mask*255)
|
||||
else: # self.mask_type == 'xml':
|
||||
oriw, orih = orig_img.size
|
||||
label_name = f'{self.split}-{str(index).zfill(7)}-label'
|
||||
bbox = LmdbReader.read(self.lmdb_path, label_name.encode('utf-8'), val_type='list')
|
||||
mask = Image.fromarray(outside_xml(oriw, orih, bbox)*255)
|
||||
|
||||
# resize
|
||||
rate = 512.0 / min(orih, oriw)
|
||||
neww, newh = int(oriw*rate), int(orih*rate)
|
||||
img = orig_img.resize((neww, newh), Image.BILINEAR)
|
||||
mask = mask.resize((neww, newh), Image.NEAREST)
|
||||
|
||||
x = random.randint(0, np.maximum(0, img.size[0] - self.size))
|
||||
y = random.randint(0, np.maximum(0, img.size[1] - self.size))
|
||||
img = img.crop((x, y, x+self.size, y+self.size))
|
||||
mask = mask.crop((x, y, x+self.size, y+self.size))
|
||||
|
||||
if np.random.normal() > 0:
|
||||
mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
img = img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
|
||||
if np.sum(np.array(mask)) == 0:
|
||||
mask = Image.fromarray(small_block(self.size, self.size)*255)
|
||||
|
||||
img = self._train_transform(img)
|
||||
return key, img, transforms.ToTensor()(mask)
|
||||
else:
|
||||
oriw, orih = orig_img.size
|
||||
rate = (self.size + 0.0) / max(orih, oriw)
|
||||
neww, newh = int(oriw*rate), int(orih*rate)
|
||||
img = orig_img.resize((neww, newh), Image.BILINEAR)
|
||||
|
||||
label_name = f'{self.split}-{str(index).zfill(7)}-label'
|
||||
bbox = LmdbReader.read(self.lmdb_path, label_name.encode('utf-8'), val_type='list')
|
||||
if bbox is not None:
|
||||
orig_mask = bbox2np(oriw, orih, bbox)
|
||||
mask = orig_mask.resize((neww, newh), Image.NEAREST)
|
||||
else:
|
||||
mask = random_bbox(neww, newh)
|
||||
mask = Image.fromarray(mask*255)
|
||||
|
||||
img = F.pad(img, (0, 0, max(512 - neww, 0), max(512 - newh, 0)), fill=0, padding_mode='reflect')
|
||||
mask = F.pad(mask, (0, 0, max(512-neww, 0), max(512 - newh, 0)), fill=0, padding_mode='reflect')
|
||||
|
||||
orig_img = self._test_transform(orig_img)
|
||||
img = self._test_transform(img)
|
||||
mask = transforms.ToTensor()(mask)
|
||||
return key, orig_img, img, mask
|
||||
|
||||
def create_iterator(self, batch_size):
|
||||
while True:
|
||||
sample_loader = DataLoader(
|
||||
dataset=self, batch_size=batch_size, drop_last=True)
|
||||
for item in sample_loader:
|
||||
yield item
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description="ca")
|
||||
parser.add_argument('--lmdb_path', type=str, default='/data07/t-yazen/lsun_data/logos')
|
||||
parser.add_argument('--size', type=int, default=512)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
d = Dataset(args.lmdb_path, args.size)
|
||||
d.set_split('test')
|
||||
print(len(d), ' for testing')
|
||||
for i in range(5):
|
||||
key, orig_img, img, mask = d[i]
|
||||
print(orig_img.size(), np.unique(mask.numpy()))
|
||||
orig_img = (orig_img.permute(1,2,0).numpy()+1)/2*255
|
||||
orig_img = Image.fromarray(orig_img.astype(np.uint8))
|
||||
|
||||
oriw, orih = orig_img.size
|
||||
rate = (args.size+0.0) / max(orih, oriw)
|
||||
neww, newh = int(oriw*rate), int(orih*rate)
|
||||
|
||||
img = img*(1.-mask) + mask
|
||||
img = (img.permute(1,2,0).numpy()+1)/2*255
|
||||
img = Image.fromarray(img.astype(np.uint8))
|
||||
mask = mask.squeeze().numpy()
|
||||
mask = Image.fromarray(mask.astype(np.uint8))
|
||||
img = F.crop(img, 0, 0, newh, neww)
|
||||
mask = F.crop(mask, 0, 0, newh, neww)
|
||||
print(oriw, orih, neww, newh)
|
||||
mask = np.expand_dims(mask, axis=-1)
|
||||
img = np.array(np.array(img)*(1.-mask) + mask*255).astype(np.uint8)
|
||||
Image.fromarray(img).save(f'comp-{key}.png')
|
Loading…
Reference in new issue