|
|
|
@ -29,7 +29,7 @@ import torchvision.utils
|
|
|
|
|
from torch.nn.parallel import DistributedDataParallel as NativeDDP
|
|
|
|
|
|
|
|
|
|
from timm.data import Dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
|
|
|
|
|
from timm.models import create_model, resume_checkpoint, convert_splitbn_model
|
|
|
|
|
from timm.models import create_model, resume_checkpoint, load_checkpoint, convert_splitbn_model
|
|
|
|
|
from timm.utils import *
|
|
|
|
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
|
|
|
|
from timm.optim import create_optimizer
|
|
|
|
@ -230,8 +230,6 @@ parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
|
|
|
|
help='how many batches to wait before writing recovery checkpoint')
|
|
|
|
|
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|
|
|
|
help='how many training processes to use (default: 1)')
|
|
|
|
|
parser.add_argument('--num-gpu', type=int, default=1,
|
|
|
|
|
help='Number of GPUS to use')
|
|
|
|
|
parser.add_argument('--save-images', action='store_true', default=False,
|
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
@ -255,6 +253,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
|
parser.add_argument("--local_rank", default=0, type=int)
|
|
|
|
|
parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
|
|
|
|
|
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
|
|
|
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
|
|
|
|
help='convert model torchscript for inference')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_args():
|
|
|
|
@ -282,28 +282,36 @@ def main():
|
|
|
|
|
args.distributed = False
|
|
|
|
|
if 'WORLD_SIZE' in os.environ:
|
|
|
|
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
|
|
|
if args.distributed and args.num_gpu > 1:
|
|
|
|
|
_logger.warning(
|
|
|
|
|
'Using more than one GPU per process in distributed mode is not allowed.Setting num_gpu to 1.')
|
|
|
|
|
args.num_gpu = 1
|
|
|
|
|
|
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
|
args.world_size = 1
|
|
|
|
|
args.rank = 0 # global rank
|
|
|
|
|
if args.distributed:
|
|
|
|
|
args.num_gpu = 1
|
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
|
|
args.rank = torch.distributed.get_rank()
|
|
|
|
|
assert args.rank >= 0
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
_logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
|
% (args.rank, args.world_size))
|
|
|
|
|
else:
|
|
|
|
|
_logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
|
|
|
|
_logger.info('Training with a single process on 1 GPUs.')
|
|
|
|
|
assert args.rank >= 0
|
|
|
|
|
|
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
|
use_amp = None
|
|
|
|
|
if args.amp:
|
|
|
|
|
# for backwards compat, `--amp` arg tries apex before native amp
|
|
|
|
|
if has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
elif has_native_amp:
|
|
|
|
|
args.native_amp = True
|
|
|
|
|
if args.apex_amp and has_apex:
|
|
|
|
|
use_amp = 'apex'
|
|
|
|
|
elif args.native_amp and has_native_amp:
|
|
|
|
|
use_amp = 'native'
|
|
|
|
|
elif args.apex_amp or args.native_amp:
|
|
|
|
|
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
|
|
|
|
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed + args.rank)
|
|
|
|
|
|
|
|
|
@ -319,6 +327,7 @@ def main():
|
|
|
|
|
bn_tf=args.bn_tf,
|
|
|
|
|
bn_momentum=args.bn_momentum,
|
|
|
|
|
bn_eps=args.bn_eps,
|
|
|
|
|
scriptable=args.torchscript,
|
|
|
|
|
checkpoint_path=args.initial_checkpoint)
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
@ -327,44 +336,43 @@ def main():
|
|
|
|
|
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
|
|
|
|
|
|
|
|
|
# setup augmentation batch splits for contrastive loss or split bn
|
|
|
|
|
num_aug_splits = 0
|
|
|
|
|
if args.aug_splits > 0:
|
|
|
|
|
assert args.aug_splits > 1, 'A split of 1 makes no sense'
|
|
|
|
|
num_aug_splits = args.aug_splits
|
|
|
|
|
|
|
|
|
|
# enable split bn (separate bn stats per batch-portion)
|
|
|
|
|
if args.split_bn:
|
|
|
|
|
assert num_aug_splits > 1 or args.resplit
|
|
|
|
|
model = convert_splitbn_model(model, max(num_aug_splits, 2))
|
|
|
|
|
|
|
|
|
|
use_amp = None
|
|
|
|
|
if args.amp:
|
|
|
|
|
# for backwards compat, `--amp` arg tries apex before native amp
|
|
|
|
|
if has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
elif has_native_amp:
|
|
|
|
|
args.native_amp = True
|
|
|
|
|
if args.apex_amp and has_apex:
|
|
|
|
|
use_amp = 'apex'
|
|
|
|
|
elif args.native_amp and has_native_amp:
|
|
|
|
|
use_amp = 'native'
|
|
|
|
|
elif args.apex_amp or args.native_amp:
|
|
|
|
|
_logger.warning("Neither APEX or native Torch AMP is available, using float32. "
|
|
|
|
|
"Install NVIDA apex or upgrade to PyTorch 1.6")
|
|
|
|
|
# move model to GPU, enable channels last layout if set
|
|
|
|
|
model.cuda()
|
|
|
|
|
if args.channels_last:
|
|
|
|
|
model = model.to(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
|
if use_amp == 'apex':
|
|
|
|
|
_logger.warning(
|
|
|
|
|
'Apex AMP does not work well with nn.DataParallel, disabling. Use DDP or Torch AMP.')
|
|
|
|
|
use_amp = None
|
|
|
|
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
|
assert not args.channels_last, "Channels last not supported with DP, use DDP."
|
|
|
|
|
else:
|
|
|
|
|
model.cuda()
|
|
|
|
|
if args.channels_last:
|
|
|
|
|
model = model.to(memory_format=torch.channels_last)
|
|
|
|
|
# setup synchronized BatchNorm for distributed training
|
|
|
|
|
if args.distributed and args.sync_bn:
|
|
|
|
|
assert not args.split_bn
|
|
|
|
|
if has_apex and use_amp != 'native':
|
|
|
|
|
# Apex SyncBN preferred unless native amp is activated
|
|
|
|
|
model = convert_syncbn_model(model)
|
|
|
|
|
else:
|
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
_logger.info(
|
|
|
|
|
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
|
|
|
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
|
|
|
|
|
|
|
|
|
if args.torchscript:
|
|
|
|
|
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
|
|
|
|
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
|
|
|
|
|
optimizer = create_optimizer(args, model)
|
|
|
|
|
|
|
|
|
|
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
|
|
|
|
amp_autocast = suppress # do nothing
|
|
|
|
|
loss_scaler = None
|
|
|
|
|
if use_amp == 'apex':
|
|
|
|
@ -390,30 +398,17 @@ def main():
|
|
|
|
|
loss_scaler=None if args.no_resume_opt else loss_scaler,
|
|
|
|
|
log_info=args.local_rank == 0)
|
|
|
|
|
|
|
|
|
|
# setup exponential moving average of model weights, SWA could be used here too
|
|
|
|
|
model_ema = None
|
|
|
|
|
if args.model_ema:
|
|
|
|
|
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
|
|
|
|
model_ema = ModelEma(
|
|
|
|
|
model,
|
|
|
|
|
decay=args.model_ema_decay,
|
|
|
|
|
device='cpu' if args.model_ema_force_cpu else '',
|
|
|
|
|
resume=args.resume)
|
|
|
|
|
model_ema = ModelEmaV2(
|
|
|
|
|
model, decay=args.model_ema_decay, device='cpu' if args.model_ema_force_cpu else None)
|
|
|
|
|
if args.resume:
|
|
|
|
|
load_checkpoint(model_ema.module, args.resume, use_ema=True)
|
|
|
|
|
|
|
|
|
|
# setup distributed training
|
|
|
|
|
if args.distributed:
|
|
|
|
|
if args.sync_bn:
|
|
|
|
|
assert not args.split_bn
|
|
|
|
|
try:
|
|
|
|
|
if has_apex and use_amp != 'native':
|
|
|
|
|
# Apex SyncBN preferred unless native amp is activated
|
|
|
|
|
model = convert_syncbn_model(model)
|
|
|
|
|
else:
|
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
_logger.info(
|
|
|
|
|
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
|
|
|
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
|
|
|
|
except Exception as e:
|
|
|
|
|
_logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
|
|
|
|
if has_apex and use_amp != 'native':
|
|
|
|
|
# Apex DDP preferred unless native amp is activated
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
@ -425,6 +420,7 @@ def main():
|
|
|
|
|
model = NativeDDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
|
|
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
|
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
|
|
|
|
start_epoch = 0
|
|
|
|
|
if args.start_epoch is not None:
|
|
|
|
@ -438,12 +434,22 @@ def main():
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
_logger.info('Scheduled epochs: {}'.format(num_epochs))
|
|
|
|
|
|
|
|
|
|
# create the train and eval datasets
|
|
|
|
|
train_dir = os.path.join(args.data, 'train')
|
|
|
|
|
if not os.path.exists(train_dir):
|
|
|
|
|
_logger.error('Training folder does not exist at: {}'.format(train_dir))
|
|
|
|
|
exit(1)
|
|
|
|
|
dataset_train = Dataset(train_dir)
|
|
|
|
|
|
|
|
|
|
eval_dir = os.path.join(args.data, 'val')
|
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
|
eval_dir = os.path.join(args.data, 'validation')
|
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
|
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
|
|
|
|
exit(1)
|
|
|
|
|
dataset_eval = Dataset(eval_dir)
|
|
|
|
|
|
|
|
|
|
# setup mixup / cutmix
|
|
|
|
|
collate_fn = None
|
|
|
|
|
mixup_fn = None
|
|
|
|
|
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
|
|
|
@ -458,9 +464,11 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
mixup_fn = Mixup(**mixup_args)
|
|
|
|
|
|
|
|
|
|
# wrap dataset in AugMix helper
|
|
|
|
|
if num_aug_splits > 1:
|
|
|
|
|
dataset_train = AugMixDataset(dataset_train, num_splits=num_aug_splits)
|
|
|
|
|
|
|
|
|
|
# create data loaders w/ augmentation pipeiine
|
|
|
|
|
train_interpolation = args.train_interpolation
|
|
|
|
|
if args.no_aug or not train_interpolation:
|
|
|
|
|
train_interpolation = data_config['interpolation']
|
|
|
|
@ -492,14 +500,6 @@ def main():
|
|
|
|
|
use_multi_epochs_loader=args.use_multi_epochs_loader
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
eval_dir = os.path.join(args.data, 'val')
|
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
|
eval_dir = os.path.join(args.data, 'validation')
|
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
|
_logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
|
|
|
|
exit(1)
|
|
|
|
|
dataset_eval = Dataset(eval_dir)
|
|
|
|
|
|
|
|
|
|
loader_eval = create_loader(
|
|
|
|
|
dataset_eval,
|
|
|
|
|
input_size=data_config['input_size'],
|
|
|
|
@ -515,6 +515,7 @@ def main():
|
|
|
|
|
pin_memory=args.pin_mem,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# setup loss function
|
|
|
|
|
if args.jsd:
|
|
|
|
|
assert num_aug_splits > 1 # JSD only valid with aug splits set
|
|
|
|
|
train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing).cuda()
|
|
|
|
@ -527,6 +528,7 @@ def main():
|
|
|
|
|
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
|
|
# setup checkpoint saver and eval metric tracking
|
|
|
|
|
eval_metric = args.eval_metric
|
|
|
|
|
best_metric = None
|
|
|
|
|
best_epoch = None
|
|
|
|
@ -568,7 +570,7 @@ def main():
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
ema_eval_metrics = validate(
|
|
|
|
|
model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
|
|
|
|
model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)')
|
|
|
|
|
eval_metrics = ema_eval_metrics
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
@ -638,11 +640,11 @@ def train_epoch(
|
|
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
if model_ema is not None:
|
|
|
|
|
model_ema.update(model)
|
|
|
|
|
num_updates += 1
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
num_updates += 1
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
|