pull/5/head
Deniz Ugur 3 years ago
parent 58d30b58d9
commit 5b0a3a0234

@ -1,18 +1,27 @@
from .dataset import InpaintingData
from .dataset import TerrainDataset
from torch.utils.data import DataLoader
def sample_data(loader):
def sample_data(loader):
while True:
for batch in loader:
yield batch
def create_loader(args):
dataset = InpaintingData(args)
def create_loader(args):
dataset = TerrainDataset(
"data/download/MDRS/data/*.tif",
dataset_type="train",
randomize=True,
block_variance=1,
)
data_loader = DataLoader(
dataset, batch_size=args.batch_size//args.world_size,
shuffle=True, num_workers=args.num_workers, pin_memory=True)
dataset,
batch_size=args.batch_size // args.world_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
)
return sample_data(data_loader)

@ -1,80 +1,302 @@
import os
import math
import rasterio
import torch
from scipy.ndimage import zoom
from skimage.draw import rectangle_perimeter, line
from torch.utils.data import Dataset
from tqdm import tqdm
import numpy as np
from scipy import stats
from glob import glob
import random
from random import shuffle
from PIL import Image, ImageFilter
import torch
import torchvision.transforms.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class InpaintingData(Dataset):
def __init__(self, args):
super(Dataset, self).__init__()
self.w = self.h = args.image_size
self.mask_type = args.mask_type
# image and mask
self.image_path = []
for ext in ['*.jpg', '*.png']:
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png'))
# augmentation
self.img_trans = transforms.Compose([
transforms.RandomResizedCrop(args.image_size),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
transforms.ToTensor()])
self.mask_trans = transforms.Compose([
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(
(0, 45), interpolation=transforms.InterpolationMode.NEAREST),
])
class Helper:
def __init__(self) -> None:
super().__init__()
@staticmethod
def get_ranges(x):
bmax = np.max(x.reshape(-1, x.shape[2] ** 2), axis=1)
bmin = np.min(x.reshape(-1, x.shape[2] ** 2), axis=1)
return bmax - bmin
@staticmethod
def get_tri(x):
return np.apply_along_axis(Helper._get_tri, 1, x.reshape(-1, x.shape[2] ** 2))
@staticmethod
def _get_tri(x):
now = x.reshape(-1, 30)
tri = np.zeros_like(now)
d = 1
for i in range(1, 29):
for j in range(1, 29):
tri[i, j] = np.sqrt(
np.sum(
np.power(
now[i - d : i + d + 1, j - d : j + d + 1].flatten()
- now[i, j],
2,
)
)
)
return stats.trim_mean(tri.flatten(), 0.2)
class TerrainDataset(Dataset):
NAN = 0
def __init__(
self,
dataset_glob,
dataset_type,
patch_size=30,
sample_size=256,
observer_pad=50,
block_variance=4,
observer_height=0.75,
limit_samples=None,
randomize=True,
random_state=42,
usable_portion=1.0,
fast_load=False,
transform=None,
):
"""
dataset_glob -> glob to *.tif files (i.e. "data/MDRS/data/*.tif")
dataset_type -> train or validation
patch_size -> the 1m^2 area to read from .TIF
sample_size -> the 0.1m^2 res area to be trained sample size
observer_pad -> n pixels to pad before getting a random observer
block_variance -> how many different observer points
observer_height -> Observer Height
limit_samples -> Limit number of samples returned
randomize -> predictable randomize
random_state -> a value that gets added to seed
usable_portion -> What % of the data will be used
fast_load -> initialize from npy file, Warning: Dragons be aware
transform -> if there is any, PyTorch Transforms
"""
np.seterr(divide="ignore", invalid="ignore")
# * Set Dataset attributes
self.observer_height = observer_height
self.patch_size = patch_size
self.sample_size = sample_size
self.block_variance = block_variance
self.observer_pad = observer_pad
# * PyTorch Related Variables
self.transform = transform
# * Gather files
self.files = glob(dataset_glob)
self.dataset_type = dataset_type
self.usable_portion = usable_portion
self.limit_samples = limit_samples
self.randomize = False if fast_load else randomize
self.random_state = random_state
if self.randomize:
random.shuffle(self.files)
# * Build dataset dictionary
self.sample_dict = dict()
start = 0
for file in tqdm(self.files, ncols=100, disable=fast_load):
blocks, mask = self.get_blocks(file, return_mask=True)
if len(blocks) == 0:
continue
self.sample_dict[file] = {
"start": start,
"end": start + len(blocks[mask]),
"mask": mask,
"min": np.min(blocks[mask]),
"max": np.max(blocks[mask]),
"range": np.max(Helper.get_ranges(blocks[mask])),
}
start += len(blocks[mask])
del blocks
if fast_load:
break
self.data_min = min(self.sample_dict.values(), key=lambda x: x["min"])["min"]
self.data_max = max(self.sample_dict.values(), key=lambda x: x["max"])["max"]
self.data_range = max(self.sample_dict.values(), key=lambda x: x["range"])[
"range"
]
# * Check if limit_samples is enough for this dataset
if limit_samples is not None:
assert (
limit_samples <= self.get_len()
), "limit_samples cannot be bigger than dataset size"
# * Dataset state
self.current_file = None
self.current_blocks = None
def get_len(self):
key = list(self.sample_dict.keys())[-1]
return self.sample_dict[key]["end"]
def __len__(self):
return len(self.image_path)
if not self.limit_samples is None:
return self.limit_samples
return self.get_len()
def __getitem__(self, idx):
"""
returns (x, (ox, oy, oz)), y
"""
rel_idx = None
for file, info in self.sample_dict.items():
if idx >= info["start"] and idx < info["end"]:
rel_idx = idx - info["start"]
if self.current_file != file:
b = self.get_blocks(file)
self.current_blocks = b[info["mask"]]
self.current_file = file
break
current = np.copy(self.current_blocks[rel_idx])
current -= np.min(current)
current /= self.data_range
oh = self.observer_height / self.data_range
adjusted = self.get_adjusted(current)
viewshed, _ = self.viewshed(adjusted, oh, idx)
mask = np.isnan(viewshed).astype(np.uint8)
mask = torch.from_numpy(mask).float()
mask = mask.unsqueeze(0)
def __getitem__(self, index):
# load image
image = Image.open(self.image_path[index]).convert('RGB')
filename = os.path.basename(self.image_path[index])
target = torch.from_numpy(adjusted).float()
target = target.unsqueeze(0)
if self.mask_type == 'pconv':
index = np.random.randint(0, len(self.mask_path))
mask = Image.open(self.mask_path[index])
mask = mask.convert('L')
return target, mask, f"{self.current_file}-{idx}"
def viewshed(self, dem, oh, seed):
h, w = dem.shape
np.random.seed(seed + self.random_state)
rands = np.random.rand(h - self.observer_pad, w - self.observer_pad)
template = np.zeros_like(dem)
template[
self.observer_pad - self.observer_pad // 2 : h - self.observer_pad // 2,
self.observer_pad - self.observer_pad // 2 : w - self.observer_pad // 2,
] = rands
observer = tuple(np.argwhere(template == np.max(template))[0])
yp, xp = observer
zp = dem[observer] + oh
observer = (xp, yp, zp)
viewshed = np.copy(dem)
# * Find perimiter
rr, cc = rectangle_perimeter((1, 1), end=(h - 2, w - 2), shape=dem.shape)
# * Iterate through perimiter
for yc, xc in zip(rr, cc):
# * Form the line
ray_y, ray_x = line(yp, xp, yc, xc)
ray_z = dem[ray_y, ray_x]
m = (ray_z - zp) / np.hypot(ray_y - yp, ray_x - xp)
max_so_far = -np.inf
for yi, xi, mi in zip(ray_y, ray_x, m):
if mi < max_so_far:
viewshed[yi, xi] = np.nan
else:
max_so_far = mi
return viewshed, observer
def blockshaped(self, arr, nside):
"""
Return an array of shape (n, nside, nside) where
n * nside * nside = arr.size
If arr is a 2D array, the returned array should look like n subblocks with
each subblock preserving the "physical" layout of arr.
"""
h, w = arr.shape
assert h % nside == 0, "{} rows is not evenly divisble by {}".format(h, nside)
assert w % nside == 0, "{} cols is not evenly divisble by {}".format(w, nside)
return (
arr.reshape(h // nside, nside, -1, nside)
.swapaxes(1, 2)
.reshape(-1, nside, nside)
)
def get_adjusted(self, block):
zoomed = zoom(block, 10, order=1)
y, x = zoomed.shape
startx = x // 2 - (self.sample_size // 2)
starty = y // 2 - (self.sample_size // 2)
return zoomed[
starty : starty + self.sample_size, startx : startx + self.sample_size
]
def get_blocks(self, file, return_mask=False):
raster = rasterio.open(file)
grid = raster.read(1)
# Remove minimum
grid[grid == np.min(grid)] = np.nan
# Find the edges to cut from
NL = np.count_nonzero(np.isnan(grid[:, 0]))
NR = np.count_nonzero(np.isnan(grid[:, -1]))
NT = np.count_nonzero(np.isnan(grid[0, :]))
NB = np.count_nonzero(np.isnan(grid[-1, :]))
w, h = grid.shape
if NL > NR:
grid = grid[w % self.patch_size : w, 0:h]
else:
grid = grid[0 : w - (w % self.patch_size), 0:h]
w, h = grid.shape
if NT > NB:
grid = grid[0:w, h % self.patch_size : h]
else:
grid = grid[0:w, 0 : h - (h % self.patch_size)]
blocks = self.blockshaped(grid, self.patch_size)
# * Randomize
if self.randomize:
np.random.seed(self.random_state)
blocks = np.random.shuffle(blocks)
# * Remove blocks that contain nans
mask = ~np.isnan(blocks).any(axis=1).any(axis=1)
blocks = blocks[mask]
if self.dataset_type == "train":
blocks = blocks[: int(len(blocks) * self.usable_portion)]
else:
mask = np.zeros((self.h, self.w)).astype(np.uint8)
mask[self.h//4:self.h//4*3, self.w//4:self.w//4*3] = 1
mask = Image.fromarray(m).convert('L')
# augment
image = self.img_trans(image) * 2. - 1.
mask = F.to_tensor(self.mask_trans(mask))
return image, mask, filename
if __name__ == '__main__':
from attrdict import AttrDict
args = {
'dir_image': '../../../dataset',
'data_train': 'places2',
'dir_mask': '../../../dataset',
'mask_type': 'pconv',
'image_size': 512
}
args = AttrDict(args)
data = InpaintingData(args)
print(len(data), len(data.mask_path))
img, mask, filename = data[0]
print(img.size(), mask.size(), filename)
blocks = blocks[int(len(blocks) * self.usable_portion) :]
# * Add Variance
blocks = np.repeat(blocks, self.block_variance, axis=0)
if return_mask:
# * Further filter remeaning data in relation to z-score
ranges = Helper.get_ranges(blocks)
mask_ru = np.abs(stats.zscore(ranges)) < 2
mask_rl = np.abs(stats.zscore(ranges)) > 0.2
# # * Terrain Ruggedness Index
# tri = Helper.get_tri(blocks)
# mask_tu = np.abs(stats.zscore(tri)) < 2
# mask_tl = np.abs(stats.zscore(tri)) > 0.05
mask = mask_ru & mask_rl # & mask_tu & mask_tl
return blocks, mask
return blocks

Loading…
Cancel
Save