@ -452,6 +452,7 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool):
model_ema = args . model_ema ,
model_ema_decay = args . model_ema_decay ,
resume_path = args . resume ,
resume_opt = not args . no_resume_opt ,
use_syncbn = args . sync_bn ,
)
@ -683,7 +684,7 @@ def after_train_step(
Returns :
"""
end _step = step_idx == step_end_idx
last _step = step_idx == step_end_idx
with torch . no_grad ( ) :
output , target , loss = tensors
@ -696,15 +697,15 @@ def after_train_step(
state = replace ( state , step_count_global = state . step_count_global + 1 )
cfg = state . train_cfg
if services . monitor is not None and end _step or ( step_idx + 1 ) % cfg . log_interval == 0 :
if services . monitor is not None and last _step or ( step_idx + 1 ) % cfg . log_interval == 0 :
global_batch_size = dev_env . world_size * output . shape [ 0 ]
loss_avg = loss_meter . compute ( )
if services . monitor is not None :
lr_avg = state . updater . get_average_lr ( )
services . monitor . log_step (
' Train ' ,
step = step_idx ,
step_end = step_end_idx ,
step _idx = step_idx ,
step_end _idx = step_end_idx ,
epoch = state . epoch ,
loss = loss_avg . item ( ) ,
rate = tracker . get_avg_iter_rate ( global_batch_size ) ,
@ -712,8 +713,8 @@ def after_train_step(
)
if services . checkpoint is not None and cfg . recovery_interval and (
end _step or ( step_idx + 1 ) % cfg . recovery_interval == 0 ) :
services . checkpoint . save_recovery ( state . epoch , batch_idx = step_idx )
last _step or ( step_idx + 1 ) % cfg . recovery_interval == 0 ) :
services . checkpoint . save_recovery ( state )
if state . lr_scheduler is not None :
# FIXME perform scheduler update here or via updater after_step call?
@ -770,8 +771,8 @@ def evaluate(
loss_avg = losses_m . compute ( )
logger . log_step (
' Eval ' ,
step = step_idx ,
step_end = end_idx ,
step _idx = step_idx ,
step_end _idx = end_idx ,
loss = loss_avg . item ( ) ,
top1 = top1 . item ( ) ,
top5 = top5 . item ( ) ,