Add wandb support

pull/550/head
Aman Arora 4 years ago
parent 779107b693
commit 8e6fb861e4

@ -23,6 +23,8 @@ from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
import wandb
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
@ -293,7 +295,8 @@ def _parse_args():
def main(): def main():
setup_default_logging() setup_default_logging()
args, args_text = _parse_args() args, args_text = _parse_args()
wandb.init(project='efficientnet_v2', config=args)
wandb.run.name = args.model
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
args.distributed = False args.distributed = False
if 'WORLD_SIZE' in os.environ: if 'WORLD_SIZE' in os.environ:
@ -572,14 +575,14 @@ def main():
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
wandb.log(train_metrics)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0: if args.local_rank == 0:
_logger.info("Distributing BatchNorm running means and vars") _logger.info("Distributing BatchNorm running means and vars")
distribute_bn(model, args.world_size, args.dist_bn == 'reduce') distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast) eval_metrics = validate(model, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast)
wandb.log(eval_metrics)
if model_ema is not None and not args.model_ema_force_cpu: if model_ema is not None and not args.model_ema_force_cpu:
if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
@ -711,7 +714,7 @@ def train_one_epoch(
if hasattr(optimizer, 'sync_lookahead'): if hasattr(optimizer, 'sync_lookahead'):
optimizer.sync_lookahead() optimizer.sync_lookahead()
return OrderedDict([('loss', losses_m.avg)]) return OrderedDict([('train_loss', losses_m.avg)])
def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''): def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix=''):
@ -773,7 +776,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
log_name, batch_idx, last_idx, batch_time=batch_time_m, log_name, batch_idx, last_idx, batch_time=batch_time_m,
loss=losses_m, top1=top1_m, top5=top5_m)) loss=losses_m, top1=top1_m, top5=top5_m))
metrics = OrderedDict([('loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)]) metrics = OrderedDict([('val_loss', losses_m.avg), ('top1', top1_m.avg), ('top5', top5_m.avg)])
return metrics return metrics

Loading…
Cancel
Save