diff --git a/src/data/__init__.py b/src/data/__init__.py index 3c0e642..6a159ad 100644 --- a/src/data/__init__.py +++ b/src/data/__init__.py @@ -11,7 +11,7 @@ def sample_data(loader): def create_loader(args): dataset = TerrainDataset( - "data/download/MDRS/data/*.tif", + args.dir_train, dataset_type="train", randomize=True, block_variance=1, diff --git a/src/utils/option.py b/src/utils/option.py index ace1e49..6443f16 100644 --- a/src/utils/option.py +++ b/src/utils/option.py @@ -3,6 +3,8 @@ import argparse parser = argparse.ArgumentParser(description='Image Inpainting') # data specifications +parser.add_argument('--dir_train', type=str, default='data/download/*/data/*.tif', + help='train dataset directory') parser.add_argument('--data_train', type=str, default='tds', help='dataname used for training') parser.add_argument('--data_test', type=str, default='tds',