You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
96 lines
4.2 KiB
96 lines
4.2 KiB
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description='Image Inpainting')
|
|
|
|
# data specifications
|
|
parser.add_argument('--dir_image', type=str, default='../../dataset',
|
|
help='image dataset directory')
|
|
parser.add_argument('--dir_mask', type=str, default='../../dataset',
|
|
help='mask dataset directory')
|
|
parser.add_argument('--data_train', type=str, default='places2',
|
|
help='dataname used for training')
|
|
parser.add_argument('--data_test', type=str, default='places2',
|
|
help='dataname used for testing')
|
|
parser.add_argument('--image_size', type=int, default=512,
|
|
help='image size used during training')
|
|
parser.add_argument('--mask_type', type=str, default='pconv',
|
|
help='mask used during training')
|
|
|
|
# model specifications
|
|
parser.add_argument('--model', type=str, default='aotgan',
|
|
help='model name')
|
|
parser.add_argument('--block_num', type=int, default=8,
|
|
help='number of AOT blocks')
|
|
parser.add_argument('--rates', type=str, default='1+2+4+8',
|
|
help='dilation rates used in AOT block')
|
|
parser.add_argument('--gan_type', type=str, default='smgan',
|
|
help='discriminator types')
|
|
|
|
# hardware specifications
|
|
parser.add_argument('--seed', type=int, default=2021,
|
|
help='random seed')
|
|
parser.add_argument('--num_workers', type=int, default=4,
|
|
help='number of workers used in data loader')
|
|
|
|
# optimization specifications
|
|
parser.add_argument('--lrg', type=float, default=1e-4,
|
|
help='learning rate for generator')
|
|
parser.add_argument('--lrd', type=float, default=1e-4,
|
|
help='learning rate for discriminator')
|
|
parser.add_argument('--optimizer', default='ADAM',
|
|
choices=('SGD', 'ADAM', 'RMSprop'),
|
|
help='optimizer to use (SGD | ADAM | RMSprop)')
|
|
parser.add_argument('--beta1', type=float, default=0.5,
|
|
help='beta1 in optimizer')
|
|
parser.add_argument('--beta2', type=float, default=0.999,
|
|
help='beta2 in optimier')
|
|
|
|
# loss specifications
|
|
parser.add_argument('--rec_loss', type=str, default='1*L1+250*Style+0.1*Perceptual',
|
|
help='losses for reconstruction')
|
|
parser.add_argument('--adv_weight', type=float, default=0.01,
|
|
help='loss weight for adversarial loss')
|
|
|
|
# training specifications
|
|
parser.add_argument('--iterations', type=int, default=1e6,
|
|
help='the number of iterations for training')
|
|
parser.add_argument('--batch_size', type=int, default=8,
|
|
help='batch size in each mini-batch')
|
|
parser.add_argument('--port', type=int, default=22334,
|
|
help='tcp port for distributed training')
|
|
parser.add_argument('--resume', action='store_true',
|
|
help='resume from previous iteration')
|
|
|
|
|
|
# log specifications
|
|
parser.add_argument('--print_every', type=int, default=10,
|
|
help='frequency for updating progress bar')
|
|
parser.add_argument('--save_every', type=int, default=1e4,
|
|
help='frequency for saving models')
|
|
parser.add_argument('--save_dir', type=str, default='../experiments',
|
|
help='directory for saving models and logs')
|
|
parser.add_argument('--tensorboard', action='store_true',
|
|
help='default: false, since it will slow training. use it for debugging')
|
|
|
|
# test and demo specifications
|
|
parser.add_argument('--pre_train', type=str, default=None,
|
|
help='path to pretrained models')
|
|
parser.add_argument('--outputs', type=str, default='../outputs',
|
|
help='path to save results')
|
|
parser.add_argument('--thick', type=int, default=15,
|
|
help='the thick of pen for free-form drawing')
|
|
parser.add_argument('--painter', default='freeform', choices=('freeform', 'bbox'),
|
|
help='different painters for demo ')
|
|
|
|
|
|
# ----------------------------------
|
|
args = parser.parse_args()
|
|
args.iterations = int(args.iterations)
|
|
|
|
args.rates = list(map(int, list(args.rates.split('+'))))
|
|
|
|
losses = list(args.rec_loss.split('+'))
|
|
args.rec_loss = {}
|
|
for l in losses:
|
|
weight, name = l.split('*')
|
|
args.rec_loss[name] = float(weight) |