@ -6,12 +6,13 @@ from datetime import datetime
try :
from apex import amp
from apex . parallel import DistributedDataParallel as DDP
from apex . parallel import convert_syncbn_model
has_apex = True
except ImportError :
has_apex = False
from data import Dataset , create_loader , resolve_data_config , FastCollateMixup , mixup_target
from models import create_model , resume_checkpoint
from models import create_model , resume_checkpoint , load_checkpoint
from utils import *
from loss import LabelSmoothingCrossEntropy , SoftTargetCrossEntropy
from optim import create_optimizer
@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None,
help = ' BatchNorm momentum override (if not None) ' )
parser . add_argument ( ' --bn-eps ' , type = float , default = None ,
help = ' BatchNorm epsilon override (if not None) ' )
parser . add_argument ( ' --model-ema ' , action = ' store_true ' , default = False ,
help = ' Enable tracking moving average of model weights ' )
parser . add_argument ( ' --model-ema-force-cpu ' , action = ' store_true ' , default = False ,
help = ' Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation. ' )
parser . add_argument ( ' --model-ema-decay ' , type = float , default = 0.9998 ,
help = ' decay factor for model weights moving average (default: 0.9998) ' )
parser . add_argument ( ' --seed ' , type = int , default = 42 , metavar = ' S ' ,
help = ' random seed (default: 42) ' )
parser . add_argument ( ' --log-interval ' , type = int , default = 50 , metavar = ' N ' ,
help = ' how many batches to wait before logging training status ' )
parser . add_argument ( ' --recovery-interval ' , type = int , default = 1000 , metavar = ' N ' ,
parser . add_argument ( ' --recovery-interval ' , type = int , default = 0, metavar = ' N ' ,
help = ' how many batches to wait before writing recovery checkpoint ' )
parser . add_argument ( ' -j ' , ' --workers ' , type = int , default = 4 , metavar = ' N ' ,
help = ' how many training processes to use (default: 1) ' )
@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help = ' save images of input bathes every log interval for debugging ' )
parser . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
help = ' use NVIDIA amp for mixed precision training ' )
parser . add_argument ( ' --sync-bn ' , action = ' store_true ' ,
help = ' enabling apex sync BN. ' )
parser . add_argument ( ' --no-prefetcher ' , action = ' store_true ' , default = False ,
help = ' disable fast prefetcher ' )
parser . add_argument ( ' --output ' , default = ' ' , type = str , metavar = ' PATH ' ,
@ -131,31 +140,28 @@ def main():
args . device = ' cuda:0 '
args . world_size = 1
r = - 1
args. rank = 0 # global rank
if args . distributed :
args . num_gpu = 1
args . device = ' cuda: %d ' % args . local_rank
torch . cuda . set_device ( args . local_rank )
torch . distributed . init_process_group ( backend = ' nccl ' ,
init_method = ' env:// ' )
torch . distributed . init_process_group (
backend = ' nccl ' , init_method = ' env:// ' )
args . world_size = torch . distributed . get_world_size ( )
r = torch . distributed . get_rank ( )
args . rank = torch . distributed . get_rank ( )
assert args . rank > = 0
if args . distributed :
print ( ' Training in distributed mode with multiple processes, 1 GPU per process. Process %d , total %d . '
% ( r, args . world_size ) )
% ( a rgs. rank , args . world_size ) )
else :
print ( ' Training with a single process on %d GPUs. ' % args . num_gpu )
# FIXME seed handling for multi-process distributed?
torch . manual_seed ( args . seed )
torch . manual_seed ( args . seed + args . rank )
output_dir = ' '
if args . local_rank == 0 :
if args . output :
output_base = args . output
else :
output_base = ' ./output '
output_base = args . output if args . output else ' ./output '
exp_name = ' - ' . join ( [
datetime . now ( ) . strftime ( " % Y % m %d - % H % M % S " ) ,
args . model ,
@ -191,6 +197,8 @@ def main():
args . amp = False
model = nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
else :
if args . distributed and args . sync_bn and has_apex :
model = convert_syncbn_model ( model )
model . cuda ( )
optimizer = create_optimizer ( args , model )
@ -205,8 +213,20 @@ def main():
use_amp = False
print ( ' AMP disabled ' )
model_ema = None
if args . model_ema :
model_ema = ModelEma (
model ,
decay = args . model_ema_decay ,
device = ' cpu ' if args . model_ema_force_cpu else ' ' ,
resume = args . resume )
if args . distributed :
model = DDP ( model , delay_allreduce = True )
if model_ema is not None and not args . model_ema_force_cpu :
# must also distribute EMA model to allow validation
model_ema . ema = DDP ( model_ema . ema , delay_allreduce = True )
model_ema . ema_has_module = True
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
if start_epoch > 0 :
@ -273,6 +293,7 @@ def main():
eval_metric = args . eval_metric
saver = None
if output_dir :
# only set if process is rank 0
decreasing = True if eval_metric == ' loss ' else False
saver = CheckpointSaver ( checkpoint_dir = output_dir , decreasing = decreasing )
best_metric = None
@ -284,10 +305,15 @@ def main():
train_metrics = train_epoch (
epoch , model , loader_train , optimizer , train_loss_fn , args ,
lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir , use_amp = use_amp )
lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
use_amp = use_amp , model_ema = model_ema )
eval_metrics = validate ( model , loader_eval , validate_loss_fn , args )
eval_metrics = validate (
model , loader_eval , validate_loss_fn , args )
if model_ema is not None and not args . model_ema_force_cpu :
ema_eval_metrics = validate (
model_ema . ema , loader_eval , validate_loss_fn , args , log_suffix = ' (EMA) ' )
eval_metrics = ema_eval_metrics
if lr_scheduler is not None :
lr_scheduler . step ( epoch , eval_metrics [ eval_metric ] )
@ -298,15 +324,12 @@ def main():
if saver is not None :
# save proper checkpoint with eval metric
best_metric , best_epoch = saver . save_checkpoint ( {
' epoch ' : epoch + 1 ,
' arch ' : args . model ,
' state_dict ' : model . state_dict ( ) ,
' optimizer ' : optimizer . state_dict ( ) ,
' args ' : args ,
} ,
save_metric = eval_metrics [ eval_metric ]
best_metric , best_epoch = saver . save_checkpoint (
model , optimizer , args ,
epoch = epoch + 1 ,
metric = eval_metrics [ eval_metric ] )
model_ema = model_ema ,
metric = save_metric )
except KeyboardInterrupt :
pass
@ -316,7 +339,7 @@ def main():
def train_epoch (
epoch , model , loader , optimizer , loss_fn , args ,
lr_scheduler = None , saver = None , output_dir = ' ' , use_amp = False ):
lr_scheduler = None , saver = None , output_dir = ' ' , use_amp = False , model_ema = None ):
if args . prefetcher and args . mixup > 0 and loader . mixup_enabled :
if args . mixup_off_epoch and epoch > = args . mixup_off_epoch :
@ -359,6 +382,8 @@ def train_epoch(
optimizer . step ( )
torch . cuda . synchronize ( )
if model_ema is not None :
model_ema . update ( model )
num_updates + = 1
batch_time_m . update ( time . time ( ) - end )
@ -394,18 +419,11 @@ def train_epoch(
padding = 0 ,
normalize = True )
if args. local_rank == 0 and (
saver is not None and last_batch or ( batch_idx + 1 ) % args . recovery_interval == 0 ) :
if saver is not None and args . recovery_interval and (
last_batch or ( batch_idx + 1 ) % args . recovery_interval == 0 ) :
save_epoch = epoch + 1 if last_batch else epoch
saver . save_recovery ( {
' epoch ' : save_epoch ,
' arch ' : args . model ,
' state_dict ' : model . state_dict ( ) ,
' optimizer ' : optimizer . state_dict ( ) ,
' args ' : args ,
} ,
epoch = save_epoch ,
batch_idx = batch_idx )
saver . save_recovery (
model , optimizer , args , save_epoch , model_ema = model_ema , batch_idx = batch_idx )
if lr_scheduler is not None :
lr_scheduler . step_update ( num_updates = num_updates , metric = losses_m . avg )
@ -415,7 +433,7 @@ def train_epoch(
return OrderedDict ( [ ( ' loss ' , losses_m . avg ) ] )
def validate ( model , loader , loss_fn , args ):
def validate ( model , loader , loss_fn , args , log_suffix = ' ' ):
batch_time_m = AverageMeter ( )
losses_m = AverageMeter ( )
prec1_m = AverageMeter ( )
@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
batch_time_m . update ( time . time ( ) - end )
end = time . time ( )
if args . local_rank == 0 and ( last_batch or batch_idx % args . log_interval == 0 ) :
print ( ' Test: [ {0} / {1} ] \t '
log_name = ' Test ' + log_suffix
print ( ' {0} : [ {1} / {2} ] \t '
' Time {batch_time.val:.3f} ( {batch_time.avg:.3f} ) '
' Loss {loss.val:.4f} ( {loss.avg:.4f} ) '
' Prec@1 {top1.val:.4f} ( {top1.avg:.4f} ) '
' Prec@5 {top5.val:.4f} ( {top5.avg:.4f} ) ' . format (
batch_idx, last_idx ,
log_name, batch_idx, last_idx ,
batch_time = batch_time_m , loss = losses_m ,
top1 = prec1_m , top5 = prec5_m ) )
@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
return metrics
def reduce_tensor ( tensor , n ) :
rt = tensor . clone ( )
dist . all_reduce ( rt , op = dist . ReduceOp . SUM )
rt / = n
return rt
if __name__ == ' __main__ ' :
main ( )