|
|
@ -40,6 +40,7 @@ import torch.nn as nn
|
|
|
|
import torchvision.utils
|
|
|
|
import torchvision.utils
|
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# The first arg parser parses out only the --config argument, this argument is used to
|
|
|
|
# The first arg parser parses out only the --config argument, this argument is used to
|
|
|
@ -232,7 +233,7 @@ def main():
|
|
|
|
if 'WORLD_SIZE' in os.environ:
|
|
|
|
if 'WORLD_SIZE' in os.environ:
|
|
|
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
|
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
|
|
if args.distributed and args.num_gpu > 1:
|
|
|
|
if args.distributed and args.num_gpu > 1:
|
|
|
|
logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
|
|
|
logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
|
|
|
args.num_gpu = 1
|
|
|
|
args.num_gpu = 1
|
|
|
|
|
|
|
|
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
args.device = 'cuda:0'
|
|
|
@ -248,10 +249,10 @@ def main():
|
|
|
|
assert args.rank >= 0
|
|
|
|
assert args.rank >= 0
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
|
logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
% (args.rank, args.world_size))
|
|
|
|
% (args.rank, args.world_size))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
|
|
|
logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
|
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed + args.rank)
|
|
|
|
torch.manual_seed(args.seed + args.rank)
|
|
|
|
|
|
|
|
|
|
|
@ -270,7 +271,7 @@ def main():
|
|
|
|
checkpoint_path=args.initial_checkpoint)
|
|
|
|
checkpoint_path=args.initial_checkpoint)
|
|
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info('Model %s created, param count: %d' %
|
|
|
|
logger.info('Model %s created, param count: %d' %
|
|
|
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
|
|
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
|
|
|
|
|
|
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
|
|
@ -286,7 +287,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
if args.amp:
|
|
|
|
if args.amp:
|
|
|
|
logging.warning(
|
|
|
|
logger.warning(
|
|
|
|
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
|
|
|
|
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
|
|
|
|
args.amp = False
|
|
|
|
args.amp = False
|
|
|
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
@ -300,7 +301,7 @@ def main():
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
|
|
|
use_amp = True
|
|
|
|
use_amp = True
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info('NVIDIA APEX {}. AMP {}.'.format(
|
|
|
|
logger.info('NVIDIA APEX {}. AMP {}.'.format(
|
|
|
|
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
|
|
|
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
|
|
|
|
|
|
|
|
|
|
|
# optionally resume from a checkpoint
|
|
|
|
# optionally resume from a checkpoint
|
|
|
@ -311,11 +312,11 @@ def main():
|
|
|
|
if resume_state and not args.no_resume_opt:
|
|
|
|
if resume_state and not args.no_resume_opt:
|
|
|
|
if 'optimizer' in resume_state:
|
|
|
|
if 'optimizer' in resume_state:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info('Restoring Optimizer state from checkpoint')
|
|
|
|
logger.info('Restoring Optimizer state from checkpoint')
|
|
|
|
optimizer.load_state_dict(resume_state['optimizer'])
|
|
|
|
optimizer.load_state_dict(resume_state['optimizer'])
|
|
|
|
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
|
|
|
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
|
|
|
logger.info('Restoring NVIDIA AMP state from checkpoint')
|
|
|
|
amp.load_state_dict(resume_state['amp'])
|
|
|
|
amp.load_state_dict(resume_state['amp'])
|
|
|
|
del resume_state
|
|
|
|
del resume_state
|
|
|
|
|
|
|
|
|
|
|
@ -337,16 +338,16 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info(
|
|
|
|
logger.info(
|
|
|
|
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
|
|
|
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
|
|
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
|
|
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
|
|
|
except Exception as e:
|
|
|
|
except Exception as e:
|
|
|
|
logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
|
|
|
logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
|
|
|
if has_apex:
|
|
|
|
if has_apex:
|
|
|
|
model = DDP(model, delay_allreduce=True)
|
|
|
|
model = DDP(model, delay_allreduce=True)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
|
|
|
|
logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
|
|
|
|
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
|
|
|
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
|
|
|
|
|
|
|
@ -361,11 +362,11 @@ def main():
|
|
|
|
lr_scheduler.step(start_epoch)
|
|
|
|
lr_scheduler.step(start_epoch)
|
|
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info('Scheduled epochs: {}'.format(num_epochs))
|
|
|
|
logger.info('Scheduled epochs: {}'.format(num_epochs))
|
|
|
|
|
|
|
|
|
|
|
|
train_dir = os.path.join(args.data, 'train')
|
|
|
|
train_dir = os.path.join(args.data, 'train')
|
|
|
|
if not os.path.exists(train_dir):
|
|
|
|
if not os.path.exists(train_dir):
|
|
|
|
logging.error('Training folder does not exist at: {}'.format(train_dir))
|
|
|
|
logger.error('Training folder does not exist at: {}'.format(train_dir))
|
|
|
|
exit(1)
|
|
|
|
exit(1)
|
|
|
|
dataset_train = Dataset(train_dir)
|
|
|
|
dataset_train = Dataset(train_dir)
|
|
|
|
|
|
|
|
|
|
|
@ -404,7 +405,7 @@ def main():
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
eval_dir = os.path.join(args.data, 'validation')
|
|
|
|
eval_dir = os.path.join(args.data, 'validation')
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
if not os.path.isdir(eval_dir):
|
|
|
|
logging.error('Validation folder does not exist at: {}'.format(eval_dir))
|
|
|
|
logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
|
|
|
exit(1)
|
|
|
|
exit(1)
|
|
|
|
dataset_eval = Dataset(eval_dir)
|
|
|
|
dataset_eval = Dataset(eval_dir)
|
|
|
|
|
|
|
|
|
|
|
@ -468,7 +469,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info("Distributing BatchNorm running means and vars")
|
|
|
|
logger.info("Distributing BatchNorm running means and vars")
|
|
|
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
|
|
|
|
|
|
|
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
|
|
@ -499,7 +500,7 @@ def main():
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
|
if best_metric is not None:
|
|
|
|
if best_metric is not None:
|
|
|
|
logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
|
|
|
logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(
|
|
|
|
def train_epoch(
|
|
|
@ -559,7 +560,7 @@ def train_epoch(
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
logging.info(
|
|
|
|
logger.info(
|
|
|
|
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
|
|
|
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
|
|
|
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
|
|
|
|
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
|
|
@ -647,7 +648,7 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
|
|
|
|
end = time.time()
|
|
|
|
end = time.time()
|
|
|
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
|
|
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
|
|
|
log_name = 'Test' + log_suffix
|
|
|
|
log_name = 'Test' + log_suffix
|
|
|
|
logging.info(
|
|
|
|
logger.info(
|
|
|
|
'{0}: [{1:>4d}/{2}] '
|
|
|
|
'{0}: [{1:>4d}/{2}] '
|
|
|
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
|
|
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
|
|
|