diff --git a/timm/data/loader.py b/timm/data/loader.py index 76144669..99cf132f 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -11,7 +11,7 @@ import numpy as np from .transforms_factory import create_transform from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .distributed_sampler import OrderedDistributedSampler +from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler from .random_erasing import RandomErasing from .mixup import FastCollateMixup @@ -142,6 +142,7 @@ def create_loader( vflip=0., color_jitter=0.4, auto_augment=None, + num_aug_repeats=0, num_aug_splits=0, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, @@ -186,11 +187,16 @@ def create_loader( sampler = None if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if is_training: - sampler = torch.utils.data.distributed.DistributedSampler(dataset) + if num_aug_repeats: + sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats) + else: + sampler = torch.utils.data.distributed.DistributedSampler(dataset) else: # This will add extra duplicate entries to result in equal num # of samples per-process, will slightly alter validation results sampler = OrderedDistributedSampler(dataset) + else: + assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use" if collate_fn is None: collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate diff --git a/timm/loss/__init__.py b/timm/loss/__init__.py index 28a686ce..a74bcb88 100644 --- a/timm/loss/__init__.py +++ b/timm/loss/__init__.py @@ -1,3 +1,4 @@ +from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel +from .binary_cross_entropy import DenseBinaryCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .jsd import JsdCrossEntropy -from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel \ No newline at end of file diff --git a/train.py b/train.py index f1c1581e..07c5b1a8 100755 --- a/train.py +++ b/train.py @@ -32,7 +32,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters from timm.utils import * -from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy +from timm.loss import * from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -140,8 +140,12 @@ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', help='learning rate cycle len multiplier (default: 1.0)') +parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', + help='amount to decay each learning rate cycle (default: 0.5)') parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', - help='learning rate cycle limit') + help='learning rate cycle limit, cycles enabled if > 1') +parser.add_argument('--lr-k-decay', type=float, default=1.0, + help='learning rate k-decay for cosine/poly (default: 1.0)') parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', @@ -178,10 +182,14 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') parser.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), +parser.add_argument('--aug-repeat', type=int, default=0, + help='Number of augmentation repetitions (distributed training only) (default: 0)') parser.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') -parser.add_argument('--jsd', action='store_true', default=False, +parser.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') +parser.add_argument('--bce-loss', action='store_true', default=False, + help='Enable BCE loss w/ Mixup/CutMix use.') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', @@ -516,6 +524,7 @@ def main(): vflip=args.vflip, color_jitter=args.color_jitter, auto_augment=args.aa, + num_aug_repeats=args.aug_repeats, num_aug_splits=num_aug_splits, interpolation=train_interpolation, mean=data_config['mean'], @@ -543,16 +552,23 @@ def main(): ) # setup loss function - if args.jsd: + if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set - train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() + train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_active: - # smoothing is handled with mixup target transform - train_loss_fn = SoftTargetCrossEntropy().cuda() + # smoothing is handled with mixup target transform which outputs sparse, soft targets + if args.bce_loss: + train_loss_fn = nn.BCEWithLogitsLoss() + else: + train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: - train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() + if args.bce_loss: + train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing) + else: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: - train_loss_fn = nn.CrossEntropyLoss().cuda() + train_loss_fn = nn.CrossEntropyLoss() + train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() # setup checkpoint saver and eval metric tracking @@ -692,7 +708,7 @@ def train_one_epoch( if args.local_rank == 0: _logger.info( 'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' - 'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' + 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'LR: {lr:.3e} '