Fix #387 so that checkpoint saver works with max history of 1. Add checkpoint-hist arg to train.py.

pull/401/head
Ross Wightman 4 years ago
parent 99b82ae5ab
commit 4203efa36d

@ -66,7 +66,7 @@ class CheckpointSaver:
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
self._save(tmp_save_path, epoch, metric) self._save(tmp_save_path, epoch, metric)
if os.path.exists(last_save_path): if os.path.exists(last_save_path):
os.unlink(last_save_path) # required for Windows support. os.unlink(last_save_path) # required for Windows support.
os.rename(tmp_save_path, last_save_path) os.rename(tmp_save_path, last_save_path)
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if (len(self.checkpoint_files) < self.max_history if (len(self.checkpoint_files) < self.max_history
@ -118,7 +118,7 @@ class CheckpointSaver:
def _cleanup_checkpoints(self, trim=0): def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim) trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim delete_index = self.max_history - trim
if delete_index <= 0 or len(self.checkpoint_files) <= delete_index: if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
return return
to_delete = self.checkpoint_files[delete_index:] to_delete = self.checkpoint_files[delete_index:]
for d in to_delete: for d in to_delete:
@ -147,7 +147,4 @@ class CheckpointSaver:
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
files = glob.glob(recovery_path + '*' + self.extension) files = glob.glob(recovery_path + '*' + self.extension)
files = sorted(files) files = sorted(files)
if len(files): return files[0] if len(files) else ''
return files[0]
else:
return ''

@ -236,6 +236,8 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
parser.add_argument('--save-images', action='store_true', default=False, parser.add_argument('--save-images', action='store_true', default=False,
@ -547,7 +549,7 @@ def main():
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver( saver = CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text) f.write(args_text)

Loading…
Cancel
Save