prepare for tds

pull/5/head
Deniz Ugur 3 years ago
parent a96e18a5fe
commit 58d30b58d9

1
.gitignore vendored

@ -4,3 +4,4 @@ __pycache__
*.png
events.*
outputs
src/data/download

@ -12,7 +12,7 @@ class InpaintGenerator(BaseNetwork):
self.encoder = nn.Sequential(
nn.ReflectionPad2d(3),
nn.Conv2d(4, 64, 7),
nn.Conv2d(2, 64, 7),
nn.ReLU(True),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.ReLU(True),
@ -27,7 +27,7 @@ class InpaintGenerator(BaseNetwork):
nn.ReLU(True),
UpConv(128, 64),
nn.ReLU(True),
nn.Conv2d(64, 3, 3, stride=1, padding=1)
nn.Conv2d(64, 1, 3, stride=1, padding=1)
)
self.init_weights()
@ -92,7 +92,7 @@ def my_layer_norm(feat):
class Discriminator(BaseNetwork):
def __init__(self, ):
super(Discriminator, self).__init__()
inc = 3
inc = 1
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(inc, 64, 4, stride=2, padding=1, bias=False)),
nn.LeakyReLU(0.2, inplace=True),

@ -3,15 +3,11 @@ import argparse
parser = argparse.ArgumentParser(description='Image Inpainting')
# data specifications
parser.add_argument('--dir_image', type=str, default='../../dataset',
help='image dataset directory')
parser.add_argument('--dir_mask', type=str, default='../../dataset',
help='mask dataset directory')
parser.add_argument('--data_train', type=str, default='places2',
parser.add_argument('--data_train', type=str, default='tds',
help='dataname used for training')
parser.add_argument('--data_test', type=str, default='places2',
parser.add_argument('--data_test', type=str, default='tds',
help='dataname used for testing')
parser.add_argument('--image_size', type=int, default=512,
parser.add_argument('--image_size', type=int, default=256,
help='image size used during training')
parser.add_argument('--mask_type', type=str, default='pconv',
help='mask used during training')
@ -54,7 +50,7 @@ parser.add_argument('--adv_weight', type=float, default=0.01,
# training specifications
parser.add_argument('--iterations', type=int, default=1e6,
help='the number of iterations for training')
parser.add_argument('--batch_size', type=int, default=8,
parser.add_argument('--batch_size', type=int, default=1,
help='batch size in each mini-batch')
parser.add_argument('--port', type=int, default=22334,
help='tcp port for distributed training')
@ -65,7 +61,7 @@ parser.add_argument('--resume', action='store_true',
# log specifications
parser.add_argument('--print_every', type=int, default=10,
help='frequency for updating progress bar')
parser.add_argument('--save_every', type=int, default=1e4,
parser.add_argument('--save_every', type=int, default=1e3,
help='frequency for saving models')
parser.add_argument('--save_dir', type=str, default='../experiments',
help='directory for saving models and logs')

Loading…
Cancel
Save