avoid anomalies

pull/5/head
Deniz Ugur 3 years ago
parent 31ad4ba044
commit 70ac4b8ced

@ -139,6 +139,7 @@ class TerrainDataset(Dataset):
# * Dataset state
self.current_file = None
self.current_blocks = None
self.idx_offset = 0
def get_len(self):
key = list(self.sample_dict.keys())[-1]
@ -150,6 +151,16 @@ class TerrainDataset(Dataset):
return self.get_len()
def __getitem__(self, idx):
while True:
target, mask, file_name = self.__internal_getitem__(idx + self.idx_offset)
if not np.all(mask.numpy() == 0):
break
else:
self.idx_offset += 1
print(f"Skipped block #{self.idx_offset}")
return target, mask, file_name
def __internal_getitem__(self, idx):
"""
returns (x, (ox, oy, oz)), y
"""

@ -19,6 +19,9 @@ class Trainer():
self.args = args
self.iteration = 0
# Detect anomalies
torch.autograd.set_detect_anomaly(True)
# setup data set and data loader
self.dataloader = create_loader(args)

Loading…
Cancel
Save