@ -1,6 +1,7 @@
import argparse
import time
import logging
from datetime import datetime
try :
@ -127,14 +128,14 @@ parser.add_argument("--local_rank", default=0, type=int)
def main ( ) :
setup_default_logging ( )
args = parser . parse_args ( )
args . prefetcher = not args . no_prefetcher
args . distributed = False
if ' WORLD_SIZE ' in os . environ :
args . distributed = int ( os . environ [ ' WORLD_SIZE ' ] ) > 1
if args . distributed and args . num_gpu > 1 :
print ( ' Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1. ' )
logging . warning ( ' Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1. ' )
args . num_gpu = 1
args . device = ' cuda:0 '
@ -144,17 +145,16 @@ def main():
args . num_gpu = 1
args . device = ' cuda: %d ' % args . local_rank
torch . cuda . set_device ( args . local_rank )
torch . distributed . init_process_group (
backend = ' nccl ' , init_method = ' env:// ' )
torch . distributed . init_process_group ( backend = ' nccl ' , init_method = ' env:// ' )
args . world_size = torch . distributed . get_world_size ( )
args . rank = torch . distributed . get_rank ( )
assert args . rank > = 0
if args . distributed :
print ( ' Training in distributed mode with multiple processes, 1 GPU per process. Process %d , total %d . '
% ( args . rank , args . world_size ) )
logging . info ( ' Training in distributed mode with multiple processes, 1 GPU per process. Process %d , total %d . '
% ( args . rank , args . world_size ) )
else :
print ( ' Training with a single process on %d GPUs. ' % args . num_gpu )
logging . info ( ' Training with a single process on %d GPUs. ' % args . num_gpu )
torch . manual_seed ( args . seed + args . rank )
@ -169,8 +169,8 @@ def main():
bn_eps = args . bn_eps ,
checkpoint_path = args . initial_checkpoint )
print ( ' Model %s created, param count: %d ' %
( args . model , sum ( [ m . numel ( ) for m in model . parameters ( ) ] ) ) )
logging . info ( ' Model %s created, param count: %d ' %
( args . model , sum ( [ m . numel ( ) for m in model . parameters ( ) ] ) ) )
data_config = resolve_data_config ( model , args , verbose = args . local_rank == 0 )
@ -182,8 +182,8 @@ def main():
if args . num_gpu > 1 :
if args . amp :
print ( ' Warning: AMP does not work well with nn.DataParallel, disabling. '
' Use distributed mode for multi-GPU AMP.' )
logging . warning (
' AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' )
args . amp = False
model = nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
else :
@ -198,10 +198,10 @@ def main():
if has_apex and args . amp :
model , optimizer = amp . initialize ( model , optimizer , opt_level = ' O1 ' )
use_amp = True
print ( ' AMP enabled ' )
logging . info ( ' AMP enabled ' )
else :
use_amp = False
print ( ' AMP disabled ' )
logging . info ( ' AMP disabled ' )
model_ema = None
if args . model_ema :
@ -222,11 +222,11 @@ def main():
if start_epoch > 0 :
lr_scheduler . step ( start_epoch )
if args . local_rank == 0 :
print ( ' Scheduled epochs: ' , num_epochs )
logging . info ( ' Scheduled epochs: {} ' . format ( num_epochs ) )
train_dir = os . path . join ( args . data , ' train ' )
if not os . path . exists ( train_dir ) :
print ( ' Error: training folder does not exist at: %s ' % train_dir )
logging . error ( ' Training folder does not exist at: {} ' . format ( train_dir ) )
exit ( 1 )
dataset_train = Dataset ( train_dir )
@ -252,7 +252,7 @@ def main():
eval_dir = os . path . join ( args . data , ' validation ' )
if not os . path . isdir ( eval_dir ) :
print ( ' Error: validation folder does not exist at: %s ' % eval_dir )
logging . error ( ' Validation folder does not exist at: {} ' . format ( eval_dir ) )
exit ( 1 )
dataset_eval = Dataset ( eval_dir )
@ -332,7 +332,7 @@ def main():
except KeyboardInterrupt :
pass
if best_metric is not None :
print ( ' *** Best metric: {0} (epoch {1} ) ' . format ( best_metric , best_epoch ) )
logging . info ( ' *** Best metric: {0} (epoch {1} ) ' . format ( best_metric , best_epoch ) )
def train_epoch (
@ -394,21 +394,22 @@ def train_epoch(
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
if args . local_rank == 0 :
print ( ' Train: {} [ {} / {} ( {:.0f} % )] '
' Loss: {loss.val:.6f} ( {loss.avg:.4f} ) '
' Time: {batch_time.val:.3f} s, {rate:.3f} /s '
' ( {batch_time.avg:.3f} s, {rate_avg:.3f} /s) '
' LR: {lr:.4f} '
' Data: {data_time.val:.3f} ( {data_time.avg:.3f} ) ' . format (
epoch ,
batch_idx , len ( loader ) ,
100. * batch_idx / last_idx ,
loss = losses_m ,
batch_time = batch_time_m ,
rate = input . size ( 0 ) * args . world_size / batch_time_m . val ,
rate_avg = input . size ( 0 ) * args . world_size / batch_time_m . avg ,
lr = lr ,
data_time = data_time_m ) )
logging . info (
' Train: {} [ {:>4d} / {} ( {:>3.0f} % )] '
' Loss: {loss.val:>9.6f} ( {loss.avg:>6.4f} ) '
' Time: {batch_time.val:.3f} s, {rate:>7.2f} /s '
' ( {batch_time.avg:.3f} s, {rate_avg:>7.2f} /s) '
' LR: {lr:.3e} '
' Data: {data_time.val:.3f} ( {data_time.avg:.3f} ) ' . format (
epoch ,
batch_idx , len ( loader ) ,
100. * batch_idx / last_idx ,
loss = losses_m ,
batch_time = batch_time_m ,
rate = input . size ( 0 ) * args . world_size / batch_time_m . val ,
rate_avg = input . size ( 0 ) * args . world_size / batch_time_m . avg ,
lr = lr ,
data_time = data_time_m ) )
if args . save_images and output_dir :
torchvision . utils . save_image (
@ -478,14 +479,15 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
end = time . time ( )
if args . local_rank == 0 and ( last_batch or batch_idx % args . log_interval == 0 ) :
log_name = ' Test ' + log_suffix
print ( ' {0} : [ {1} / {2} ] \t '
' Time {batch_time.val:.3f} ( {batch_time.avg:.3f} ) '
' Loss {loss.val:.4f} ( {loss.avg:.4f} ) '
' Prec@1 {top1.val:.4f} ( {top1.avg:.4f} ) '
' Prec@5 {top5.val:.4f} ( {top5.avg:.4f} ) ' . format (
log_name , batch_idx , last_idx ,
batch_time = batch_time_m , loss = losses_m ,
top1 = prec1_m , top5 = prec5_m ) )
logging . info (
' {0} : [ {1:>4d} / {2} ] '
' Time: {batch_time.val:.3f} ( {batch_time.avg:.3f} ) '
' Loss: {loss.val:>7.4f} ( {loss.avg:>6.4f} ) '
' Prec@1: {top1.val:>7.4f} ( {top1.avg:>7.4f} ) '
' Prec@5: {top5.val:>7.4f} ( {top5.avg:>7.4f} ) ' . format (
log_name , batch_idx , last_idx ,
batch_time = batch_time_m , loss = losses_m ,
top1 = prec1_m , top5 = prec5_m ) )
metrics = OrderedDict ( [ ( ' loss ' , losses_m . avg ) , ( ' prec1 ' , prec1_m . avg ) , ( ' prec5 ' , prec5_m . avg ) ] )