diff --git a/.gitignore b/.gitignore index a564698..038cda4 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ __pycache__ *.png events.* outputs +src/data/download diff --git a/src/loss/__init__.py b/src/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/metric/__init__.py b/src/metric/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/aotgan.py b/src/model/aotgan.py index 518b76c..31ceab9 100644 --- a/src/model/aotgan.py +++ b/src/model/aotgan.py @@ -12,7 +12,7 @@ class InpaintGenerator(BaseNetwork): self.encoder = nn.Sequential( nn.ReflectionPad2d(3), - nn.Conv2d(4, 64, 7), + nn.Conv2d(2, 64, 7), nn.ReLU(True), nn.Conv2d(64, 128, 4, stride=2, padding=1), nn.ReLU(True), @@ -27,7 +27,7 @@ class InpaintGenerator(BaseNetwork): nn.ReLU(True), UpConv(128, 64), nn.ReLU(True), - nn.Conv2d(64, 3, 3, stride=1, padding=1) + nn.Conv2d(64, 1, 3, stride=1, padding=1) ) self.init_weights() @@ -92,7 +92,7 @@ def my_layer_norm(feat): class Discriminator(BaseNetwork): def __init__(self, ): super(Discriminator, self).__init__() - inc = 3 + inc = 1 self.conv = nn.Sequential( spectral_norm(nn.Conv2d(inc, 64, 4, stride=2, padding=1, bias=False)), nn.LeakyReLU(0.2, inplace=True), diff --git a/src/trainer/__init__.py b/src/trainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/option.py b/src/utils/option.py index f71d55c..ace1e49 100644 --- a/src/utils/option.py +++ b/src/utils/option.py @@ -3,15 +3,11 @@ import argparse parser = argparse.ArgumentParser(description='Image Inpainting') # data specifications -parser.add_argument('--dir_image', type=str, default='../../dataset', - help='image dataset directory') -parser.add_argument('--dir_mask', type=str, default='../../dataset', - help='mask dataset directory') -parser.add_argument('--data_train', type=str, default='places2', +parser.add_argument('--data_train', type=str, default='tds', help='dataname used for training') -parser.add_argument('--data_test', type=str, default='places2', +parser.add_argument('--data_test', type=str, default='tds', help='dataname used for testing') -parser.add_argument('--image_size', type=int, default=512, +parser.add_argument('--image_size', type=int, default=256, help='image size used during training') parser.add_argument('--mask_type', type=str, default='pconv', help='mask used during training') @@ -54,7 +50,7 @@ parser.add_argument('--adv_weight', type=float, default=0.01, # training specifications parser.add_argument('--iterations', type=int, default=1e6, help='the number of iterations for training') -parser.add_argument('--batch_size', type=int, default=8, +parser.add_argument('--batch_size', type=int, default=1, help='batch size in each mini-batch') parser.add_argument('--port', type=int, default=22334, help='tcp port for distributed training') @@ -65,7 +61,7 @@ parser.add_argument('--resume', action='store_true', # log specifications parser.add_argument('--print_every', type=int, default=10, help='frequency for updating progress bar') -parser.add_argument('--save_every', type=int, default=1e4, +parser.add_argument('--save_every', type=int, default=1e3, help='frequency for saving models') parser.add_argument('--save_dir', type=str, default='../experiments', help='directory for saving models and logs')