|
|
@ -33,7 +33,7 @@ from timm.bits import initialize_device, setup_model_and_optimizer, DeviceEnv, M
|
|
|
|
from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\
|
|
|
|
from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config,\
|
|
|
|
PreprocessCfg, AugCfg, MixupCfg, AugMixDataset
|
|
|
|
PreprocessCfg, AugCfg, MixupCfg, AugMixDataset
|
|
|
|
from timm.models import create_model, safe_model_name, convert_splitbn_model
|
|
|
|
from timm.models import create_model, safe_model_name, convert_splitbn_model
|
|
|
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
|
|
|
from timm.loss import *
|
|
|
|
from timm.optim import optimizer_kwargs
|
|
|
|
from timm.optim import optimizer_kwargs
|
|
|
|
from timm.scheduler import create_scheduler
|
|
|
|
from timm.scheduler import create_scheduler
|
|
|
|
from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model
|
|
|
|
from timm.utils import setup_default_logging, random_seed, get_outdir, unwrap_model
|
|
|
@ -121,8 +121,12 @@ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
|
|
|
help='learning rate noise std-dev (default: 1.0)')
|
|
|
|
help='learning rate noise std-dev (default: 1.0)')
|
|
|
|
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
|
|
|
|
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
|
|
|
|
help='learning rate cycle len multiplier (default: 1.0)')
|
|
|
|
help='learning rate cycle len multiplier (default: 1.0)')
|
|
|
|
|
|
|
|
parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
|
|
|
|
|
|
|
|
help='amount to decay each learning rate cycle (default: 0.5)')
|
|
|
|
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
|
|
|
|
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
|
|
|
|
help='learning rate cycle limit')
|
|
|
|
help='learning rate cycle limit, cycles enabled if > 1')
|
|
|
|
|
|
|
|
parser.add_argument('--lr-k-decay', type=float, default=1.0,
|
|
|
|
|
|
|
|
help='learning rate k-decay for cosine/poly (default: 1.0)')
|
|
|
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
|
|
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
|
|
|
help='warmup learning rate (default: 0.0001)')
|
|
|
|
help='warmup learning rate (default: 0.0001)')
|
|
|
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
|
|
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
|
|
@ -161,8 +165,10 @@ parser.add_argument('--aa', type=str, default=None, metavar='NAME',
|
|
|
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
|
|
|
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
|
|
|
|
parser.add_argument('--aug-splits', type=int, default=0,
|
|
|
|
parser.add_argument('--aug-splits', type=int, default=0,
|
|
|
|
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
|
|
|
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
|
|
|
|
parser.add_argument('--jsd', action='store_true', default=False,
|
|
|
|
parser.add_argument('--jsd-loss', action='store_true', default=False,
|
|
|
|
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
|
|
|
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
|
|
|
|
|
|
|
|
parser.add_argument('--bce-loss', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Enable BCE loss w/ Mixup/CutMix use.')
|
|
|
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
|
|
|
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
|
|
|
help='Random erase prob (default: 0.)')
|
|
|
|
help='Random erase prob (default: 0.)')
|
|
|
|
parser.add_argument('--remode', type=str, default='const',
|
|
|
|
parser.add_argument('--remode', type=str, default='const',
|
|
|
@ -448,14 +454,20 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
|
|
|
|
lr_scheduler.step(train_state.epoch)
|
|
|
|
lr_scheduler.step(train_state.epoch)
|
|
|
|
|
|
|
|
|
|
|
|
# setup loss function
|
|
|
|
# setup loss function
|
|
|
|
if args.jsd:
|
|
|
|
if args.jsd_loss:
|
|
|
|
assert args.aug_splits > 1 # JSD only valid with aug splits set
|
|
|
|
assert args.aug_splits > 1 # JSD only valid with aug splits set
|
|
|
|
train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing)
|
|
|
|
train_loss_fn = JsdCrossEntropy(num_splits=args.aug_splits, smoothing=args.smoothing)
|
|
|
|
elif mixup_active:
|
|
|
|
elif mixup_active:
|
|
|
|
# smoothing is handled with mixup target transform
|
|
|
|
# smoothing is handled with mixup target transform
|
|
|
|
train_loss_fn = SoftTargetCrossEntropy()
|
|
|
|
if args.bce_loss:
|
|
|
|
|
|
|
|
train_loss_fn = nn.BCEWithLogitsLoss()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
train_loss_fn = SoftTargetCrossEntropy()
|
|
|
|
elif args.smoothing:
|
|
|
|
elif args.smoothing:
|
|
|
|
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
|
|
|
if args.bce_loss:
|
|
|
|
|
|
|
|
train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
train_loss_fn = nn.CrossEntropyLoss()
|
|
|
|
train_loss_fn = nn.CrossEntropyLoss()
|
|
|
|
eval_loss_fn = nn.CrossEntropyLoss()
|
|
|
|
eval_loss_fn = nn.CrossEntropyLoss()
|
|
|
|