Default to img_size in model default_cfg, defer output folder creation until later in the init sequence

pull/12/head
Ross Wightman 5 years ago
parent 9bcd65181b
commit 7dab6d1ec7

@ -42,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N',
help='Image patch size (default: 224)')
parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: None => model default)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -159,15 +159,6 @@ def main():
torch.manual_seed(args.seed + args.rank)
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name)
model = create_model(
args.model,
pretrained=args.pretrained,
@ -291,13 +282,21 @@ def main():
validate_loss_fn = train_loss_fn
eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None
if output_dir:
# only set if process is rank 0
output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(data_config['input_size'][-1])
])
output_dir = get_outdir(output_base, 'train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
best_metric = None
best_epoch = None
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed:

@ -253,9 +253,9 @@ class ModelEma:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
print("=> loaded state_dict_ema")
print("=> Loaded state_dict_ema")
else:
print("=> failed to find state_dict_ema, starting from loaded model weights)")
print("=> Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model):
# correct a mismatch in state dict keys

Loading…
Cancel
Save