|
|
|
@ -31,9 +31,7 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
|
|
|
|
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
|
|
|
|
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
|
|
|
|
|
convert_splitbn_model, model_parameters
|
|
|
|
|
from timm.utils import setup_default_logging, random_seed, set_jit_fuser, ModelEmaV2,\
|
|
|
|
|
get_outdir, CheckpointSaver, distribute_bn, update_summary, accuracy, AverageMeter,\
|
|
|
|
|
dispatch_clip_grad, reduce_tensor
|
|
|
|
|
from timm import utils
|
|
|
|
|
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
|
|
|
|
|
LabelSmoothingCrossEntropy
|
|
|
|
|
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
|
|
|
@ -346,7 +344,7 @@ def _parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
setup_default_logging()
|
|
|
|
|
utils.setup_default_logging()
|
|
|
|
|
args, args_text = _parse_args()
|
|
|
|
|
|
|
|
|
|
if args.log_wandb:
|
|
|
|
@ -391,10 +389,10 @@ def main():
|
|
|
|
|
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
|
|
|
|
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
|
|
|
|
|
|
|
|
|
random_seed(args.seed, args.rank)
|
|
|
|
|
utils.random_seed(args.seed, args.rank)
|
|
|
|
|
|
|
|
|
|
if args.fuser:
|
|
|
|
|
set_jit_fuser(args.fuser)
|
|
|
|
|
utils.set_jit_fuser(args.fuser)
|
|
|
|
|
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
@ -492,7 +490,7 @@ def main():
|
|
|
|
|
model_ema = None
|
|
|
|
|
if args.model_ema:
|
|
|
|
|
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
|
|
|
|
|
model_ema = ModelEmaV2(
|
|
|
|
|
model_ema = utils.ModelEmaV2(
|
|
|
|
|
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
|
|
|
|
|
if args.resume:
|
|
|
|
|
load_checkpoint(model_ema.module, args.resume, use_ema=True)
|
|
|
|
@ -640,9 +638,9 @@ def main():
|
|
|
|
|
safe_model_name(args.model),
|
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
|
])
|
|
|
|
|
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
|
|
|
|
|
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
|
saver = CheckpointSaver(
|
|
|
|
|
saver = utils.CheckpointSaver(
|
|
|
|
|
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
|
|
|
|
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
|
|
|
|
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
|
|
|
@ -661,13 +659,13 @@ def main():
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
_logger.info("Distributing BatchNorm running means and vars")
|
|
|
|
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
|
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
|
|
|
|
|
|
|
|
|
|
if model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
ema_eval_metrics = validate(
|
|
|
|
|
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
|
|
|
|
eval_metrics = ema_eval_metrics
|
|
|
|
@ -677,7 +675,7 @@ def main():
|
|
|
|
|
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
|
|
|
|
|
|
|
|
|
if output_dir is not None:
|
|
|
|
|
update_summary(
|
|
|
|
|
utils.update_summary(
|
|
|
|
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
|
|
|
|
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
|
|
|
|
|
|
|
|
|
@ -704,9 +702,9 @@ def train_one_epoch(
|
|
|
|
|
mixup_fn.mixup_enabled = False
|
|
|
|
|
|
|
|
|
|
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
|
data_time_m = AverageMeter()
|
|
|
|
|
losses_m = AverageMeter()
|
|
|
|
|
batch_time_m = utils.AverageMeter()
|
|
|
|
|
data_time_m = utils.AverageMeter()
|
|
|
|
|
losses_m = utils.AverageMeter()
|
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
@ -740,7 +738,7 @@ def train_one_epoch(
|
|
|
|
|
else:
|
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
|
|
if args.clip_grad is not None:
|
|
|
|
|
dispatch_clip_grad(
|
|
|
|
|
utils.dispatch_clip_grad(
|
|
|
|
|
model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
|
value=args.clip_grad, mode=args.clip_mode)
|
|
|
|
|
optimizer.step()
|
|
|
|
@ -756,7 +754,7 @@ def train_one_epoch(
|
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
@ -801,10 +799,10 @@ def train_one_epoch(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
|
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
|
losses_m = AverageMeter()
|
|
|
|
|
top1_m = AverageMeter()
|
|
|
|
|
top5_m = AverageMeter()
|
|
|
|
|
batch_time_m = utils.AverageMeter()
|
|
|
|
|
losses_m = utils.AverageMeter()
|
|
|
|
|
top1_m = utils.AverageMeter()
|
|
|
|
|
top5_m = utils.AverageMeter()
|
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
@ -831,12 +829,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
|
|
|
|
|
target = target[0:target.size(0):reduce_factor]
|
|
|
|
|
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
acc1, acc5 = accuracy(output, target, topk=(1, 5))
|
|
|
|
|
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
acc1 = reduce_tensor(acc1, args.world_size)
|
|
|
|
|
acc5 = reduce_tensor(acc5, args.world_size)
|
|
|
|
|
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
acc1 = utils.reduce_tensor(acc1, args.world_size)
|
|
|
|
|
acc5 = utils.reduce_tensor(acc5, args.world_size)
|
|
|
|
|
else:
|
|
|
|
|
reduced_loss = loss.data
|
|
|
|
|
|
|
|
|
|