use utils namespace instead of function/classnames

This fixes buggy behavior introduced by
https://github.com/rwightman/pytorch-image-models/pull/1266.

Related to https://github.com/rwightman/pytorch-image-models/pull/1273.
pull/1322/head
Jakub Kaczmarzyk 2 years ago committed by Ross Wightman
parent db64393c0d
commit 9e12530433

@ -31,9 +31,7 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm.utils import setup_default_logging, random_seed, set_jit_fuser, ModelEmaV2,\
get_outdir, CheckpointSaver, distribute_bn, update_summary, accuracy, AverageMeter,\
dispatch_clip_grad, reduce_tensor
from timm import utils
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
LabelSmoothingCrossEntropy
from timm.optim import create_optimizer_v2, optimizer_kwargs
@ -346,7 +344,7 @@ def _parse_args():
def main():
setup_default_logging()
utils.setup_default_logging()
args, args_text = _parse_args()
if args.log_wandb:
@ -391,10 +389,10 @@ def main():
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
"Install NVIDA apex or upgrade to PyTorch 1.6")
random_seed(args.seed, args.rank)
utils.random_seed(args.seed, args.rank)
if args.fuser:
set_jit_fuser(args.fuser)
utils.set_jit_fuser(args.fuser)
model = create_model(
args.model,
@ -492,7 +490,7 @@ def main():
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but before DDP wrapper
model_ema = ModelEmaV2(
model_ema = utils.ModelEmaV2(
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
if args.resume:
load_checkpoint(model_ema.module, args.resume, use_ema=True)
@ -640,9 +638,9 @@ def main():
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(
saver = utils.CheckpointSaver(
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
checkpoint_dir=output_dir, recovery_dir=output_dir, decreasing=decreasing, max_history=args.checkpoint_hist)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
@ -661,13 +659,13 @@ def main():
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
utils.distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
utils.distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
ema_eval_metrics = validate(
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
@ -677,7 +675,7 @@ def main():
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
if output_dir is not None:
update_summary(
utils.update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_metric is None, log_wandb=args.log_wandb and has_wandb)
@ -704,9 +702,9 @@ def train_one_epoch(
mixup_fn.mixup_enabled = False
second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
losses_m = AverageMeter()
batch_time_m = utils.AverageMeter()
data_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
model.train()
@ -740,7 +738,7 @@ def train_one_epoch(
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
dispatch_clip_grad(
utils.dispatch_clip_grad(
model_parameters(model, exclude_head='agc' in args.clip_mode),
value=args.clip_grad, mode=args.clip_mode)
optimizer.step()
@ -756,7 +754,7 @@ def train_one_epoch(
lr = sum(lrl) / len(lrl)
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
@ -801,10 +799,10 @@ def train_one_epoch(
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
top1_m = AverageMeter()
top5_m = AverageMeter()
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
top1_m = utils.AverageMeter()
top5_m = utils.AverageMeter()
model.eval()
@ -831,12 +829,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
target = target[0:target.size(0):reduce_factor]
loss = loss_fn(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
acc1 = reduce_tensor(acc1, args.world_size)
acc5 = reduce_tensor(acc5, args.world_size)
reduced_loss = utils.reduce_tensor(loss.data, args.world_size)
acc1 = utils.reduce_tensor(acc1, args.world_size)
acc5 = utils.reduce_tensor(acc5, args.world_size)
else:
reduced_loss = loss.data

Loading…
Cancel
Save