@ -21,6 +21,7 @@ import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
from functools import partial
import torch
import torch . nn as nn
@ -35,7 +36,7 @@ from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntrop
from timm . models import create_model , safe_model_name , resume_checkpoint , load_checkpoint , \
convert_splitbn_model , convert_sync_batchnorm , model_parameters , set_fast_norm
from timm . optim import create_optimizer_v2 , optimizer_kwargs
from timm . scheduler import create_scheduler
from timm . scheduler import create_scheduler _v2, scheduler_kwargs
from timm . utils import ApexScaler , NativeScaler
try :
@ -66,7 +67,6 @@ except ImportError as e:
has_functorch = False
torch . backends . cudnn . benchmark = True
_logger = logging . getLogger ( ' train ' )
# The first arg parser parses out only the --config argument, this argument is used to
@ -111,7 +111,9 @@ group.add_argument('--num-classes', type=int, default=None, metavar='N',
group . add_argument ( ' --gp ' , default = None , type = str , metavar = ' POOL ' ,
help = ' Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None. ' )
group . add_argument ( ' --img-size ' , type = int , default = None , metavar = ' N ' ,
help = ' Image patch size (default: None => model default) ' )
help = ' Image size (default: None => model default) ' )
group . add_argument ( ' --in-chans ' , type = int , default = None , metavar = ' N ' ,
help = ' Image input channels (default: None => 3) ' )
group . add_argument ( ' --input-size ' , default = None , nargs = 3 , type = int ,
metavar = ' N N N ' , help = ' Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty ' )
group . add_argument ( ' --crop-pct ' , default = None , type = float ,
@ -161,10 +163,18 @@ group.add_argument('--layer-decay', type=float, default=None,
# Learning rate schedule parameters
group = parser . add_argument_group ( ' Learning rate schedule parameters ' )
group . add_argument ( ' --sched ' , default = ' cosine ' , type = str , metavar = ' SCHEDULER ' ,
group . add_argument ( ' --sched ' , type = str , default = ' cosine ' , metavar = ' SCHEDULER ' ,
help = ' LR scheduler (default: " step " ' )
group . add_argument ( ' --lr ' , type = float , default = 0.05 , metavar = ' LR ' ,
help = ' learning rate (default: 0.05) ' )
group . add_argument ( ' --sched-on-updates ' , action = ' store_true ' , default = False ,
help = ' Apply LR scheduler step on update instead of epoch end. ' )
group . add_argument ( ' --lr ' , type = float , default = None , metavar = ' LR ' ,
help = ' learning rate, overrides lr-base if set (default: None) ' )
group . add_argument ( ' --lr-base ' , type = float , default = 0.1 , metavar = ' LR ' ,
help = ' base learning rate: lr = lr_base * global_batch_size / base_size ' )
group . add_argument ( ' --lr-base-size ' , type = int , default = 256 , metavar = ' DIV ' ,
help = ' base learning rate batch size (divisor, default: 256). ' )
group . add_argument ( ' --lr-base-scale ' , type = str , default = ' ' , metavar = ' SCALE ' ,
help = ' base learning rate vs batch_size scaling ( " linear " , " sqrt " , based on opt if empty) ' )
group . add_argument ( ' --lr-noise ' , type = float , nargs = ' + ' , default = None , metavar = ' pct, pct ' ,
help = ' learning rate noise on/off epoch percentages ' )
group . add_argument ( ' --lr-noise-pct ' , type = float , default = 0.67 , metavar = ' PERCENT ' ,
@ -179,23 +189,25 @@ group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help = ' learning rate cycle limit, cycles enabled if > 1 ' )
group . add_argument ( ' --lr-k-decay ' , type = float , default = 1.0 ,
help = ' learning rate k-decay for cosine/poly (default: 1.0) ' )
group . add_argument ( ' --warmup-lr ' , type = float , default = 0.000 1, metavar = ' LR ' ,
help = ' warmup learning rate (default: 0.000 1)' )
group . add_argument ( ' --min-lr ' , type = float , default = 1e-6 , metavar = ' LR ' ,
help = ' lower lr bound for cyclic schedulers that hit 0 ( 1e-5 )' )
group . add_argument ( ' --warmup-lr ' , type = float , default = 1e-5 , metavar = ' LR ' ,
help = ' warmup learning rate (default: 1e-5 )' )
group . add_argument ( ' --min-lr ' , type = float , default = 0 , metavar = ' LR ' ,
help = ' lower lr bound for cyclic schedulers that hit 0 ( default: 0 )' )
group . add_argument ( ' --epochs ' , type = int , default = 300 , metavar = ' N ' ,
help = ' number of epochs to train (default: 300) ' )
group . add_argument ( ' --epoch-repeats ' , type = float , default = 0. , metavar = ' N ' ,
help = ' epoch repeat multiplier (number of times to repeat dataset epoch per train epoch). ' )
group . add_argument ( ' --start-epoch ' , default = None , type = int , metavar = ' N ' ,
help = ' manual epoch number (useful on restarts) ' )
group . add_argument ( ' --decay-milestones ' , default = [ 30, 6 0] , type = int , nargs = ' + ' , metavar = " MILESTONES " ,
group . add_argument ( ' --decay-milestones ' , default = [ 90, 180 , 27 0] , type = int , nargs = ' + ' , metavar = " MILESTONES " ,
help = ' list of decay epoch indices for multistep lr. must be increasing ' )
group . add_argument ( ' --decay-epochs ' , type = float , default = 10 0, metavar = ' N ' ,
group . add_argument ( ' --decay-epochs ' , type = float , default = 9 0, metavar = ' N ' ,
help = ' epoch interval to decay LR ' )
group . add_argument ( ' --warmup-epochs ' , type = int , default = 3 , metavar = ' N ' ,
group . add_argument ( ' --warmup-epochs ' , type = int , default = 5 , metavar = ' N ' ,
help = ' epochs to warmup LR, if scheduler supports ' )
group . add_argument ( ' --cooldown-epochs ' , type = int , default = 10 , metavar = ' N ' ,
group . add_argument ( ' --warmup-prefix ' , action = ' store_true ' , default = False ,
help = ' Exclude warmup period from decay schedule. ' ) ,
group . add_argument ( ' --cooldown-epochs ' , type = int , default = 0 , metavar = ' N ' ,
help = ' epochs to cooldown LR at min_lr, after cyclic schedule ends ' )
group . add_argument ( ' --patience-epochs ' , type = int , default = 10 , metavar = ' N ' ,
help = ' patience epochs for Plateau LR scheduler (default: 10 ' )
@ -303,10 +315,10 @@ group.add_argument('--save-images', action='store_true', default=False,
help = ' save images of input bathes every log interval for debugging ' )
group . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
help = ' use NVIDIA Apex AMP or Native AMP for mixed precision training ' )
group . add_argument ( ' --a pex-amp' , action = ' store_true ' , default = False ,
help = ' Use NVIDIA Apex AMP mixed precision ' )
group . add_argument ( ' -- native-amp' , action = ' store_true ' , default = False ,
help = ' Use Native Torch AMP mixed precision ' )
group . add_argument ( ' --a mp-dtype' , default = ' float16 ' , type = str ,
help = ' lower precision AMP dtype (default: float16) ' )
group . add_argument ( ' -- amp-impl' , default = ' native ' , type = str ,
help = ' AMP impl to use, " native " or " apex " (default: native) ' )
group . add_argument ( ' --no-ddp-bb ' , action = ' store_true ' , default = False ,
help = ' Force broadcast buffers for native DDP to off. ' )
group . add_argument ( ' --pin-mem ' , action = ' store_true ' , default = False ,
@ -349,49 +361,42 @@ def main():
utils . setup_default_logging ( )
args , args_text = _parse_args ( )
if torch . cuda . is_available ( ) :
torch . backends . cuda . matmul . allow_tf32 = True
torch . backends . cudnn . benchmark = True
args . prefetcher = not args . no_prefetcher
args . distributed = False
if ' WORLD_SIZE ' in os . environ :
args . distributed = int ( os . environ [ ' WORLD_SIZE ' ] ) > 1
args . device = ' cuda:0 '
args . world_size = 1
args . rank = 0 # global rank
device = utils . init_distributed_device ( args )
if args . distributed :
if ' LOCAL_RANK ' in os . environ :
args . local_rank = int ( os . getenv ( ' LOCAL_RANK ' ) )
args . device = ' cuda: %d ' % args . local_rank
torch . cuda . set_device ( args . local_rank )
torch . distributed . init_process_group ( backend = ' nccl ' , init_method = ' env:// ' )
args . world_size = torch . distributed . get_world_size ( )
args . rank = torch . distributed . get_rank ( )
_logger . info ( ' Training in distributed mode with multiple processes, 1 GPU per process. Process %d , total %d . '
% ( args . rank , args . world_size ) )
_logger . info (
' Training in distributed mode with multiple processes, 1 device per process. '
f ' Process { args . rank } , total { args . world_size } , device { args . device } . ' )
else :
_logger . info ( ' Training with a single process on 1 GPUs .' )
_logger . info ( f ' Training with a single process on 1 device ( { args . device } ). ' )
assert args . rank > = 0
if args. rank == 0 and args . log_wandb :
if utils . is_primary ( args ) and args . log_wandb :
if has_wandb :
wandb . init ( project = args . experiment , config = args )
else :
_logger . warning ( " You ' ve requested to log metrics to wandb but package not found. "
" Metrics not being logged to wandb, try `pip install wandb` " )
_logger . warning (
" You ' ve requested to log metrics to wandb but package not found. "
" Metrics not being logged to wandb, try `pip install wandb` " )
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch . float16
if args . amp :
# `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_native_amp :
args . native_amp = True
elif has_apex :
args . apex_amp = True
if args . apex_amp and has_apex :
use_amp = ' apex '
elif args . native_amp and has_native_amp :
use_amp = ' native '
elif args . apex_amp or args . native_amp :
_logger . warning ( " Neither APEX or native Torch AMP is available, using float32. "
" Install NVIDA apex or upgrade to PyTorch 1.6 " )
if args . amp_impl == ' apex ' :
assert has_apex , ' AMP impl specified as APEX but APEX is not installed. '
use_amp = ' apex '
assert args . amp_dtype == ' float16 '
else :
assert has_native_amp , ' Please update PyTorch to a version with native AMP (or use APEX). '
use_amp = ' native '
assert args . amp_dtype in ( ' float16 ' , ' bfloat16 ' )
if args . amp_dtype == ' bfloat16 ' :
amp_dtype = torch . bfloat16
utils . random_seed ( args . seed , args . rank )
@ -400,19 +405,26 @@ def main():
if args . fast_norm :
set_fast_norm ( )
in_chans = 3
if args . in_chans is not None :
in_chans = args . in_chanes
elif args . input_size is not None :
in_chans = args . input_size [ 0 ]
model = create_model (
args . model ,
pretrained = args . pretrained ,
in_chans = in_chans ,
num_classes = args . num_classes ,
drop_rate = args . drop ,
drop_connect_rate = args . drop_connect , # DEPRECATED, use drop_path
drop_path_rate = args . drop_path ,
drop_block_rate = args . drop_block ,
global_pool = args . gp ,
bn_momentum = args . bn_momentum ,
bn_eps = args . bn_eps ,
scriptable = args . torchscript ,
checkpoint_path = args . initial_checkpoint )
checkpoint_path = args . initial_checkpoint ,
)
if args . num_classes is None :
assert hasattr ( model , ' num_classes ' ) , ' Model must have `num_classes` attr if not set on cmd line/config. '
args . num_classes = model . num_classes # FIXME handle model default vs config num_classes more elegantly
@ -420,11 +432,11 @@ def main():
if args . grad_checkpointing :
model . set_grad_checkpointing ( enable = True )
if args. local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info (
f ' Model { safe_model_name ( args . model ) } created, param count: { sum ( [ m . numel ( ) for m in model . parameters ( ) ] ) } ' )
data_config = resolve_data_config ( vars ( args ) , model = model , verbose = args. local_rank == 0 )
data_config = resolve_data_config ( vars ( args ) , model = model , verbose = utils. is_primary ( args ) )
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
@ -438,9 +450,9 @@ def main():
model = convert_splitbn_model ( model , max ( num_aug_splits , 2 ) )
# move model to GPU, enable channels last layout if set
model . cuda( )
model . to( device = device )
if args . channels_last :
model = model . to ( memory_format = torch . channels_last )
model . to ( memory_format = torch . channels_last )
# setup synchronized BatchNorm for distributed training
if args . distributed and args . sync_bn :
@ -452,7 +464,7 @@ def main():
model = convert_syncbn_model ( model )
else :
model = convert_sync_batchnorm ( model )
if args. local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info (
' Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
' zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled. ' )
@ -461,38 +473,56 @@ def main():
assert not use_amp == ' apex ' , ' Cannot use APEX AMP with torchscripted model '
assert not args . sync_bn , ' Cannot use SyncBatchNorm with torchscripted model '
model = torch . jit . script ( model )
if args . aot_autograd :
assert has_functorch , " functorch is needed for --aot-autograd "
model = memory_efficient_fusion ( model )
if args . lr is None :
global_batch_size = args . batch_size * args . world_size
batch_ratio = global_batch_size / args . lr_base_size
if not args . lr_base_scale :
on = args . opt . lower ( )
args . lr_base_scale = ' sqrt ' if any ( [ o in on for o in ( ' ada ' , ' lamb ' ) ] ) else ' linear '
if args . lr_base_scale == ' sqrt ' :
batch_ratio = batch_ratio * * 0.5
args . lr = args . lr_base * batch_ratio
if utils . is_primary ( args ) :
_logger . info (
f ' Learning rate ( { args . lr } ) calculated from base learning rate ( { args . lr_base } ) '
f ' and global batch size ( { global_batch_size } ) with { args . lr_base_scale } scaling. ' )
optimizer = create_optimizer_v2 ( model , * * optimizer_kwargs ( cfg = args ) )
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
loss_scaler = None
if use_amp == ' apex ' :
assert device . type == ' cuda '
model , optimizer = amp . initialize ( model , optimizer , opt_level = ' O1 ' )
loss_scaler = ApexScaler ( )
if args . local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info ( ' Using NVIDIA APEX AMP. Training in mixed precision. ' )
elif use_amp == ' native ' :
amp_autocast = torch . cuda . amp . autocast
loss_scaler = NativeScaler ( )
if args . local_rank == 0 :
amp_autocast = partial ( torch . autocast , device_type = device . type , dtype = amp_dtype )
if device . type == ' cuda ' :
loss_scaler = NativeScaler ( )
if utils . is_primary ( args ) :
_logger . info ( ' Using native Torch AMP. Training in mixed precision. ' )
else :
if args. local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info ( ' AMP not enabled. Training in float32. ' )
# optionally resume from a checkpoint
resume_epoch = None
if args . resume :
resume_epoch = resume_checkpoint (
model , args . resume ,
model ,
args . resume ,
optimizer = None if args . no_resume_opt else optimizer ,
loss_scaler = None if args . no_resume_opt else loss_scaler ,
log_info = args . local_rank == 0 )
log_info = utils . is_primary ( args ) ,
)
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
@ -507,41 +537,37 @@ def main():
if args . distributed :
if has_apex and use_amp == ' apex ' :
# Apex DDP preferred unless native amp is activated
if args. local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info ( " Using NVIDIA APEX DistributedDataParallel. " )
model = ApexDDP ( model , delay_allreduce = True )
else :
if args. local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info ( " Using native Torch DistributedDataParallel. " )
model = NativeDDP ( model , device_ids = [ args. local_rank ] , broadcast_buffers = not args . no_ddp_bb )
model = NativeDDP ( model , device_ids = [ device ] , broadcast_buffers = not args . no_ddp_bb )
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
start_epoch = 0
if args . start_epoch is not None :
# a specified start_epoch will always override the resume epoch
start_epoch = args . start_epoch
elif resume_epoch is not None :
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0 :
lr_scheduler . step ( start_epoch )
if args . local_rank == 0 :
_logger . info ( ' Scheduled epochs: {} ' . format ( num_epochs ) )
# create the train and eval datasets
dataset_train = create_dataset (
args . dataset , root = args . data_dir , split = args . train_split , is_training = True ,
args . dataset ,
root = args . data_dir ,
split = args . train_split ,
is_training = True ,
class_map = args . class_map ,
download = args . dataset_download ,
batch_size = args . batch_size ,
repeats = args . epoch_repeats )
seed = args . seed ,
repeats = args . epoch_repeats ,
)
dataset_eval = create_dataset (
args . dataset , root = args . data_dir , split = args . val_split , is_training = False ,
args . dataset ,
root = args . data_dir ,
split = args . val_split ,
is_training = False ,
class_map = args . class_map ,
download = args . dataset_download ,
batch_size = args . batch_size )
batch_size = args . batch_size ,
)
# setup mixup / cutmix
collate_fn = None
@ -549,9 +575,15 @@ def main():
mixup_active = args . mixup > 0 or args . cutmix > 0. or args . cutmix_minmax is not None
if mixup_active :
mixup_args = dict (
mixup_alpha = args . mixup , cutmix_alpha = args . cutmix , cutmix_minmax = args . cutmix_minmax ,
prob = args . mixup_prob , switch_prob = args . mixup_switch_prob , mode = args . mixup_mode ,
label_smoothing = args . smoothing , num_classes = args . num_classes )
mixup_alpha = args . mixup ,
cutmix_alpha = args . cutmix ,
cutmix_minmax = args . cutmix_minmax ,
prob = args . mixup_prob ,
switch_prob = args . mixup_switch_prob ,
mode = args . mixup_mode ,
label_smoothing = args . smoothing ,
num_classes = args . num_classes
)
if args . prefetcher :
assert not num_aug_splits # collate conflict (need to support deinterleaving in collate mixup)
collate_fn = FastCollateMixup ( * * mixup_args )
@ -592,10 +624,15 @@ def main():
distributed = args . distributed ,
collate_fn = collate_fn ,
pin_memory = args . pin_mem ,
device = device ,
use_multi_epochs_loader = args . use_multi_epochs_loader ,
worker_seeding = args . worker_seeding ,
)
eval_workers = args . workers
if args . distributed and ( ' tfds ' in args . dataset or ' wds ' in args . dataset ) :
# FIXME reduces validation padding issues when using TFDS, WDS w/ workers and distributed training
eval_workers = min ( 2 , args . workers )
loader_eval = create_loader (
dataset_eval ,
input_size = data_config [ ' input_size ' ] ,
@ -605,10 +642,11 @@ def main():
interpolation = data_config [ ' interpolation ' ] ,
mean = data_config [ ' mean ' ] ,
std = data_config [ ' std ' ] ,
num_workers = args. workers,
num_workers = eval_ workers,
distributed = args . distributed ,
crop_pct = data_config [ ' crop_pct ' ] ,
pin_memory = args . pin_mem ,
device = device ,
)
# setup loss function
@ -628,8 +666,8 @@ def main():
train_loss_fn = LabelSmoothingCrossEntropy ( smoothing = args . smoothing )
else :
train_loss_fn = nn . CrossEntropyLoss ( )
train_loss_fn = train_loss_fn . cuda( )
validate_loss_fn = nn . CrossEntropyLoss ( ) . cuda( )
train_loss_fn = train_loss_fn . to( device = device )
validate_loss_fn = nn . CrossEntropyLoss ( ) . to( device = device )
# setup checkpoint saver and eval metric tracking
eval_metric = args . eval_metric
@ -637,7 +675,7 @@ def main():
best_epoch = None
saver = None
output_dir = None
if args. rank == 0 :
if utils. is_primary ( args ) :
if args . experiment :
exp_name = args . experiment
else :
@ -649,60 +687,136 @@ def main():
output_dir = utils . get_outdir ( args . output if args . output else ' ./output/train ' , exp_name )
decreasing = True if eval_metric == ' loss ' else False
saver = utils . CheckpointSaver (
model = model , optimizer = optimizer , args = args , model_ema = model_ema , amp_scaler = loss_scaler ,
checkpoint_dir = output_dir , recovery_dir = output_dir , decreasing = decreasing , max_history = args . checkpoint_hist )
model = model ,
optimizer = optimizer ,
args = args ,
model_ema = model_ema ,
amp_scaler = loss_scaler ,
checkpoint_dir = output_dir ,
recovery_dir = output_dir ,
decreasing = decreasing ,
max_history = args . checkpoint_hist
)
with open ( os . path . join ( output_dir , ' args.yaml ' ) , ' w ' ) as f :
f . write ( args_text )
# setup learning rate schedule and starting epoch
updates_per_epoch = len ( loader_train )
lr_scheduler , num_epochs = create_scheduler_v2 (
optimizer ,
* * scheduler_kwargs ( args ) ,
updates_per_epoch = updates_per_epoch ,
)
start_epoch = 0
if args . start_epoch is not None :
# a specified start_epoch will always override the resume epoch
start_epoch = args . start_epoch
elif resume_epoch is not None :
start_epoch = resume_epoch
if lr_scheduler is not None and start_epoch > 0 :
if args . step_on_updates :
lr_scheduler . step_update ( start_epoch * updates_per_epoch )
else :
lr_scheduler . step ( start_epoch )
if utils . is_primary ( args ) :
_logger . info (
f ' Scheduled epochs: { num_epochs } . LR stepped per { " epoch " if lr_scheduler . t_in_epochs else " update " } . ' )
try :
for epoch in range ( start_epoch , num_epochs ) :
if args . distributed and hasattr ( loader_train . sampler , ' set_epoch ' ) :
if hasattr ( dataset_train , ' set_epoch ' ) :
dataset_train . set_epoch ( epoch )
elif args . distributed and hasattr ( loader_train . sampler , ' set_epoch ' ) :
loader_train . sampler . set_epoch ( epoch )
train_metrics = train_one_epoch (
epoch , model , loader_train , optimizer , train_loss_fn , args ,
lr_scheduler = lr_scheduler , saver = saver , output_dir = output_dir ,
amp_autocast = amp_autocast , loss_scaler = loss_scaler , model_ema = model_ema , mixup_fn = mixup_fn )
epoch ,
model ,
loader_train ,
optimizer ,
train_loss_fn ,
args ,
lr_scheduler = lr_scheduler ,
saver = saver ,
output_dir = output_dir ,
amp_autocast = amp_autocast ,
loss_scaler = loss_scaler ,
model_ema = model_ema ,
mixup_fn = mixup_fn ,
)
if args . distributed and args . dist_bn in ( ' broadcast ' , ' reduce ' ) :
if args . local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info ( " Distributing BatchNorm running means and vars " )
utils . distribute_bn ( model , args . world_size , args . dist_bn == ' reduce ' )
eval_metrics = validate ( model , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast )
eval_metrics = validate (
model ,
loader_eval ,
validate_loss_fn ,
args ,
amp_autocast = amp_autocast ,
)
if model_ema is not None and not args . model_ema_force_cpu :
if args . distributed and args . dist_bn in ( ' broadcast ' , ' reduce ' ) :
utils . distribute_bn ( model_ema , args . world_size , args . dist_bn == ' reduce ' )
ema_eval_metrics = validate (
model_ema . module , loader_eval , validate_loss_fn , args , amp_autocast = amp_autocast , log_suffix = ' (EMA) ' )
model_ema . module ,
loader_eval ,
validate_loss_fn ,
args ,
amp_autocast = amp_autocast ,
log_suffix = ' (EMA) ' ,
)
eval_metrics = ema_eval_metrics
if lr_scheduler is not None :
# step LR for next epoch
lr_scheduler . step ( epoch + 1 , eval_metrics [ eval_metric ] )
if output_dir is not None :
lrs = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
utils . update_summary (
epoch , train_metrics , eval_metrics , os . path . join ( output_dir , ' summary.csv ' ) ,
write_header = best_metric is None , log_wandb = args . log_wandb and has_wandb )
epoch ,
train_metrics ,
eval_metrics ,
filename = os . path . join ( output_dir , ' summary.csv ' ) ,
lr = sum ( lrs ) / len ( lrs ) ,
write_header = best_metric is None ,
log_wandb = args . log_wandb and has_wandb ,
)
if saver is not None :
# save proper checkpoint with eval metric
save_metric = eval_metrics [ eval_metric ]
best_metric , best_epoch = saver . save_checkpoint ( epoch , metric = save_metric )
if lr_scheduler is not None :
# step LR for next epoch
lr_scheduler . step ( epoch + 1 , eval_metrics [ eval_metric ] )
except KeyboardInterrupt :
pass
if best_metric is not None :
_logger . info ( ' *** Best metric: {0} (epoch {1} ) ' . format ( best_metric , best_epoch ) )
def train_one_epoch (
epoch , model , loader , optimizer , loss_fn , args ,
lr_scheduler = None , saver = None , output_dir = None , amp_autocast = suppress ,
loss_scaler = None , model_ema = None , mixup_fn = None ) :
epoch ,
model ,
loader ,
optimizer ,
loss_fn ,
args ,
device = torch . device ( ' cuda ' ) ,
lr_scheduler = None ,
saver = None ,
output_dir = None ,
amp_autocast = suppress ,
loss_scaler = None ,
model_ema = None ,
mixup_fn = None
) :
if args . mixup_off_epoch and epoch > = args . mixup_off_epoch :
if args . prefetcher and loader . mixup_enabled :
loader . mixup_enabled = False
@ -717,13 +831,14 @@ def train_one_epoch(
model . train ( )
end = time . time ( )
last_idx = len ( loader ) - 1
num_updates = epoch * len ( loader )
num_batches_per_epoch = len ( loader )
last_idx = num_batches_per_epoch - 1
num_updates = epoch * num_batches_per_epoch
for batch_idx , ( input , target ) in enumerate ( loader ) :
last_batch = batch_idx == last_idx
data_time_m . update ( time . time ( ) - end )
if not args . prefetcher :
input , target = input . cuda( ) , target . cuda ( )
input , target = input . to( device ) , target . to ( device )
if mixup_fn is not None :
input , target = mixup_fn ( input , target )
if args . channels_last :
@ -740,21 +855,26 @@ def train_one_epoch(
if loss_scaler is not None :
loss_scaler (
loss , optimizer ,
clip_grad = args . clip_grad , clip_mode = args . clip_mode ,
clip_grad = args . clip_grad ,
clip_mode = args . clip_mode ,
parameters = model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
create_graph = second_order )
create_graph = second_order
)
else :
loss . backward ( create_graph = second_order )
if args . clip_grad is not None :
utils . dispatch_clip_grad (
model_parameters ( model , exclude_head = ' agc ' in args . clip_mode ) ,
value = args . clip_grad , mode = args . clip_mode )
value = args . clip_grad ,
mode = args . clip_mode
)
optimizer . step ( )
if model_ema is not None :
model_ema . update ( model )
torch . cuda . synchronize ( )
num_updates + = 1
batch_time_m . update ( time . time ( ) - end )
if last_batch or batch_idx % args . log_interval == 0 :
@ -765,7 +885,7 @@ def train_one_epoch(
reduced_loss = utils . reduce_tensor ( loss . data , args . world_size )
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
if args. local_rank == 0 :
if utils. is_primary ( args ) :
_logger . info (
' Train: {} [ {:>4d} / {} ( {:>3.0f} % )] '
' Loss: {loss.val:#.4g} ( {loss.avg:#.3g} ) '
@ -781,14 +901,16 @@ def train_one_epoch(
rate = input . size ( 0 ) * args . world_size / batch_time_m . val ,
rate_avg = input . size ( 0 ) * args . world_size / batch_time_m . avg ,
lr = lr ,
data_time = data_time_m ) )
data_time = data_time_m )
)
if args . save_images and output_dir :
torchvision . utils . save_image (
input ,
os . path . join ( output_dir , ' train-batch- %d .jpg ' % batch_idx ) ,
padding = 0 ,
normalize = True )
normalize = True
)
if saver is not None and args . recovery_interval and (
last_batch or ( batch_idx + 1 ) % args . recovery_interval == 0 ) :
@ -806,7 +928,15 @@ def train_one_epoch(
return OrderedDict ( [ ( ' loss ' , losses_m . avg ) ] )
def validate ( model , loader , loss_fn , args , amp_autocast = suppress , log_suffix = ' ' ) :
def validate (
model ,
loader ,
loss_fn ,
args ,
device = torch . device ( ' cuda ' ) ,
amp_autocast = suppress ,
log_suffix = ' '
) :
batch_time_m = utils . AverageMeter ( )
losses_m = utils . AverageMeter ( )
top1_m = utils . AverageMeter ( )
@ -820,8 +950,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
for batch_idx , ( input , target ) in enumerate ( loader ) :
last_batch = batch_idx == last_idx
if not args . prefetcher :
input = input . cuda( )
target = target . cuda( )
input = input . to( device )
target = target . to( device )
if args . channels_last :
input = input . contiguous ( memory_format = torch . channels_last )
@ -846,7 +976,8 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
else :
reduced_loss = loss . data
torch . cuda . synchronize ( )
if device . type == ' cuda ' :
torch . cuda . synchronize ( )
losses_m . update ( reduced_loss . item ( ) , input . size ( 0 ) )
top1_m . update ( acc1 . item ( ) , output . size ( 0 ) )
@ -854,7 +985,7 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
batch_time_m . update ( time . time ( ) - end )
end = time . time ( )
if args. local_rank == 0 and ( last_batch or batch_idx % args . log_interval == 0 ) :
if utils. is_primary ( args ) and ( last_batch or batch_idx % args . log_interval == 0 ) :
log_name = ' Test ' + log_suffix
_logger . info (
' {0} : [ {1:>4d} / {2} ] '
@ -862,8 +993,12 @@ def validate(model, loader, loss_fn, args, amp_autocast=suppress, log_suffix='')
' Loss: {loss.val:>7.4f} ( {loss.avg:>6.4f} ) '
' Acc@1: {top1.val:>7.4f} ( {top1.avg:>7.4f} ) '
' Acc@5: {top5.val:>7.4f} ( {top5.avg:>7.4f} ) ' . format (
log_name , batch_idx , last_idx , batch_time = batch_time_m ,
loss = losses_m , top1 = top1_m , top5 = top5_m ) )
log_name , batch_idx , last_idx ,
batch_time = batch_time_m ,
loss = losses_m ,
top1 = top1_m ,
top5 = top5_m )
)
metrics = OrderedDict ( [ ( ' loss ' , losses_m . avg ) , ( ' top1 ' , top1_m . avg ) , ( ' top5 ' , top5_m . avg ) ] )