@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help = ' Image resize interpolation type (overrides model) ' )
help = ' Image resize interpolation type (overrides model) ' )
group . add_argument ( ' -b ' , ' --batch-size ' , type = int , default = 128 , metavar = ' N ' ,
group . add_argument ( ' -b ' , ' --batch-size ' , type = int , default = 128 , metavar = ' N ' ,
help = ' Input batch size for training (default: 128) ' )
help = ' Input batch size for training (default: 128) ' )
group . add_argument ( ' --iters-to-accum ' , type = int , default = 1 , metavar = ' N ' ,
help = ' The number of iterations to accumulate gradients (default: 1) ' )
group . add_argument ( ' -vb ' , ' --validation-batch-size ' , type = int , default = None , metavar = ' N ' ,
group . add_argument ( ' -vb ' , ' --validation-batch-size ' , type = int , default = None , metavar = ' N ' ,
help = ' Validation batch size override (default: None) ' )
help = ' Validation batch size override (default: None) ' )
group . add_argument ( ' --channels-last ' , action = ' store_true ' , default = False ,
group . add_argument ( ' --channels-last ' , action = ' store_true ' , default = False ,
@ -399,6 +401,9 @@ def main():
if args . amp_dtype == ' bfloat16 ' :
if args . amp_dtype == ' bfloat16 ' :
amp_dtype = torch . bfloat16
amp_dtype = torch . bfloat16
# check if iters_to_accum is smaller than or equal to 0.
assert args . iters_to_accum > 0 , ' The argument " iters-to-accum " must be greater than zero. '
utils . random_seed ( args . seed , args . rank )
utils . random_seed ( args . seed , args . rank )
if args . fuser :
if args . fuser :
@ -851,11 +856,23 @@ def train_one_epoch(
model . train ( )
model . train ( )
end = time . time ( )
end = time . time ( )
num_batches_per_epoch = len ( loader )
num_batches_per_epoch = ( len ( loader ) + args . iters_to_accum - 1 ) / / args . iters_to_accum
last_idx = num_batches_per_epoch - 1
last_idx = len ( loader ) - 1
last_iters_to_accum = len ( loader ) % args . iters_to_accum
last_idx_to_accum = len ( loader ) - last_iters_to_accum
num_updates = epoch * num_batches_per_epoch
num_updates = epoch * num_batches_per_epoch
optimizer . zero_grad ( )
num_step_samples = 0
for batch_idx , ( input , target ) in enumerate ( loader ) :
for batch_idx , ( input , target ) in enumerate ( loader ) :
last_batch = batch_idx == last_idx
last_batch = batch_idx == last_idx
iters_to_accum = args . iters_to_accum
if batch_idx > = last_idx_to_accum :
iters_to_accum = last_iters_to_accum
need_step = False
if ( batch_idx + 1 ) % args . iters_to_accum == 0 or last_batch :
need_step = True
data_time_m . update ( time . time ( ) - end )
data_time_m . update ( time . time ( ) - end )
if not args . prefetcher :
if not args . prefetcher :
input , target = input . to ( device ) , target . to ( device )
input , target = input . to ( device ) , target . to ( device )
@ -864,21 +881,31 @@ def train_one_epoch(
if args . channels_last :
if args . channels_last :
input = input . contiguous ( memory_format = torch . channels_last )
input = input . contiguous ( memory_format = torch . channels_last )
def _forward ( ) :
with amp_autocast ( ) :
with amp_autocast ( ) :
output = model ( input )
output = model ( input )
loss = loss_fn ( output , target )
loss = loss_fn ( output , target )
loss / = iters_to_accum
return loss
if need_step is not True and hasattr ( model , " no_sync " ) :
with model . no_sync ( ) :
loss = _forward ( )
else :
loss = _forward ( )
if not args . distributed :
if not args . distributed :
losses_m . update ( loss . item ( ) , input . size ( 0 ) )
losses_m . update ( loss . item ( ) * iters_to_accum , input . size ( 0 ) )
optimizer . zero_grad ( )
def _backward ( ) :
if loss_scaler is not None :
if loss_scaler is not None :
loss_scaler (
loss_scaler (
loss , optimizer ,
loss , optimizer ,
clip_grad = args . clip_grad ,
clip_grad = args . clip_grad ,
clip_mode = args . clip_mode ,
clip_mode = args . clip_mode ,
parameters = model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
parameters = model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
create_graph = second_order
create_graph = second_order ,
need_step = need_step
)
)
else :
else :
loss . backward ( create_graph = second_order )
loss . backward ( create_graph = second_order )
@ -888,22 +915,31 @@ def train_one_epoch(
value = args . clip_grad ,
value = args . clip_grad ,
mode = args . clip_mode
mode = args . clip_mode
)
)
if need_step :
optimizer . step ( )
optimizer . step ( )
num_step_samples + = input . size ( 0 )
if need_step is not True and hasattr ( model , " no_sync " ) :
with model . no_sync ( ) :
_backward ( )
else :
_backward ( )
if need_step :
optimizer . zero_grad ( )
if model_ema is not None :
if model_ema is not None :
model_ema . update ( model )
model_ema . update ( model )
torch . cuda . synchronize ( )
torch . cuda . synchronize ( )
num_updates + = 1
num_updates + = 1
batch_time_m . update ( time . time ( ) - end )
batch_time_m . update ( time . time ( ) - end )
if last_batch or batch_idx % args . log_interval == 0 :
if ( batch_idx / / args . iters_to_accum ) % args . log_interval == 0 :
lrl = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
lrl = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
lr = sum ( lrl ) / len ( lrl )
lr = sum ( lrl ) / len ( lrl )
if args . distributed :
if args . distributed :
reduced_loss = utils . reduce_tensor ( loss . data , args . world_size )
reduced_loss = utils . reduce_tensor ( loss . data , args . world_size )
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
losses_m . update ( reduced_loss . item ( ) * iters_to_accum , input . size ( 0 ) )
if utils . is_primary ( args ) :
if utils . is_primary ( args ) :
_logger . info (
_logger . info (
@ -918,8 +954,8 @@ def train_one_epoch(
100. * batch_idx / last_idx ,
100. * batch_idx / last_idx ,
loss = losses_m ,
loss = losses_m ,
batch_time = batch_time_m ,
batch_time = batch_time_m ,
rate = input . size ( 0 ) * args . world_size / batch_time_m . val ,
rate = num_step_samples * args . world_size / batch_time_m . val ,
rate_avg = input . size ( 0 ) * args . world_size / batch_time_m . avg ,
rate_avg = num_step_samples * args . world_size / batch_time_m . avg ,
lr = lr ,
lr = lr ,
data_time = data_time_m )
data_time = data_time_m )
)
)
@ -933,12 +969,12 @@ def train_one_epoch(
)
)
if saver is not None and args . recovery_interval and (
if saver is not None and args . recovery_interval and (
last_batch or ( batch_idx + 1 ) % args . recovery_interval == 0 ) :
( batch_idx / / args . iters_to_accum + 1 ) % args . recovery_interval == 0 ) :
saver . save_recovery ( epoch , batch_idx = batch_idx )
saver . save_recovery ( epoch , batch_idx = batch_idx )
if lr_scheduler is not None :
if lr_scheduler is not None :
lr_scheduler . step_update ( num_updates = num_updates , metric = losses_m . avg )
lr_scheduler . step_update ( num_updates = num_updates , metric = losses_m . avg )
num_step_samples = 0
end = time . time ( )
end = time . time ( )
# end for
# end for