From 9e12530433f38c536e9a5fdfe5c9f455638a3e8a Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Thu, 26 May 2022 08:57:47 -0400 Subject: [PATCH] use utils namespace instead of function/classnames This fixes buggy behavior introduced by https://github.com/rwightman/pytorch-image-models/pull/1266. Related to https://github.com/rwightman/pytorch-image-models/pull/1273. --- train.py | 48 +++++++++++++++++++++++------------------------- 1 file changed, 23 insertions(+), 25 deletions(-) diff --git a/train.py b/train.py index c95ec150..047a8256 100755 --- a/train.py +++ b/train.py @@ -31,9 +31,7 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters -from timm.utils import setup_default_logging, random_seed, set_jit_fuser, ModelEmaV2,\ - get_outdir, CheckpointSaver, distribute_bn, update_summary, accuracy, AverageMeter,\ - dispatch_clip_grad, reduce_tensor +from timm import utils from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ LabelSmoothingCrossEntropy from timm.optim import create_optimizer_v2, optimizer_kwargs @@ -346,7 +344,7 @@ def _parse_args(): def main(): - setup_default_logging() + utils.setup_default_logging() args, args_text = _parse_args() if args.log_wandb: @@ -391,10 +389,10 @@ def main(): _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6") - random_seed(args.seed, args.rank) + utils.random_seed(args.seed, args.rank) if args.fuser: - set_jit_fuser(args.fuser) + utils.set_jit_fuser(args.fuser) model = create_model( args.model, @@ -492,7 +490,7 @@ def main(): model_ema = None if args.model_ema: # Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper - model_ema = ModelEmaV2( + model_ema = utils.ModelEmaV2( model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None) if args.resume: load_checkpoint(model_ema.module, args.resume, use_ema=True) @@ -640,9 +638,9 @@ def main(): safe_model_name(args.model), str(data_config['input_size'][-1]) ]) - output_dir = get_outdir(args.output if args.output else './output/train', exp_name) + output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False - saver = CheckpointSaver( + saver = utils.CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler, checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist) with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: @@ -661,13 +659,13 @@ def main(): if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") - distribute_bn(model, args.world_size, args.dist_bn == 'reduce') + utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): - distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') + utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics @@ -677,7 +675,7 @@ def main(): lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) if output_dir is not None: - update_summary( + utils.update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb) @@ -704,9 +702,9 @@ def train_one_epoch( mixup_fn.mixup_enabled = False second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order - batch_time_m = AverageMeter() - data_time_m = AverageMeter() - losses_m = AverageMeter() + batch_time_m = utils.AverageMeter() + data_time_m = utils.AverageMeter() + losses_m = utils.AverageMeter() model.train() @@ -740,7 +738,7 @@ def train_one_epoch( else: loss.backward(create_graph=second_order) if args.clip_grad is not None: - dispatch_clip_grad( + utils.dispatch_clip_grad( model_parameters(model, exclude_head='agc' in args.clip_mode), value=args.clip_grad, mode=args.clip_mode) optimizer.step() @@ -756,7 +754,7 @@ def train_one_epoch( lr = sum(lrl) / len(lrl) if args.distributed: - reduced_loss = reduce_tensor(loss.data, args.world_size) + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) losses_m.update(reduced_loss.item(), input.size(0)) if args.local_rank == 0: @@ -801,10 +799,10 @@ def train_one_epoch( def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): - batch_time_m = AverageMeter() - losses_m = AverageMeter() - top1_m = AverageMeter() - top5_m = AverageMeter() + batch_time_m = utils.AverageMeter() + losses_m = utils.AverageMeter() + top1_m = utils.AverageMeter() + top5_m = utils.AverageMeter() model.eval() @@ -831,12 +829,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') target = target[0:target.size(0):reduce_factor] loss = loss_fn(output, target) - acc1, acc5 = accuracy(output, target, topk=(1, 5)) + acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) if args.distributed: - reduced_loss = reduce_tensor(loss.data, args.world_size) - acc1 = reduce_tensor(acc1, args.world_size) - acc5 = reduce_tensor(acc5, args.world_size) + reduced_loss = utils.reduce_tensor(loss.data, args.world_size) + acc1 = utils.reduce_tensor(acc1, args.world_size) + acc5 = utils.reduce_tensor(acc5, args.world_size) else: reduced_loss = loss.data