diff --git a/src/data/dataset.py b/src/data/dataset.py index fda42f0..eff50d0 100644 --- a/src/data/dataset.py +++ b/src/data/dataset.py @@ -139,6 +139,7 @@ class TerrainDataset(Dataset): # * Dataset state self.current_file = None self.current_blocks = None + self.idx_offset = 0 def get_len(self): key = list(self.sample_dict.keys())[-1] @@ -150,6 +151,16 @@ class TerrainDataset(Dataset): return self.get_len() def __getitem__(self, idx): + while True: + target, mask, file_name = self.__internal_getitem__(idx + self.idx_offset) + if not np.all(mask.numpy() == 0): + break + else: + self.idx_offset += 1 + print(f"Skipped block #{self.idx_offset}") + return target, mask, file_name + + def __internal_getitem__(self, idx): """ returns (x, (ox, oy, oz)), y """ diff --git a/src/trainer/trainer.py b/src/trainer/trainer.py index 0288e6f..90a3a87 100644 --- a/src/trainer/trainer.py +++ b/src/trainer/trainer.py @@ -19,6 +19,9 @@ class Trainer(): self.args = args self.iteration = 0 + # Detect anomalies + torch.autograd.set_detect_anomaly(True) + # setup data set and data loader self.dataloader = create_loader(args)