@ -29,7 +29,7 @@ import torchvision.utils
from torch . nn . parallel import DistributedDataParallel as NativeDDP
from torch . nn . parallel import DistributedDataParallel as NativeDDP
from timm . data import Dataset , create_loader , resolve_data_config , Mixup , FastCollateMixup , AugMixDataset
from timm . data import Dataset , create_loader , resolve_data_config , Mixup , FastCollateMixup , AugMixDataset
from timm . models import create_model , resume_checkpoint , convert_splitbn_model
from timm . models import create_model , resume_checkpoint , load_checkpoint, convert_splitbn_model
from timm . utils import *
from timm . utils import *
from timm . loss import LabelSmoothingCrossEntropy , SoftTargetCrossEntropy , JsdCrossEntropy
from timm . loss import LabelSmoothingCrossEntropy , SoftTargetCrossEntropy , JsdCrossEntropy
from timm . optim import create_optimizer
from timm . optim import create_optimizer
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help = ' how many batches to wait before writing recovery checkpoint ' )
help = ' how many batches to wait before writing recovery checkpoint ' )
parser . add_argument ( ' -j ' , ' --workers ' , type = int , default = 4 , metavar = ' N ' ,
parser . add_argument ( ' -j ' , ' --workers ' , type = int , default = 4 , metavar = ' N ' ,
help = ' how many training processes to use (default: 1) ' )
help = ' how many training processes to use (default: 1) ' )
parser . add_argument ( ' --num-gpu ' , type = int , default = 1 ,
help = ' Number of GPUS to use ' )
parser . add_argument ( ' --save-images ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --save-images ' , action = ' store_true ' , default = False ,
help = ' save images of input bathes every log interval for debugging ' )
help = ' save images of input bathes every log interval for debugging ' )
parser . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --amp ' , action = ' store_true ' , default = False ,
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser . add_argument ( " --local_rank " , default = 0 , type = int )
parser . add_argument ( " --local_rank " , default = 0 , type = int )
parser . add_argument ( ' --use-multi-epochs-loader ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --use-multi-epochs-loader ' , action = ' store_true ' , default = False ,
help = ' use the multi-epochs-loader to save time at the beginning of every epoch ' )
help = ' use the multi-epochs-loader to save time at the beginning of every epoch ' )
parser . add_argument ( ' --torchscript ' , dest = ' torchscript ' , action = ' store_true ' ,
help = ' convert model torchscript for inference ' )
def _parse_args ( ) :
def _parse_args ( ) :
@ -282,28 +282,36 @@ def main():
args . distributed = False
args . distributed = False
if ' WORLD_SIZE ' in os . environ :
if ' WORLD_SIZE ' in os . environ :
args . distributed = int ( os . environ [ ' WORLD_SIZE ' ] ) > 1
args . distributed = int ( os . environ [ ' WORLD_SIZE ' ] ) > 1
if args . distributed and args . num_gpu > 1 :
_logger . warning (
' Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1. ' )
args . num_gpu = 1
args . device = ' cuda:0 '
args . device = ' cuda:0 '
args . world_size = 1
args . world_size = 1
args . rank = 0 # global rank
args . rank = 0 # global rank
if args . distributed :
if args . distributed :
args . num_gpu = 1
args . device = ' cuda: %d ' % args . local_rank
args . device = ' cuda: %d ' % args . local_rank
torch . cuda . set_device ( args . local_rank )
torch . cuda . set_device ( args . local_rank )
torch . distributed . init_process_group ( backend = ' nccl ' , init_method = ' env:// ' )
torch . distributed . init_process_group ( backend = ' nccl ' , init_method = ' env:// ' )
args . world_size = torch . distributed . get_world_size ( )
args . world_size = torch . distributed . get_world_size ( )
args . rank = torch . distributed . get_rank ( )
args . rank = torch . distributed . get_rank ( )
assert args . rank > = 0
if args . distributed :
_logger . info ( ' Training in distributed mode with multiple processes, 1 GPU per process. Process %d , total %d . '
_logger . info ( ' Training in distributed mode with multiple processes, 1 GPU per process. Process %d , total %d . '
% ( args . rank , args . world_size ) )
% ( args . rank , args . world_size ) )
else :
else :
_logger . info ( ' Training with a single process on %d GPUs. ' % args . num_gpu )
_logger . info ( ' Training with a single process on 1 GPUs. ' )
assert args . rank > = 0
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args . amp :
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex :
args . apex_amp = True
elif has_native_amp :
args . native_amp = True
if args . apex_amp and has_apex :
use_amp = ' apex '
elif args . native_amp and has_native_amp :
use_amp = ' native '
elif args . apex_amp or args . native_amp :
_logger . warning ( " Neither APEX or native Torch AMP is available, using float32. "
" Install NVIDA apex or upgrade to PyTorch 1.6 " )
torch . manual_seed ( args . seed + args . rank )
torch . manual_seed ( args . seed + args . rank )
@ -327,44 +335,44 @@ def main():
data_config = resolve_data_config ( vars ( args ) , model = model , verbose = args . local_rank == 0 )
data_config = resolve_data_config ( vars ( args ) , model = model , verbose = args . local_rank == 0 )
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
num_aug_splits = 0
if args . aug_splits > 0 :
if args . aug_splits > 0 :
assert args . aug_splits > 1 , ' A split of 1 makes no sense '
assert args . aug_splits > 1 , ' A split of 1 makes no sense '
num_aug_splits = args . aug_splits
num_aug_splits = args . aug_splits
# enable split bn (separate bn stats per batch-portion)
if args . split_bn :
if args . split_bn :
assert num_aug_splits > 1 or args . resplit
assert num_aug_splits > 1 or args . resplit
model = convert_splitbn_model ( model , max ( num_aug_splits , 2 ) )
model = convert_splitbn_model ( model , max ( num_aug_splits , 2 ) )
use_amp = None
# move model to GPU, enable channels last layout if set
if args . amp :
# for backwards compat, `--amp` arg tries apex before native amp
if has_apex :
args . apex_amp = True
elif has_native_amp :
args . native_amp = True
if args . apex_amp and has_apex :
use_amp = ' apex '
elif args . native_amp and has_native_amp :
use_amp = ' native '
elif args . apex_amp or args . native_amp :
_logger . warning ( " Neither APEX or native Torch AMP is available, using float32. "
" Install NVIDA apex or upgrade to PyTorch 1.6 " )
if args . num_gpu > 1 :
if use_amp == ' apex ' :
_logger . warning (
' Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP. ' )
use_amp = None
model = nn . DataParallel ( model , device_ids = list ( range ( args . num_gpu ) ) ) . cuda ( )
assert not args . channels_last , " Channels last not supported with DP, use DDP. "
else :
model . cuda ( )
model . cuda ( )
if args . channels_last :
if args . channels_last :
model = model . to ( memory_format = torch . channels_last )
model = model . to ( memory_format = torch . channels_last )
# setup synchronized BatchNorm for distributed training
if args . distributed and args . sync_bn :
assert not args . split_bn
if has_apex and use_amp != ' native ' :
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model ( model )
else :
model = torch . nn . SyncBatchNorm . convert_sync_batchnorm ( model )
if args . local_rank == 0 :
_logger . info (
' Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
' zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled. ' )
if args . torchscript :
assert not use_amp == ' apex ' , ' Cannot use APEX AMP with torchscripted model '
assert not args . sync_bn , ' Cannot use SyncBatchNorm with torchscripted model '
# FIXME I ran into a bug w/ AMP + torchscript + Linear layers
model = torch . jit . script ( model )
optimizer = create_optimizer ( args , model )
optimizer = create_optimizer ( args , model )
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
amp_autocast = suppress # do nothing
loss_scaler = None
loss_scaler = None
if use_amp == ' apex ' :
if use_amp == ' apex ' :
@ -390,30 +398,17 @@ def main():
loss_scaler = None if args . no_resume_opt else loss_scaler ,
loss_scaler = None if args . no_resume_opt else loss_scaler ,
log_info = args . local_rank == 0 )
log_info = args . local_rank == 0 )
# setup exponential moving average of model weights, SWA could be used here too
model_ema = None
model_ema = None
if args . model_ema :
if args . model_ema :
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
model_ema = ModelEma (
model_ema = ModelEmaV2 (
model ,
model , decay = args . model_ema_decay , device = ' cpu ' if args . model_ema_force_cpu else None )
decay = args . model_ema_decay ,
if args . resume :
device = ' cpu ' if args . model_ema_force_cpu else ' ' ,
load_checkpoint ( model_ema . module , args . resume , use_ema = True )
resume = args . resume )
# setup distributed training
if args . distributed :
if args . distributed :
if args . sync_bn :
assert not args . split_bn
try :
if has_apex and use_amp != ' native ' :
# Apex SyncBN preferred unless native amp is activated
model = convert_syncbn_model ( model )
else :
model = torch . nn . SyncBatchNorm . convert_sync_batchnorm ( model )
if args . local_rank == 0 :
_logger . info (
' Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
' zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled. ' )
except Exception as e :
_logger . error ( ' Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1 ' )
if has_apex and use_amp != ' native ' :
if has_apex and use_amp != ' native ' :
# Apex DDP preferred unless native amp is activated
# Apex DDP preferred unless native amp is activated
if args . local_rank == 0 :
if args . local_rank == 0 :
@ -425,6 +420,7 @@ def main():
model = NativeDDP ( model , device_ids = [ args . local_rank ] ) # can use device str in Torch >= 1.1
model = NativeDDP ( model , device_ids = [ args . local_rank ] ) # can use device str in Torch >= 1.1
# NOTE: EMA model does not need to be wrapped by DDP
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
lr_scheduler , num_epochs = create_scheduler ( args , optimizer )
start_epoch = 0
start_epoch = 0
if args . start_epoch is not None :
if args . start_epoch is not None :
@ -438,12 +434,22 @@ def main():
if args . local_rank == 0 :
if args . local_rank == 0 :
_logger . info ( ' Scheduled epochs: {} ' . format ( num_epochs ) )
_logger . info ( ' Scheduled epochs: {} ' . format ( num_epochs ) )
# create the train and eval datasets
train_dir = os . path . join ( args . data , ' train ' )
train_dir = os . path . join ( args . data , ' train ' )
if not os . path . exists ( train_dir ) :
if not os . path . exists ( train_dir ) :
_logger . error ( ' Training folder does not exist at: {} ' . format ( train_dir ) )
_logger . error ( ' Training folder does not exist at: {} ' . format ( train_dir ) )
exit ( 1 )
exit ( 1 )
dataset_train = Dataset ( train_dir )
dataset_train = Dataset ( train_dir )
eval_dir = os . path . join ( args . data , ' val ' )
if not os . path . isdir ( eval_dir ) :
eval_dir = os . path . join ( args . data , ' validation ' )
if not os . path . isdir ( eval_dir ) :
_logger . error ( ' Validation folder does not exist at: {} ' . format ( eval_dir ) )
exit ( 1 )
dataset_eval = Dataset ( eval_dir )
# setup mixup / cutmix
collate_fn = None
collate_fn = None
mixup_fn = None
mixup_fn = None
mixup_active = args . mixup > 0 or args . cutmix > 0. or args . cutmix_minmax is not None
mixup_active = args . mixup > 0 or args . cutmix > 0. or args . cutmix_minmax is not None
@ -458,9 +464,11 @@ def main():
else :
else :
mixup_fn = Mixup ( * * mixup_args )
mixup_fn = Mixup ( * * mixup_args )
# wrap dataset in AugMix helper
if num_aug_splits > 1 :
if num_aug_splits > 1 :
dataset_train = AugMixDataset ( dataset_train , num_splits = num_aug_splits )
dataset_train = AugMixDataset ( dataset_train , num_splits = num_aug_splits )
# create data loaders w/ augmentation pipeiine
train_interpolation = args . train_interpolation
train_interpolation = args . train_interpolation
if args . no_aug or not train_interpolation :
if args . no_aug or not train_interpolation :
train_interpolation = data_config [ ' interpolation ' ]
train_interpolation = data_config [ ' interpolation ' ]
@ -492,14 +500,6 @@ def main():
use_multi_epochs_loader = args . use_multi_epochs_loader
use_multi_epochs_loader = args . use_multi_epochs_loader
)
)
eval_dir = os . path . join ( args . data , ' val ' )
if not os . path . isdir ( eval_dir ) :
eval_dir = os . path . join ( args . data , ' validation ' )
if not os . path . isdir ( eval_dir ) :
_logger . error ( ' Validation folder does not exist at: {} ' . format ( eval_dir ) )
exit ( 1 )
dataset_eval = Dataset ( eval_dir )
loader_eval = create_loader (
loader_eval = create_loader (
dataset_eval ,
dataset_eval ,
input_size = data_config [ ' input_size ' ] ,
input_size = data_config [ ' input_size ' ] ,
@ -515,6 +515,7 @@ def main():
pin_memory = args . pin_mem ,
pin_memory = args . pin_mem ,
)
)
# setup loss function
if args . jsd :
if args . jsd :
assert num_aug_splits > 1 # JSD only valid with aug splits set
assert num_aug_splits > 1 # JSD only valid with aug splits set
train_loss_fn = JsdCrossEntropy ( num_splits = num_aug_splits , smoothing = args . smoothing ) . cuda ( )
train_loss_fn = JsdCrossEntropy ( num_splits = num_aug_splits , smoothing = args . smoothing ) . cuda ( )
@ -527,6 +528,7 @@ def main():
train_loss_fn = nn . CrossEntropyLoss ( ) . cuda ( )
train_loss_fn = nn . CrossEntropyLoss ( ) . cuda ( )
validate_loss_fn = nn . CrossEntropyLoss ( ) . cuda ( )
validate_loss_fn = nn . CrossEntropyLoss ( ) . cuda ( )
# setup checkpoint saver and eval metric tracking
eval_metric = args . eval_metric
eval_metric = args . eval_metric
best_metric = None
best_metric = None
best_epoch = None
best_epoch = None
@ -638,11 +640,11 @@ def train_epoch(
torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) , args . clip_grad )
torch . nn . utils . clip_grad_norm_ ( model . parameters ( ) , args . clip_grad )
optimizer . step ( )
optimizer . step ( )
torch . cuda . synchronize ( )
if model_ema is not None :
if model_ema is not None :
model_ema . update ( model )
model_ema . update ( model )
num_updates + = 1
torch . cuda . synchronize ( )
num_updates + = 1
batch_time_m . update ( time . time ( ) - end )
batch_time_m . update ( time . time ( ) - end )
if last_batch or batch_idx % args . log_interval == 0 :
if last_batch or batch_idx % args . log_interval == 0 :
lrl = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]
lrl = [ param_group [ ' lr ' ] for param_group in optimizer . param_groups ]