@ -130,6 +130,8 @@ group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help = ' Image resize interpolation type (overrides model) ' )
group . add_argument ( ' -b ' , ' --batch-size ' , type = int , default = 128 , metavar = ' N ' ,
help = ' Input batch size for training (default: 128) ' )
group . add_argument ( ' --iters-to-accum ' , type = int , default = 1 , metavar = ' N ' ,
help = ' The number of iterations to accumulate gradients (default: 1) ' )
group . add_argument ( ' -vb ' , ' --validation-batch-size ' , type = int , default = None , metavar = ' N ' ,
help = ' Validation batch size override (default: None) ' )
group . add_argument ( ' --channels-last ' , action = ' store_true ' , default = False ,
@ -399,6 +401,9 @@ def main():
if args . amp_dtype == ' bfloat16 ' :
amp_dtype = torch . bfloat16
# check if iters_to_accum is smaller than or equal to 0.
assert args . iters_to_accum > 0 , ' The argument " iters-to-accum " must be greater than zero. '
utils . random_seed ( args . seed , args . rank )
if args . fuser :
@ -851,11 +856,23 @@ def train_one_epoch(
model . train ( )
end = time . time ( )
num_batches_per_epoch = len ( loader )
last_idx = num_batches_per_epoch - 1
num_batches_per_epoch = ( len ( loader ) + args . iters_to_accum - 1 ) / / args . iters_to_accum
last_idx = len ( loader ) - 1
last_iters_to_accum = len ( loader ) % args . iters_to_accum
last_idx_to_accum = len ( loader ) - last_iters_to_accum
num_updates = epoch * num_batches_per_epoch
optimizer . zero_grad ( )
num_step_samples = 0
for batch_idx , ( input , target ) in enumerate ( loader ) :
last_batch = batch_idx == last_idx
iters_to_accum = args . iters_to_accum
if batch_idx > = last_idx_to_accum :
iters_to_accum = last_iters_to_accum
need_step = False
if ( batch_idx + 1 ) % args . iters_to_accum == 0 or last_batch :
need_step = True
data_time_m . update ( time . time ( ) - end )
if not args . prefetcher :
input , target = input . to ( device ) , target . to ( device )
@ -864,82 +881,101 @@ def train_one_epoch(
if args . channels_last :
input = input . contiguous ( memory_format = torch . channels_last )
with amp_autocast ( ) :
output = model ( input )
loss = loss_fn ( output , target )
def _forward ( ) :
with amp_autocast ( ) :
output = model ( input )
loss = loss_fn ( output , target )
loss / = iters_to_accum
return loss
if not args . distributed :
losses_m . update ( loss . item ( ) , input . size ( 0 ) )
optimizer . zero_grad ( )
if loss_scaler is not None :
loss_scaler (
loss , optimizer ,
clip_grad = args . clip_grad ,
clip_mode = args . clip_mode ,
parameters = model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
create_graph = second_order
)
if need_step is not True and hasattr ( model , " no_sync " ) :
with model . no_sync ( ) :
loss = _forward ( )
else :
loss . backward ( create_graph = second_order )
if args . clip_grad is not None :
utils . dispatch_clip_grad (
model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
value = args . clip_grad ,
mode = args . clip_mode
)
optimizer . step ( )
if model_ema is not None :
model_ema . update ( model )
torch . cuda . synchronize ( )
num_updates + = 1
batch_time_m . update ( time . time ( ) - end )
if last_batch or batch_idx % args . log_interval == 0 :
lrl = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
lr = sum ( lrl ) / len ( lrl )
loss = _forward ( )
if args . distributed :
reduced_loss = utils . reduce_tensor ( loss . data , args . world_size )
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
if utils . is_primary ( args ) :
_logger . info (
' Train: {} [ {:>4d} / {} ( {:>3.0f} % )] '
' Loss: {loss.val:#.4g} ( {loss.avg:#.3g} ) '
' Time: {batch_time.val:.3f} s, {rate:>7.2f} /s '
' ( {batch_time.avg:.3f} s, {rate_avg:>7.2f} /s) '
' LR: {lr:.3e} '
' Data: {data_time.val:.3f} ( {data_time.avg:.3f} ) ' . format (
epoch ,
batch_idx , len ( loader ) ,
100. * batch_idx / last_idx ,
loss = losses_m ,
batch_time = batch_time_m ,
rate = input . size ( 0 ) * args . world_size / batch_time_m . val ,
rate_avg = input . size ( 0 ) * args . world_size / batch_time_m . avg ,
lr = lr ,
data_time = data_time_m )
if not args . distributed :
losses_m . update ( loss . item ( ) * iters_to_accum , input . size ( 0 ) )
def _backward ( ) :
if loss_scaler is not None :
loss_scaler (
loss , optimizer ,
clip_grad = args . clip_grad ,
clip_mode = args . clip_mode ,
parameters = model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
create_graph = second_order ,
need_step = need_step
)
if args . save_images and output_dir :
torchvision . utils . save_image (
input ,
os. path . join ( output_dir , ' train-batch- %d .jpg ' % batch_idx ) ,
padding= 0 ,
normalize= Tru e
else :
loss . backward ( create_graph = second_order )
if args . clip_grad is not None :
utils . dispatch_clip_grad (
model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
value = args . clip_grad ,
mode = args . clip_mode
)
if need_step :
optimizer . step ( )
if saver is not None and args . recovery_interval and (
last_batch or ( batch_idx + 1 ) % args . recovery_interval == 0 ) :
saver . save_recovery ( epoch , batch_idx = batch_idx )
if lr_scheduler is not None :
lr_scheduler . step_update ( num_updates = num_updates , metric = losses_m . avg )
num_step_samples + = input . size ( 0 )
if need_step is not True and hasattr ( model , " no_sync " ) :
with model . no_sync ( ) :
_backward ( )
else :
_backward ( )
if need_step :
optimizer . zero_grad ( )
if model_ema is not None :
model_ema . update ( model )
end = time . time ( )
torch . cuda . synchronize ( )
num_updates + = 1
batch_time_m . update ( time . time ( ) - end )
if ( batch_idx / / args . iters_to_accum ) % args . log_interval == 0 :
lrl = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
lr = sum ( lrl ) / len ( lrl )
if args . distributed :
reduced_loss = utils . reduce_tensor ( loss . data , args . world_size )
losses_m . update ( reduced_loss . item ( ) * iters_to_accum , input . size ( 0 ) )
if utils . is_primary ( args ) :
_logger . info (
' Train: {} [ {:>4d} / {} ( {:>3.0f} % )] '
' Loss: {loss.val:#.4g} ( {loss.avg:#.3g} ) '
' Time: {batch_time.val:.3f} s, {rate:>7.2f} /s '
' ( {batch_time.avg:.3f} s, {rate_avg:>7.2f} /s) '
' LR: {lr:.3e} '
' Data: {data_time.val:.3f} ( {data_time.avg:.3f} ) ' . format (
epoch ,
batch_idx , len ( loader ) ,
100. * batch_idx / last_idx ,
loss = losses_m ,
batch_time = batch_time_m ,
rate = num_step_samples * args . world_size / batch_time_m . val ,
rate_avg = num_step_samples * args . world_size / batch_time_m . avg ,
lr = lr ,
data_time = data_time_m )
)
if args . save_images and output_dir :
torchvision . utils . save_image (
input ,
os . path . join ( output_dir , ' train-batch- %d .jpg ' % batch_idx ) ,
padding = 0 ,
normalize = True
)
if saver is not None and args . recovery_interval and (
( batch_idx / / args . iters_to_accum + 1 ) % args . recovery_interval == 0 ) :
saver . save_recovery ( epoch , batch_idx = batch_idx )
if lr_scheduler is not None :
lr_scheduler . step_update ( num_updates = num_updates , metric = losses_m . avg )
num_step_samples = 0
end = time . time ( )
# end for
if hasattr ( optimizer , ' sync_lookahead ' ) :