parent
58d30b58d9
commit
5b0a3a0234
@ -1,18 +1,27 @@
|
|||||||
from .dataset import InpaintingData
|
from .dataset import TerrainDataset
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
def sample_data(loader):
|
def sample_data(loader):
|
||||||
while True:
|
while True:
|
||||||
for batch in loader:
|
for batch in loader:
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
def create_loader(args):
|
def create_loader(args):
|
||||||
dataset = InpaintingData(args)
|
dataset = TerrainDataset(
|
||||||
|
"data/download/MDRS/data/*.tif",
|
||||||
|
dataset_type="train",
|
||||||
|
randomize=True,
|
||||||
|
block_variance=1,
|
||||||
|
)
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
dataset, batch_size=args.batch_size//args.world_size,
|
dataset,
|
||||||
shuffle=True, num_workers=args.num_workers, pin_memory=True)
|
batch_size=args.batch_size // args.world_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=args.num_workers,
|
||||||
|
pin_memory=True,
|
||||||
|
)
|
||||||
|
|
||||||
return sample_data(data_loader)
|
return sample_data(data_loader)
|
@ -1,80 +1,302 @@
|
|||||||
import os
|
import rasterio
|
||||||
import math
|
import torch
|
||||||
|
from scipy.ndimage import zoom
|
||||||
|
from skimage.draw import rectangle_perimeter, line
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from scipy import stats
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
import random
|
||||||
|
|
||||||
from random import shuffle
|
|
||||||
from PIL import Image, ImageFilter
|
|
||||||
|
|
||||||
import torch
|
class Helper:
|
||||||
import torchvision.transforms.functional as F
|
def __init__(self) -> None:
|
||||||
import torchvision.transforms as transforms
|
super().__init__()
|
||||||
from torch.utils.data import Dataset, DataLoader
|
|
||||||
|
@staticmethod
|
||||||
class InpaintingData(Dataset):
|
def get_ranges(x):
|
||||||
def __init__(self, args):
|
bmax = np.max(x.reshape(-1, x.shape[2] ** 2), axis=1)
|
||||||
super(Dataset, self).__init__()
|
bmin = np.min(x.reshape(-1, x.shape[2] ** 2), axis=1)
|
||||||
self.w = self.h = args.image_size
|
return bmax - bmin
|
||||||
self.mask_type = args.mask_type
|
|
||||||
|
@staticmethod
|
||||||
# image and mask
|
def get_tri(x):
|
||||||
self.image_path = []
|
return np.apply_along_axis(Helper._get_tri, 1, x.reshape(-1, x.shape[2] ** 2))
|
||||||
for ext in ['*.jpg', '*.png']:
|
|
||||||
self.image_path.extend(glob(os.path.join(args.dir_image, args.data_train, ext)))
|
@staticmethod
|
||||||
self.mask_path = glob(os.path.join(args.dir_mask, args.mask_type, '*.png'))
|
def _get_tri(x):
|
||||||
|
now = x.reshape(-1, 30)
|
||||||
# augmentation
|
tri = np.zeros_like(now)
|
||||||
self.img_trans = transforms.Compose([
|
|
||||||
transforms.RandomResizedCrop(args.image_size),
|
d = 1
|
||||||
transforms.RandomHorizontalFlip(),
|
for i in range(1, 29):
|
||||||
transforms.ColorJitter(0.05, 0.05, 0.05, 0.05),
|
for j in range(1, 29):
|
||||||
transforms.ToTensor()])
|
tri[i, j] = np.sqrt(
|
||||||
self.mask_trans = transforms.Compose([
|
np.sum(
|
||||||
transforms.Resize(args.image_size, interpolation=transforms.InterpolationMode.NEAREST),
|
np.power(
|
||||||
transforms.RandomHorizontalFlip(),
|
now[i - d : i + d + 1, j - d : j + d + 1].flatten()
|
||||||
transforms.RandomRotation(
|
- now[i, j],
|
||||||
(0, 45), interpolation=transforms.InterpolationMode.NEAREST),
|
2,
|
||||||
])
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stats.trim_mean(tri.flatten(), 0.2)
|
||||||
|
|
||||||
|
|
||||||
|
class TerrainDataset(Dataset):
|
||||||
|
NAN = 0
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset_glob,
|
||||||
|
dataset_type,
|
||||||
|
patch_size=30,
|
||||||
|
sample_size=256,
|
||||||
|
observer_pad=50,
|
||||||
|
block_variance=4,
|
||||||
|
observer_height=0.75,
|
||||||
|
limit_samples=None,
|
||||||
|
randomize=True,
|
||||||
|
random_state=42,
|
||||||
|
usable_portion=1.0,
|
||||||
|
fast_load=False,
|
||||||
|
transform=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
dataset_glob -> glob to *.tif files (i.e. "data/MDRS/data/*.tif")
|
||||||
|
dataset_type -> train or validation
|
||||||
|
patch_size -> the 1m^2 area to read from .TIF
|
||||||
|
sample_size -> the 0.1m^2 res area to be trained sample size
|
||||||
|
observer_pad -> n pixels to pad before getting a random observer
|
||||||
|
block_variance -> how many different observer points
|
||||||
|
observer_height -> Observer Height
|
||||||
|
limit_samples -> Limit number of samples returned
|
||||||
|
randomize -> predictable randomize
|
||||||
|
random_state -> a value that gets added to seed
|
||||||
|
usable_portion -> What % of the data will be used
|
||||||
|
fast_load -> initialize from npy file, Warning: Dragons be aware
|
||||||
|
transform -> if there is any, PyTorch Transforms
|
||||||
|
"""
|
||||||
|
np.seterr(divide="ignore", invalid="ignore")
|
||||||
|
|
||||||
|
# * Set Dataset attributes
|
||||||
|
self.observer_height = observer_height
|
||||||
|
self.patch_size = patch_size
|
||||||
|
self.sample_size = sample_size
|
||||||
|
self.block_variance = block_variance
|
||||||
|
self.observer_pad = observer_pad
|
||||||
|
|
||||||
|
# * PyTorch Related Variables
|
||||||
|
self.transform = transform
|
||||||
|
|
||||||
|
# * Gather files
|
||||||
|
self.files = glob(dataset_glob)
|
||||||
|
self.dataset_type = dataset_type
|
||||||
|
self.usable_portion = usable_portion
|
||||||
|
self.limit_samples = limit_samples
|
||||||
|
|
||||||
|
self.randomize = False if fast_load else randomize
|
||||||
|
self.random_state = random_state
|
||||||
|
if self.randomize:
|
||||||
|
random.shuffle(self.files)
|
||||||
|
|
||||||
|
# * Build dataset dictionary
|
||||||
|
self.sample_dict = dict()
|
||||||
|
start = 0
|
||||||
|
for file in tqdm(self.files, ncols=100, disable=fast_load):
|
||||||
|
blocks, mask = self.get_blocks(file, return_mask=True)
|
||||||
|
|
||||||
|
if len(blocks) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.sample_dict[file] = {
|
||||||
|
"start": start,
|
||||||
|
"end": start + len(blocks[mask]),
|
||||||
|
"mask": mask,
|
||||||
|
"min": np.min(blocks[mask]),
|
||||||
|
"max": np.max(blocks[mask]),
|
||||||
|
"range": np.max(Helper.get_ranges(blocks[mask])),
|
||||||
|
}
|
||||||
|
start += len(blocks[mask])
|
||||||
|
|
||||||
|
del blocks
|
||||||
|
if fast_load:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.data_min = min(self.sample_dict.values(), key=lambda x: x["min"])["min"]
|
||||||
|
self.data_max = max(self.sample_dict.values(), key=lambda x: x["max"])["max"]
|
||||||
|
self.data_range = max(self.sample_dict.values(), key=lambda x: x["range"])[
|
||||||
|
"range"
|
||||||
|
]
|
||||||
|
|
||||||
|
# * Check if limit_samples is enough for this dataset
|
||||||
|
if limit_samples is not None:
|
||||||
|
assert (
|
||||||
|
limit_samples <= self.get_len()
|
||||||
|
), "limit_samples cannot be bigger than dataset size"
|
||||||
|
|
||||||
|
# * Dataset state
|
||||||
|
self.current_file = None
|
||||||
|
self.current_blocks = None
|
||||||
|
|
||||||
|
def get_len(self):
|
||||||
|
key = list(self.sample_dict.keys())[-1]
|
||||||
|
return self.sample_dict[key]["end"]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.image_path)
|
if not self.limit_samples is None:
|
||||||
|
return self.limit_samples
|
||||||
|
return self.get_len()
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
returns (x, (ox, oy, oz)), y
|
||||||
|
"""
|
||||||
|
rel_idx = None
|
||||||
|
for file, info in self.sample_dict.items():
|
||||||
|
if idx >= info["start"] and idx < info["end"]:
|
||||||
|
rel_idx = idx - info["start"]
|
||||||
|
if self.current_file != file:
|
||||||
|
b = self.get_blocks(file)
|
||||||
|
self.current_blocks = b[info["mask"]]
|
||||||
|
self.current_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
current = np.copy(self.current_blocks[rel_idx])
|
||||||
|
current -= np.min(current)
|
||||||
|
current /= self.data_range
|
||||||
|
oh = self.observer_height / self.data_range
|
||||||
|
|
||||||
|
adjusted = self.get_adjusted(current)
|
||||||
|
viewshed, _ = self.viewshed(adjusted, oh, idx)
|
||||||
|
mask = np.isnan(viewshed).astype(np.uint8)
|
||||||
|
|
||||||
|
mask = torch.from_numpy(mask).float()
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
target = torch.from_numpy(adjusted).float()
|
||||||
# load image
|
target = target.unsqueeze(0)
|
||||||
image = Image.open(self.image_path[index]).convert('RGB')
|
|
||||||
filename = os.path.basename(self.image_path[index])
|
|
||||||
|
|
||||||
if self.mask_type == 'pconv':
|
return target, mask, f"{self.current_file}-{idx}"
|
||||||
index = np.random.randint(0, len(self.mask_path))
|
|
||||||
mask = Image.open(self.mask_path[index])
|
def viewshed(self, dem, oh, seed):
|
||||||
mask = mask.convert('L')
|
h, w = dem.shape
|
||||||
|
np.random.seed(seed + self.random_state)
|
||||||
|
rands = np.random.rand(h - self.observer_pad, w - self.observer_pad)
|
||||||
|
template = np.zeros_like(dem)
|
||||||
|
template[
|
||||||
|
self.observer_pad - self.observer_pad // 2 : h - self.observer_pad // 2,
|
||||||
|
self.observer_pad - self.observer_pad // 2 : w - self.observer_pad // 2,
|
||||||
|
] = rands
|
||||||
|
observer = tuple(np.argwhere(template == np.max(template))[0])
|
||||||
|
|
||||||
|
yp, xp = observer
|
||||||
|
zp = dem[observer] + oh
|
||||||
|
observer = (xp, yp, zp)
|
||||||
|
viewshed = np.copy(dem)
|
||||||
|
|
||||||
|
# * Find perimiter
|
||||||
|
rr, cc = rectangle_perimeter((1, 1), end=(h - 2, w - 2), shape=dem.shape)
|
||||||
|
|
||||||
|
# * Iterate through perimiter
|
||||||
|
for yc, xc in zip(rr, cc):
|
||||||
|
# * Form the line
|
||||||
|
ray_y, ray_x = line(yp, xp, yc, xc)
|
||||||
|
ray_z = dem[ray_y, ray_x]
|
||||||
|
|
||||||
|
m = (ray_z - zp) / np.hypot(ray_y - yp, ray_x - xp)
|
||||||
|
|
||||||
|
max_so_far = -np.inf
|
||||||
|
for yi, xi, mi in zip(ray_y, ray_x, m):
|
||||||
|
if mi < max_so_far:
|
||||||
|
viewshed[yi, xi] = np.nan
|
||||||
|
else:
|
||||||
|
max_so_far = mi
|
||||||
|
|
||||||
|
return viewshed, observer
|
||||||
|
|
||||||
|
def blockshaped(self, arr, nside):
|
||||||
|
"""
|
||||||
|
Return an array of shape (n, nside, nside) where
|
||||||
|
n * nside * nside = arr.size
|
||||||
|
|
||||||
|
If arr is a 2D array, the returned array should look like n subblocks with
|
||||||
|
each subblock preserving the "physical" layout of arr.
|
||||||
|
"""
|
||||||
|
h, w = arr.shape
|
||||||
|
assert h % nside == 0, "{} rows is not evenly divisble by {}".format(h, nside)
|
||||||
|
assert w % nside == 0, "{} cols is not evenly divisble by {}".format(w, nside)
|
||||||
|
return (
|
||||||
|
arr.reshape(h // nside, nside, -1, nside)
|
||||||
|
.swapaxes(1, 2)
|
||||||
|
.reshape(-1, nside, nside)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_adjusted(self, block):
|
||||||
|
zoomed = zoom(block, 10, order=1)
|
||||||
|
y, x = zoomed.shape
|
||||||
|
startx = x // 2 - (self.sample_size // 2)
|
||||||
|
starty = y // 2 - (self.sample_size // 2)
|
||||||
|
return zoomed[
|
||||||
|
starty : starty + self.sample_size, startx : startx + self.sample_size
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_blocks(self, file, return_mask=False):
|
||||||
|
raster = rasterio.open(file)
|
||||||
|
grid = raster.read(1)
|
||||||
|
|
||||||
|
# Remove minimum
|
||||||
|
grid[grid == np.min(grid)] = np.nan
|
||||||
|
|
||||||
|
# Find the edges to cut from
|
||||||
|
NL = np.count_nonzero(np.isnan(grid[:, 0]))
|
||||||
|
NR = np.count_nonzero(np.isnan(grid[:, -1]))
|
||||||
|
NT = np.count_nonzero(np.isnan(grid[0, :]))
|
||||||
|
NB = np.count_nonzero(np.isnan(grid[-1, :]))
|
||||||
|
|
||||||
|
w, h = grid.shape
|
||||||
|
if NL > NR:
|
||||||
|
grid = grid[w % self.patch_size : w, 0:h]
|
||||||
|
else:
|
||||||
|
grid = grid[0 : w - (w % self.patch_size), 0:h]
|
||||||
|
|
||||||
|
w, h = grid.shape
|
||||||
|
if NT > NB:
|
||||||
|
grid = grid[0:w, h % self.patch_size : h]
|
||||||
|
else:
|
||||||
|
grid = grid[0:w, 0 : h - (h % self.patch_size)]
|
||||||
|
|
||||||
|
blocks = self.blockshaped(grid, self.patch_size)
|
||||||
|
|
||||||
|
# * Randomize
|
||||||
|
if self.randomize:
|
||||||
|
np.random.seed(self.random_state)
|
||||||
|
blocks = np.random.shuffle(blocks)
|
||||||
|
|
||||||
|
# * Remove blocks that contain nans
|
||||||
|
mask = ~np.isnan(blocks).any(axis=1).any(axis=1)
|
||||||
|
blocks = blocks[mask]
|
||||||
|
|
||||||
|
if self.dataset_type == "train":
|
||||||
|
blocks = blocks[: int(len(blocks) * self.usable_portion)]
|
||||||
else:
|
else:
|
||||||
mask = np.zeros((self.h, self.w)).astype(np.uint8)
|
blocks = blocks[int(len(blocks) * self.usable_portion) :]
|
||||||
mask[self.h//4:self.h//4*3, self.w//4:self.w//4*3] = 1
|
|
||||||
mask = Image.fromarray(m).convert('L')
|
# * Add Variance
|
||||||
|
blocks = np.repeat(blocks, self.block_variance, axis=0)
|
||||||
# augment
|
|
||||||
image = self.img_trans(image) * 2. - 1.
|
if return_mask:
|
||||||
mask = F.to_tensor(self.mask_trans(mask))
|
# * Further filter remeaning data in relation to z-score
|
||||||
|
ranges = Helper.get_ranges(blocks)
|
||||||
return image, mask, filename
|
mask_ru = np.abs(stats.zscore(ranges)) < 2
|
||||||
|
mask_rl = np.abs(stats.zscore(ranges)) > 0.2
|
||||||
|
|
||||||
|
# # * Terrain Ruggedness Index
|
||||||
if __name__ == '__main__':
|
# tri = Helper.get_tri(blocks)
|
||||||
|
# mask_tu = np.abs(stats.zscore(tri)) < 2
|
||||||
from attrdict import AttrDict
|
# mask_tl = np.abs(stats.zscore(tri)) > 0.05
|
||||||
args = {
|
|
||||||
'dir_image': '../../../dataset',
|
mask = mask_ru & mask_rl # & mask_tu & mask_tl
|
||||||
'data_train': 'places2',
|
return blocks, mask
|
||||||
'dir_mask': '../../../dataset',
|
return blocks
|
||||||
'mask_type': 'pconv',
|
|
||||||
'image_size': 512
|
|
||||||
}
|
|
||||||
args = AttrDict(args)
|
|
||||||
|
|
||||||
data = InpaintingData(args)
|
|
||||||
print(len(data), len(data.mask_path))
|
|
||||||
img, mask, filename = data[0]
|
|
||||||
print(img.size(), mask.size(), filename)
|
|
||||||
|
Loading…
Reference in new issue