From 3581affb7769ec3554b2ea3d242c83db3f92a960 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 5 Sep 2021 16:05:31 -0700 Subject: [PATCH] Update train.py with some flags related to scheduler tweaks, fix best checkpoint bug. --- timm/bits/checkpoint_manager.py | 2 +- train.py | 24 ++++++++++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/timm/bits/checkpoint_manager.py b/timm/bits/checkpoint_manager.py index b2c692cb..a867e229 100644 --- a/timm/bits/checkpoint_manager.py +++ b/timm/bits/checkpoint_manager.py @@ -193,7 +193,7 @@ class CheckpointManager: best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension) self._duplicate(last_save_path, best_save_path) - return None if self.best_checkpoint is None else curr_checkpoint + return curr_checkpoint if self.best_checkpoint is None else self.best_checkpoint def save_recovery(self, train_state: TrainState): tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension) diff --git a/train.py b/train.py index cad41bca..43a8108a 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, M from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\ PreprocessCfg, AugCfg, MixupCfg, AugMixDataset from timm.models import create_model, safe_model_name, convert_splitbn_model -from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy +from timm.loss import * from timm.optim import optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model @@ -121,8 +121,12 @@ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', help='learning rate cycle len multiplier (default: 1.0)') +parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', + help='amount to decay each learning rate cycle (default: 0.5)') parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', - help='learning rate cycle limit') + help='learning rate cycle limit, cycles enabled if > 1') +parser.add_argument('--lr-k-decay', type=float, default=1.0, + help='learning rate k-decay for cosine/poly (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', @@ -161,8 +165,10 @@ parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') -parser.add_argument('--jsd', action='store_true', default=False, +parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') +parser.add_argument('--bce-loss', action='store_true', default=False, + help='Enable BCE loss w/ Mixup/CutMix use.') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', @@ -448,14 +454,20 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): lr_scheduler.step(train_state.epoch) # setup loss function - if args.jsd: + if args.jsd_loss: assert args.aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing) elif mixup_active: # smoothing is handled with mixup target transform - train_loss_fn = SoftTargetCrossEntropy() + if args.bce_loss: + train_loss_fn = nn.BCEWithLogitsLoss() + else: + train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: - train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) + if args.bce_loss: + train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing) + else: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() eval_loss_fn = nn.CrossEntropyLoss()