diff --git a/timm/utils/checkpoint_saver.py b/timm/utils/checkpoint_saver.py index 51896e78..6aad74ee 100644 --- a/timm/utils/checkpoint_saver.py +++ b/timm/utils/checkpoint_saver.py @@ -66,7 +66,7 @@ class CheckpointSaver: last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension) self._save(tmp_save_path, epoch, metric) if os.path.exists(last_save_path): - os.unlink(last_save_path) # required for Windows support. + os.unlink(last_save_path) # required for Windows support. os.rename(tmp_save_path, last_save_path) worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None if (len(self.checkpoint_files) < self.max_history @@ -118,7 +118,7 @@ class CheckpointSaver: def _cleanup_checkpoints(self, trim=0): trim = min(len(self.checkpoint_files), trim) delete_index = self.max_history - trim - if delete_index <= 0 or len(self.checkpoint_files) <= delete_index: + if delete_index < 0 or len(self.checkpoint_files) <= delete_index: return to_delete = self.checkpoint_files[delete_index:] for d in to_delete: @@ -147,7 +147,4 @@ class CheckpointSaver: recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix) files = glob.glob(recovery_path + '*' + self.extension) files = sorted(files) - if len(files): - return files[0] - else: - return '' + return files[0] if len(files) else '' diff --git a/train.py b/train.py index aa8e6553..f0fcd2af 100755 --- a/train.py +++ b/train.py @@ -236,6 +236,8 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') +parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', + help='number of checkpoints to keep (default: 10)') parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') parser.add_argument('--save-images', action='store_true', default=False, @@ -547,7 +549,7 @@ def main(): decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, - checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing) + checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text)