diff --git a/train.py b/train.py index 89ade4a1..9831ac76 100755 --- a/train.py +++ b/train.py @@ -23,6 +23,8 @@ from collections import OrderedDict from contextlib import suppress from datetime import datetime +import wandb + import torch import torch.nn as nn import torchvision.utils @@ -293,7 +295,8 @@ def _parse_args(): def main(): setup_default_logging() args, args_text = _parse_args() - + wandb.init(project='efficientnet_v2', config=args) + wandb.run.name = args.model args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -572,14 +575,14 @@ def main(): epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) - + wandb.log(train_metrics) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: _logger.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) - + wandb.log(eval_metrics) if model_ema is not None and not args.model_ema_force_cpu: if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') @@ -711,7 +714,7 @@ def train_one_epoch( if hasattr(optimizer, 'sync_lookahead'): optimizer.sync_lookahead() - return OrderedDict([('loss', losses_m.avg)]) + return OrderedDict([('train_loss', losses_m.avg)]) def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): @@ -773,7 +776,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='') log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=top1_m, top5=top5_m)) - metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) + metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) return metrics