parent
58d30b58d9
commit
5b0a3a0234
@ -1,18 +1,27 @@
|
||||
from .dataset import InpaintingData
|
||||
from .dataset import TerrainDataset
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def sample_data(loader):
|
||||
def sample_data(loader):
|
||||
while True:
|
||||
for batch in loader:
|
||||
yield batch
|
||||
|
||||
|
||||
def create_loader(args):
|
||||
dataset = InpaintingData(args)
|
||||
def create_loader(args):
|
||||
dataset = TerrainDataset(
|
||||
"data/download/MDRS/data/*.tif",
|
||||
dataset_type="train",
|
||||
randomize=True,
|
||||
block_variance=1,
|
||||
)
|
||||
data_loader = DataLoader(
|
||||
dataset, batch_size=args.batch_size//args.world_size,
|
||||
shuffle=True, num_workers=args.num_workers, pin_memory=True)
|
||||
|
||||
dataset,
|
||||
batch_size=args.batch_size // args.world_size,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
return sample_data(data_loader)
|
@ -1,80 +1,302 @@
|
||||
import os
|
||||
import math
|
||||
import rasterio
|
||||
import torch
|
||||
from scipy.ndimage import zoom
|
||||
from skimage.draw import rectangle_perimeter, line
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from scipy import stats
|
||||
from glob import glob
|
||||
import random
|
||||
|
||||
from random import shuffle
|
||||
from PIL import Image, ImageFilter
|
||||
|
||||
import torch
|
||||
import torchvision.transforms.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class InpaintingData(Dataset):
|
||||
def __init__(self, args):
|
||||
super(Dataset, self).__init__()
|
||||
self.w = self.h = args.image_size
|
||||
self.mask_type = args.mask_type
|
||||
|
||||
# image and mask
|
||||
self.image_path = []
|
||||
for ext in ['*.jpg', '*.png']:
|
||||
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
|
||||
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png'))
|
||||
|
||||
# augmentation
|
||||
self.img_trans = transforms.Compose([
|
||||
transforms.RandomResizedCrop(args.image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
|
||||
transforms.ToTensor()])
|
||||
self.mask_trans = transforms.Compose([
|
||||
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.RandomRotation(
|
||||
(0, 45), interpolation=transforms.InterpolationMode.NEAREST),
|
||||
])
|
||||
|
||||
|
||||
class Helper:
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def get_ranges(x):
|
||||
bmax = np.max(x.reshape(-1, x.shape[2] ** 2), axis=1)
|
||||
bmin = np.min(x.reshape(-1, x.shape[2] ** 2), axis=1)
|
||||
return bmax - bmin
|
||||
|
||||
@staticmethod
|
||||
def get_tri(x):
|
||||
return np.apply_along_axis(Helper._get_tri, 1, x.reshape(-1, x.shape[2] ** 2))
|
||||
|
||||
@staticmethod
|
||||
def _get_tri(x):
|
||||
now = x.reshape(-1, 30)
|
||||
tri = np.zeros_like(now)
|
||||
|
||||
d = 1
|
||||
for i in range(1, 29):
|
||||
for j in range(1, 29):
|
||||
tri[i, j] = np.sqrt(
|
||||
np.sum(
|
||||
np.power(
|
||||
now[i - d : i + d + 1, j - d : j + d + 1].flatten()
|
||||
- now[i, j],
|
||||
2,
|
||||
)
|
||||
)
|
||||
)
|
||||
return stats.trim_mean(tri.flatten(), 0.2)
|
||||
|
||||
|
||||
class TerrainDataset(Dataset):
|
||||
NAN = 0
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_glob,
|
||||
dataset_type,
|
||||
patch_size=30,
|
||||
sample_size=256,
|
||||
observer_pad=50,
|
||||
block_variance=4,
|
||||
observer_height=0.75,
|
||||
limit_samples=None,
|
||||
randomize=True,
|
||||
random_state=42,
|
||||
usable_portion=1.0,
|
||||
fast_load=False,
|
||||
transform=None,
|
||||
):
|
||||
"""
|
||||
dataset_glob -> glob to *.tif files (i.e. "data/MDRS/data/*.tif")
|
||||
dataset_type -> train or validation
|
||||
patch_size -> the 1m^2 area to read from .TIF
|
||||
sample_size -> the 0.1m^2 res area to be trained sample size
|
||||
observer_pad -> n pixels to pad before getting a random observer
|
||||
block_variance -> how many different observer points
|
||||
observer_height -> Observer Height
|
||||
limit_samples -> Limit number of samples returned
|
||||
randomize -> predictable randomize
|
||||
random_state -> a value that gets added to seed
|
||||
usable_portion -> What % of the data will be used
|
||||
fast_load -> initialize from npy file, Warning: Dragons be aware
|
||||
transform -> if there is any, PyTorch Transforms
|
||||
"""
|
||||
np.seterr(divide="ignore", invalid="ignore")
|
||||
|
||||
# * Set Dataset attributes
|
||||
self.observer_height = observer_height
|
||||
self.patch_size = patch_size
|
||||
self.sample_size = sample_size
|
||||
self.block_variance = block_variance
|
||||
self.observer_pad = observer_pad
|
||||
|
||||
# * PyTorch Related Variables
|
||||
self.transform = transform
|
||||
|
||||
# * Gather files
|
||||
self.files = glob(dataset_glob)
|
||||
self.dataset_type = dataset_type
|
||||
self.usable_portion = usable_portion
|
||||
self.limit_samples = limit_samples
|
||||
|
||||
self.randomize = False if fast_load else randomize
|
||||
self.random_state = random_state
|
||||
if self.randomize:
|
||||
random.shuffle(self.files)
|
||||
|
||||
# * Build dataset dictionary
|
||||
self.sample_dict = dict()
|
||||
start = 0
|
||||
for file in tqdm(self.files, ncols=100, disable=fast_load):
|
||||
blocks, mask = self.get_blocks(file, return_mask=True)
|
||||
|
||||
if len(blocks) == 0:
|
||||
continue
|
||||
|
||||
self.sample_dict[file] = {
|
||||
"start": start,
|
||||
"end": start + len(blocks[mask]),
|
||||
"mask": mask,
|
||||
"min": np.min(blocks[mask]),
|
||||
"max": np.max(blocks[mask]),
|
||||
"range": np.max(Helper.get_ranges(blocks[mask])),
|
||||
}
|
||||
start += len(blocks[mask])
|
||||
|
||||
del blocks
|
||||
if fast_load:
|
||||
break
|
||||
|
||||
self.data_min = min(self.sample_dict.values(), key=lambda x: x["min"])["min"]
|
||||
self.data_max = max(self.sample_dict.values(), key=lambda x: x["max"])["max"]
|
||||
self.data_range = max(self.sample_dict.values(), key=lambda x: x["range"])[
|
||||
"range"
|
||||
]
|
||||
|
||||
# * Check if limit_samples is enough for this dataset
|
||||
if limit_samples is not None:
|
||||
assert (
|
||||
limit_samples <= self.get_len()
|
||||
), "limit_samples cannot be bigger than dataset size"
|
||||
|
||||
# * Dataset state
|
||||
self.current_file = None
|
||||
self.current_blocks = None
|
||||
|
||||
def get_len(self):
|
||||
key = list(self.sample_dict.keys())[-1]
|
||||
return self.sample_dict[key]["end"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_path)
|
||||
if not self.limit_samples is None:
|
||||
return self.limit_samples
|
||||
return self.get_len()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
returns (x, (ox, oy, oz)), y
|
||||
"""
|
||||
rel_idx = None
|
||||
for file, info in self.sample_dict.items():
|
||||
if idx >= info["start"] and idx < info["end"]:
|
||||
rel_idx = idx - info["start"]
|
||||
if self.current_file != file:
|
||||
b = self.get_blocks(file)
|
||||
self.current_blocks = b[info["mask"]]
|
||||
self.current_file = file
|
||||
break
|
||||
|
||||
current = np.copy(self.current_blocks[rel_idx])
|
||||
current -= np.min(current)
|
||||
current /= self.data_range
|
||||
oh = self.observer_height / self.data_range
|
||||
|
||||
adjusted = self.get_adjusted(current)
|
||||
viewshed, _ = self.viewshed(adjusted, oh, idx)
|
||||
mask = np.isnan(viewshed).astype(np.uint8)
|
||||
|
||||
mask = torch.from_numpy(mask).float()
|
||||
mask = mask.unsqueeze(0)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# load image
|
||||
image = Image.open(self.image_path[index]).convert('RGB')
|
||||
filename = os.path.basename(self.image_path[index])
|
||||
target = torch.from_numpy(adjusted).float()
|
||||
target = target.unsqueeze(0)
|
||||
|
||||
if self.mask_type == 'pconv':
|
||||
index = np.random.randint(0, len(self.mask_path))
|
||||
mask = Image.open(self.mask_path[index])
|
||||
mask = mask.convert('L')
|
||||
return target, mask, f"{self.current_file}-{idx}"
|
||||
|
||||
def viewshed(self, dem, oh, seed):
|
||||
h, w = dem.shape
|
||||
np.random.seed(seed + self.random_state)
|
||||
rands = np.random.rand(h - self.observer_pad, w - self.observer_pad)
|
||||
template = np.zeros_like(dem)
|
||||
template[
|
||||
self.observer_pad - self.observer_pad // 2 : h - self.observer_pad // 2,
|
||||
self.observer_pad - self.observer_pad // 2 : w - self.observer_pad // 2,
|
||||
] = rands
|
||||
observer = tuple(np.argwhere(template == np.max(template))[0])
|
||||
|
||||
yp, xp = observer
|
||||
zp = dem[observer] + oh
|
||||
observer = (xp, yp, zp)
|
||||
viewshed = np.copy(dem)
|
||||
|
||||
# * Find perimiter
|
||||
rr, cc = rectangle_perimeter((1, 1), end=(h - 2, w - 2), shape=dem.shape)
|
||||
|
||||
# * Iterate through perimiter
|
||||
for yc, xc in zip(rr, cc):
|
||||
# * Form the line
|
||||
ray_y, ray_x = line(yp, xp, yc, xc)
|
||||
ray_z = dem[ray_y, ray_x]
|
||||
|
||||
m = (ray_z - zp) / np.hypot(ray_y - yp, ray_x - xp)
|
||||
|
||||
max_so_far = -np.inf
|
||||
for yi, xi, mi in zip(ray_y, ray_x, m):
|
||||
if mi < max_so_far:
|
||||
viewshed[yi, xi] = np.nan
|
||||
else:
|
||||
max_so_far = mi
|
||||
|
||||
return viewshed, observer
|
||||
|
||||
def blockshaped(self, arr, nside):
|
||||
"""
|
||||
Return an array of shape (n, nside, nside) where
|
||||
n * nside * nside = arr.size
|
||||
|
||||
If arr is a 2D array, the returned array should look like n subblocks with
|
||||
each subblock preserving the "physical" layout of arr.
|
||||
"""
|
||||
h, w = arr.shape
|
||||
assert h % nside == 0, "{} rows is not evenly divisble by {}".format(h, nside)
|
||||
assert w % nside == 0, "{} cols is not evenly divisble by {}".format(w, nside)
|
||||
return (
|
||||
arr.reshape(h // nside, nside, -1, nside)
|
||||
.swapaxes(1, 2)
|
||||
.reshape(-1, nside, nside)
|
||||
)
|
||||
|
||||
def get_adjusted(self, block):
|
||||
zoomed = zoom(block, 10, order=1)
|
||||
y, x = zoomed.shape
|
||||
startx = x // 2 - (self.sample_size // 2)
|
||||
starty = y // 2 - (self.sample_size // 2)
|
||||
return zoomed[
|
||||
starty : starty + self.sample_size, startx : startx + self.sample_size
|
||||
]
|
||||
|
||||
def get_blocks(self, file, return_mask=False):
|
||||
raster = rasterio.open(file)
|
||||
grid = raster.read(1)
|
||||
|
||||
# Remove minimum
|
||||
grid[grid == np.min(grid)] = np.nan
|
||||
|
||||
# Find the edges to cut from
|
||||
NL = np.count_nonzero(np.isnan(grid[:, 0]))
|
||||
NR = np.count_nonzero(np.isnan(grid[:, -1]))
|
||||
NT = np.count_nonzero(np.isnan(grid[0, :]))
|
||||
NB = np.count_nonzero(np.isnan(grid[-1, :]))
|
||||
|
||||
w, h = grid.shape
|
||||
if NL > NR:
|
||||
grid = grid[w % self.patch_size : w, 0:h]
|
||||
else:
|
||||
grid = grid[0 : w - (w % self.patch_size), 0:h]
|
||||
|
||||
w, h = grid.shape
|
||||
if NT > NB:
|
||||
grid = grid[0:w, h % self.patch_size : h]
|
||||
else:
|
||||
grid = grid[0:w, 0 : h - (h % self.patch_size)]
|
||||
|
||||
blocks = self.blockshaped(grid, self.patch_size)
|
||||
|
||||
# * Randomize
|
||||
if self.randomize:
|
||||
np.random.seed(self.random_state)
|
||||
blocks = np.random.shuffle(blocks)
|
||||
|
||||
# * Remove blocks that contain nans
|
||||
mask = ~np.isnan(blocks).any(axis=1).any(axis=1)
|
||||
blocks = blocks[mask]
|
||||
|
||||
if self.dataset_type == "train":
|
||||
blocks = blocks[: int(len(blocks) * self.usable_portion)]
|
||||
else:
|
||||
mask = np.zeros((self.h, self.w)).astype(np.uint8)
|
||||
mask[self.h//4:self.h//4*3, self.w//4:self.w//4*3] = 1
|
||||
mask = Image.fromarray(m).convert('L')
|
||||
|
||||
# augment
|
||||
image = self.img_trans(image) * 2. - 1.
|
||||
mask = F.to_tensor(self.mask_trans(mask))
|
||||
|
||||
return image, mask, filename
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
from attrdict import AttrDict
|
||||
args = {
|
||||
'dir_image': '../../../dataset',
|
||||
'data_train': 'places2',
|
||||
'dir_mask': '../../../dataset',
|
||||
'mask_type': 'pconv',
|
||||
'image_size': 512
|
||||
}
|
||||
args = AttrDict(args)
|
||||
|
||||
data = InpaintingData(args)
|
||||
print(len(data), len(data.mask_path))
|
||||
img, mask, filename = data[0]
|
||||
print(img.size(), mask.size(), filename)
|
||||
blocks = blocks[int(len(blocks) * self.usable_portion) :]
|
||||
|
||||
# * Add Variance
|
||||
blocks = np.repeat(blocks, self.block_variance, axis=0)
|
||||
|
||||
if return_mask:
|
||||
# * Further filter remeaning data in relation to z-score
|
||||
ranges = Helper.get_ranges(blocks)
|
||||
mask_ru = np.abs(stats.zscore(ranges)) < 2
|
||||
mask_rl = np.abs(stats.zscore(ranges)) > 0.2
|
||||
|
||||
# # * Terrain Ruggedness Index
|
||||
# tri = Helper.get_tri(blocks)
|
||||
# mask_tu = np.abs(stats.zscore(tri)) < 2
|
||||
# mask_tl = np.abs(stats.zscore(tri)) > 0.05
|
||||
|
||||
mask = mask_ru & mask_rl # & mask_tu & mask_tl
|
||||
return blocks, mask
|
||||
return blocks
|
||||
|
Loading…
Reference in new issue