@ -36,7 +36,7 @@ from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntrop
from timm . models import create_model , safe_model_name , resume_checkpoint , load_checkpoint , \
from timm . models import create_model , safe_model_name , resume_checkpoint , load_checkpoint , \
convert_splitbn_model , convert_sync_batchnorm , model_parameters , set_fast_norm
convert_splitbn_model , convert_sync_batchnorm , model_parameters , set_fast_norm
from timm . optim import create_optimizer_v2 , optimizer_kwargs
from timm . optim import create_optimizer_v2 , optimizer_kwargs
from timm . scheduler import create_scheduler
from timm . scheduler import create_scheduler _v2, scheduler_kwargs
from timm . utils import ApexScaler , NativeScaler
from timm . utils import ApexScaler , NativeScaler
try :
try :
@ -163,10 +163,18 @@ group.add_argument('--layer-decay', type=float, default=None,
# Learning rate schedule parameters
# Learning rate schedule parameters
group = parser . add_argument_group ( ' Learning rate schedule parameters ' )
group = parser . add_argument_group ( ' Learning rate schedule parameters ' )
group . add_argument ( ' --sched ' , default = ' cosine ' , type = str , metavar = ' SCHEDULER ' ,
group . add_argument ( ' --sched ' , type = str , default = ' cosine ' , metavar = ' SCHEDULER ' ,
help = ' LR scheduler (default: " step " ' )
help = ' LR scheduler (default: " step " ' )
group . add_argument ( ' --lr ' , type = float , default = 0.05 , metavar = ' LR ' ,
group . add_argument ( ' --sched-on-updates ' , action = ' store_true ' , default = False ,
help = ' learning rate (default: 0.05) ' )
help = ' Apply LR scheduler step on update instead of epoch end. ' )
group . add_argument ( ' --lr ' , type = float , default = None , metavar = ' LR ' ,
help = ' learning rate, overrides lr-base if set (default: None) ' )
group . add_argument ( ' --lr-base ' , type = float , default = 0.1 , metavar = ' LR ' ,
help = ' base learning rate: lr = lr_base * global_batch_size / base_size ' )
group . add_argument ( ' --lr-base-size ' , type = int , default = 256 , metavar = ' DIV ' ,
help = ' base learning rate batch size (divisor, default: 256). ' )
group . add_argument ( ' --lr-base-scale ' , type = str , default = ' ' , metavar = ' SCALE ' ,
help = ' base learning rate vs batch_size scaling ( " linear " , " sqrt " , based on opt if empty) ' )
group . add_argument ( ' --lr-noise ' , type = float , nargs = ' + ' , default = None , metavar = ' pct, pct ' ,
group . add_argument ( ' --lr-noise ' , type = float , nargs = ' + ' , default = None , metavar = ' pct, pct ' ,
help = ' learning rate noise on/off epoch percentages ' )
help = ' learning rate noise on/off epoch percentages ' )
group . add_argument ( ' --lr-noise-pct ' , type = float , default = 0.67 , metavar = ' PERCENT ' ,
group . add_argument ( ' --lr-noise-pct ' , type = float , default = 0.67 , metavar = ' PERCENT ' ,
@ -181,23 +189,25 @@ group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help = ' learning rate cycle limit, cycles enabled if > 1 ' )
help = ' learning rate cycle limit, cycles enabled if > 1 ' )
group . add_argument ( ' --lr-k-decay ' , type = float , default = 1.0 ,
group . add_argument ( ' --lr-k-decay ' , type = float , default = 1.0 ,
help = ' learning rate k-decay for cosine/poly (default: 1.0) ' )
help = ' learning rate k-decay for cosine/poly (default: 1.0) ' )
group . add_argument ( ' --warmup-lr ' , type = float , default = 0.000 1, metavar = ' LR ' ,
group . add_argument ( ' --warmup-lr ' , type = float , default = 1e-5 , metavar = ' LR ' ,
help = ' warmup learning rate (default: 0.000 1)' )
help = ' warmup learning rate (default: 1e-5 )' )
group . add_argument ( ' --min-lr ' , type = float , default = 1e-6 , metavar = ' LR ' ,
group . add_argument ( ' --min-lr ' , type = float , default = 0 , metavar = ' LR ' ,
help = ' lower lr bound for cyclic schedulers that hit 0 ( 1e-5 )' )
help = ' lower lr bound for cyclic schedulers that hit 0 ( default: 0 )' )
group . add_argument ( ' --epochs ' , type = int , default = 300 , metavar = ' N ' ,
group . add_argument ( ' --epochs ' , type = int , default = 300 , metavar = ' N ' ,
help = ' number of epochs to train (default: 300) ' )
help = ' number of epochs to train (default: 300) ' )
group . add_argument ( ' --epoch-repeats ' , type = float , default = 0. , metavar = ' N ' ,
group . add_argument ( ' --epoch-repeats ' , type = float , default = 0. , metavar = ' N ' ,
help = ' epoch repeat multiplier (number of times to repeat dataset epoch per train epoch). ' )
help = ' epoch repeat multiplier (number of times to repeat dataset epoch per train epoch). ' )
group . add_argument ( ' --start-epoch ' , default = None , type = int , metavar = ' N ' ,
group . add_argument ( ' --start-epoch ' , default = None , type = int , metavar = ' N ' ,
help = ' manual epoch number (useful on restarts) ' )
help = ' manual epoch number (useful on restarts) ' )
group . add_argument ( ' --decay-milestones ' , default = [ 30, 6 0] , type = int , nargs = ' + ' , metavar = " MILESTONES " ,
group . add_argument ( ' --decay-milestones ' , default = [ 90, 180 , 27 0] , type = int , nargs = ' + ' , metavar = " MILESTONES " ,
help = ' list of decay epoch indices for multistep lr. must be increasing ' )
help = ' list of decay epoch indices for multistep lr. must be increasing ' )
group . add_argument ( ' --decay-epochs ' , type = float , default = 10 0, metavar = ' N ' ,
group . add_argument ( ' --decay-epochs ' , type = float , default = 9 0, metavar = ' N ' ,
help = ' epoch interval to decay LR ' )
help = ' epoch interval to decay LR ' )
group . add_argument ( ' --warmup-epochs ' , type = int , default = 3 , metavar = ' N ' ,
group . add_argument ( ' --warmup-epochs ' , type = int , default = 5 , metavar = ' N ' ,
help = ' epochs to warmup LR, if scheduler supports ' )
help = ' epochs to warmup LR, if scheduler supports ' )
group . add_argument ( ' --cooldown-epochs ' , type = int , default = 10 , metavar = ' N ' ,
group . add_argument ( ' --warmup-prefix ' , action = ' store_true ' , default = False ,
help = ' Exclude warmup period from decay schedule. ' ) ,
group . add_argument ( ' --cooldown-epochs ' , type = int , default = 0 , metavar = ' N ' ,
help = ' epochs to cooldown LR at min_lr, after cyclic schedule ends ' )
help = ' epochs to cooldown LR at min_lr, after cyclic schedule ends ' )
group . add_argument ( ' --patience-epochs ' , type = int , default = 10 , metavar = ' N ' ,
group . add_argument ( ' --patience-epochs ' , type = int , default = 10 , metavar = ' N ' ,
help = ' patience epochs for Plateau LR scheduler (default: 10 ' )
help = ' patience epochs for Plateau LR scheduler (default: 10 ' )
@ -469,6 +479,20 @@ def main():
assert has_functorch , " functorch is needed for --aot-autograd "
assert has_functorch , " functorch is needed for --aot-autograd "
model = memory_efficient_fusion ( model )
model = memory_efficient_fusion ( model )
if args . lr is None :
global_batch_size = args . batch_size * args . world_size
batch_ratio = global_batch_size / args . lr_base_size
if not args . lr_base_scale :
on = args . opt . lower ( )
args . base_scale = ' sqrt ' if any ( [ o in on for o in ( ' ada ' , ' lamb ' ) ] ) else ' linear '
if args . lr_base_scale == ' sqrt ' :
batch_ratio = batch_ratio * * 0.5
args . lr = args . lr_base * batch_ratio
if utils . is_primary ( args ) :
_logger . info (
f ' Learning rate ( { args . lr } ) calculated from base learning rate ( { args . lr_base } ) '
f ' and global batch size ( { global_batch_size } ) with { args . lr_base_scale } scaling. ' )
optimizer = create_optimizer_v2 ( model , * * optimizer_kwargs ( cfg = args ) )
optimizer = create_optimizer_v2 ( model , * * optimizer_kwargs ( cfg = args ) )
# setup automatic mixed-precision (AMP) loss scaling and op casting
# setup automatic mixed-precision (AMP) loss scaling and op casting
@ -523,20 +547,6 @@ def main():
model = NativeDDP ( model , device_ids = [ device ] , broadcast_buffers = not args . no_ddp_bb )
model = NativeDDP ( model , device_ids = [ device ] , broadcast_buffers = not args . no_ddp_bb )
# NOTE: EMA model does not need to be wrapped by DDP
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
start_epoch = 0
if args . start_epoch is not None :
# a specified start_epoch will always override the resume epoch
start_epoch = args . start_epoch
elif resume_epoch is not None :
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0 :
lr_scheduler . step ( start_epoch )
if utils . is_primary ( args ) :
_logger . info ( ' Scheduled epochs: {} ' . format ( num_epochs ) )
# create the train and eval datasets
# create the train and eval datasets
dataset_train = create_dataset (
dataset_train = create_dataset (
args . dataset ,
args . dataset ,
@ -691,6 +701,29 @@ def main():
with open ( os . path . join ( output_dir , ' args.yaml ' ) , ' w ' ) as f :
with open ( os . path . join ( output_dir , ' args.yaml ' ) , ' w ' ) as f :
f . write ( args_text )
f . write ( args_text )
# setup learning rate schedule and starting epoch
updates_per_epoch = len ( loader_train )
lr_scheduler , num_epochs = create_scheduler_v2 (
optimizer ,
* * scheduler_kwargs ( args ) ,
updates_per_epoch = updates_per_epoch ,
)
start_epoch = 0
if args . start_epoch is not None :
# a specified start_epoch will always override the resume epoch
start_epoch = args . start_epoch
elif resume_epoch is not None :
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0 :
if args . step_on_updates :
lr_scheduler . step_update ( start_epoch * updates_per_epoch )
else :
lr_scheduler . step ( start_epoch )
if utils . is_primary ( args ) :
_logger . info (
f ' Scheduled epochs: { num_epochs } . LR stepped per { " epoch " if lr_scheduler . t_in_epochs else " update " } . ' )
try :
try :
for epoch in range ( start_epoch , num_epochs ) :
for epoch in range ( start_epoch , num_epochs ) :
if hasattr ( dataset_train , ' set_epoch ' ) :
if hasattr ( dataset_train , ' set_epoch ' ) :
@ -741,16 +774,14 @@ def main():
)
)
eval_metrics = ema_eval_metrics
eval_metrics = ema_eval_metrics
if lr_scheduler is not None :
# step LR for next epoch
lr_scheduler . step ( epoch + 1 , eval_metrics [ eval_metric ] )
if output_dir is not None :
if output_dir is not None :
lrs = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
utils . update_summary (
utils . update_summary (
epoch ,
epoch ,
train_metrics ,
train_metrics ,
eval_metrics ,
eval_metrics ,
os . path . join ( output_dir , ' summary.csv ' ) ,
filename = os . path . join ( output_dir , ' summary.csv ' ) ,
lr = sum ( lrs ) / len ( lrs ) ,
write_header = best_metric is None ,
write_header = best_metric is None ,
log_wandb = args . log_wandb and has_wandb ,
log_wandb = args . log_wandb and has_wandb ,
)
)
@ -760,8 +791,13 @@ def main():
save_metric = eval_metrics [ eval_metric ]
save_metric = eval_metrics [ eval_metric ]
best_metric , best_epoch = saver . save_checkpoint ( epoch , metric = save_metric )
best_metric , best_epoch = saver . save_checkpoint ( epoch , metric = save_metric )
if lr_scheduler is not None :
# step LR for next epoch
lr_scheduler . step ( epoch + 1 , eval_metrics [ eval_metric ] )
except KeyboardInterrupt :
except KeyboardInterrupt :
pass
pass
if best_metric is not None :
if best_metric is not None :
_logger . info ( ' *** Best metric: {0} (epoch {1} ) ' . format ( best_metric , best_epoch ) )
_logger . info ( ' *** Best metric: {0} (epoch {1} ) ' . format ( best_metric , best_epoch ) )
@ -796,8 +832,9 @@ def train_one_epoch(
model . train ( )
model . train ( )
end = time . time ( )
end = time . time ( )
last_idx = len ( loader ) - 1
num_batches_per_epoch = len ( loader )
num_updates = epoch * len ( loader )
last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch
for batch_idx , ( input , target ) in enumerate ( loader ) :
for batch_idx , ( input , target ) in enumerate ( loader ) :
last_batch = batch_idx == last_idx
last_batch = batch_idx == last_idx
data_time_m . update ( time . time ( ) - end )
data_time_m . update ( time . time ( ) - end )