Update training script and loader factory to allow use of scheduler updates, repeat augment, and bce loss

pull/821/head
Ross Wightman 3 years ago
parent f262137ff2
commit fb94350896

@ -11,7 +11,7 @@ import numpy as np
from .transforms_factory import create_transform from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from .random_erasing import RandomErasing from .random_erasing import RandomErasing
from .mixup import FastCollateMixup from .mixup import FastCollateMixup
@ -142,6 +142,7 @@ def create_loader(
vflip=0., vflip=0.,
color_jitter=0.4, color_jitter=0.4,
auto_augment=None, auto_augment=None,
num_aug_repeats=0,
num_aug_splits=0, num_aug_splits=0,
interpolation='bilinear', interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
@ -186,11 +187,16 @@ def create_loader(
sampler = None sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset): if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training: if is_training:
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if num_aug_repeats:
sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
else:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
else: else:
# This will add extra duplicate entries to result in equal num # This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results # of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset) sampler = OrderedDistributedSampler(dataset)
else:
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
if collate_fn is None: if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate

@ -1,3 +1,4 @@
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel
from .binary_cross_entropy import DenseBinaryCrossEntropy
from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from .cross_entropy import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from .jsd import JsdCrossEntropy from .jsd import JsdCrossEntropy
from .asymmetric_loss import AsymmetricLossMultiLabel, AsymmetricLossSingleLabel

@ -32,7 +32,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup,
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters convert_splitbn_model, model_parameters
from timm.utils import * from timm.utils import *
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy from timm.loss import *
from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler from timm.utils import ApexScaler, NativeScaler
@ -140,8 +140,12 @@ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)') help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)') help='learning rate cycle len multiplier (default: 1.0)')
parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit') help='learning rate cycle limit, cycles enabled if > 1')
parser.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)') help='warmup learning rate (default: 0.0001)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR', parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
@ -178,10 +182,14 @@ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)') help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default=None, metavar='NAME', parser.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'), help='Use AutoAugment policy. "v0" or "original". (default: None)'),
parser.add_argument('--aug-repeat', type=int, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
parser.add_argument('--aug-splits', type=int, default=0, parser.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)') help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
parser.add_argument('--jsd', action='store_true', default=False, parser.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
parser.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
parser.add_argument('--reprob', type=float, default=0., metavar='PCT', parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)') help='Random erase prob (default: 0.)')
parser.add_argument('--remode', type=str, default='const', parser.add_argument('--remode', type=str, default='const',
@ -516,6 +524,7 @@ def main():
vflip=args.vflip, vflip=args.vflip,
color_jitter=args.color_jitter, color_jitter=args.color_jitter,
auto_augment=args.aa, auto_augment=args.aa,
num_aug_repeats=args.aug_repeats,
num_aug_splits=num_aug_splits, num_aug_splits=num_aug_splits,
interpolation=train_interpolation, interpolation=train_interpolation,
mean=data_config['mean'], mean=data_config['mean'],
@ -543,16 +552,23 @@ def main():
) )
# setup loss function # setup loss function
if args.jsd: if args.jsd_loss:
assert num_aug_splits > 1 # JSD only valid with aug splits set assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda() train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing)
elif mixup_active: elif mixup_active:
# smoothing is handled with mixup target transform # smoothing is handled with mixup target transform which outputs sparse, soft targets
train_loss_fn = SoftTargetCrossEntropy().cuda() if args.bce_loss:
train_loss_fn = nn.BCEWithLogitsLoss()
else:
train_loss_fn = SoftTargetCrossEntropy()
elif args.smoothing: elif args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() if args.bce_loss:
train_loss_fn = DenseBinaryCrossEntropy(smoothing=args.smoothing)
else:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else: else:
train_loss_fn = nn.CrossEntropyLoss().cuda() train_loss_fn = nn.CrossEntropyLoss()
train_loss_fn = train_loss_fn.cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda()
# setup checkpoint saver and eval metric tracking # setup checkpoint saver and eval metric tracking
@ -692,7 +708,7 @@ def train_one_epoch(
if args.local_rank == 0: if args.local_rank == 0:
_logger.info( _logger.info(
'Train: {} [{:>4d}/{} ({:>3.0f}%)] ' 'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) ' 'Loss: {loss.val:#.4g} ({loss.avg:#.3g}) '
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s ' 'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' '({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'LR: {lr:.3e} ' 'LR: {lr:.3e} '

Loading…
Cancel
Save