|
|
@ -6,12 +6,13 @@ from datetime import datetime
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
from apex import amp
|
|
|
|
from apex import amp
|
|
|
|
from apex.parallel import DistributedDataParallel as DDP
|
|
|
|
from apex.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
|
|
from apex.parallel import convert_syncbn_model
|
|
|
|
has_apex = True
|
|
|
|
has_apex = True
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
has_apex = False
|
|
|
|
has_apex = False
|
|
|
|
|
|
|
|
|
|
|
|
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
|
|
|
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
|
|
|
|
from models import create_model, resume_checkpoint
|
|
|
|
from models import create_model, resume_checkpoint, load_checkpoint
|
|
|
|
from utils import *
|
|
|
|
from utils import *
|
|
|
|
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
|
|
|
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
|
|
|
from optim import create_optimizer
|
|
|
|
from optim import create_optimizer
|
|
|
@ -41,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
|
|
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
|
|
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
|
|
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
|
|
|
help='Start with pretrained version of specified network (if avail)')
|
|
|
|
help='Start with pretrained version of specified network (if avail)')
|
|
|
|
parser.add_argument('--img-size', type=int, default=224, metavar='N',
|
|
|
|
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
|
|
|
help='Image patch size (default: 224)')
|
|
|
|
help='Image patch size (default: None => model default)')
|
|
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
|
|
help='Override mean pixel value of dataset')
|
|
|
|
help='Override mean pixel value of dataset')
|
|
|
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
|
|
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
|
|
@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None,
|
|
|
|
help='BatchNorm momentum override (if not None)')
|
|
|
|
help='BatchNorm momentum override (if not None)')
|
|
|
|
parser.add_argument('--bn-eps', type=float, default=None,
|
|
|
|
parser.add_argument('--bn-eps', type=float, default=None,
|
|
|
|
help='BatchNorm epsilon override (if not None)')
|
|
|
|
help='BatchNorm epsilon override (if not None)')
|
|
|
|
|
|
|
|
parser.add_argument('--model-ema', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Enable tracking moving average of model weights')
|
|
|
|
|
|
|
|
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
|
|
|
|
|
|
|
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
|
|
|
|
|
|
|
help='decay factor for model weights moving average (default: 0.9998)')
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
help='random seed (default: 42)')
|
|
|
|
help='random seed (default: 42)')
|
|
|
|
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|
|
|
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|
|
|
help='how many batches to wait before logging training status')
|
|
|
|
help='how many batches to wait before logging training status')
|
|
|
|
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
|
|
|
|
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
|
|
|
|
help='how many batches to wait before writing recovery checkpoint')
|
|
|
|
help='how many batches to wait before writing recovery checkpoint')
|
|
|
|
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|
|
|
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|
|
|
help='how many training processes to use (default: 1)')
|
|
|
|
help='how many training processes to use (default: 1)')
|
|
|
@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
help='use NVIDIA amp for mixed precision training')
|
|
|
|
help='use NVIDIA amp for mixed precision training')
|
|
|
|
|
|
|
|
parser.add_argument('--sync-bn', action='store_true',
|
|
|
|
|
|
|
|
help='enabling apex sync BN.')
|
|
|
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
@ -131,36 +140,24 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
args.world_size = 1
|
|
|
|
args.world_size = 1
|
|
|
|
r = -1
|
|
|
|
args.rank = 0 # global rank
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
|
args.num_gpu = 1
|
|
|
|
args.num_gpu = 1
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
torch.distributed.init_process_group(backend='nccl',
|
|
|
|
torch.distributed.init_process_group(
|
|
|
|
init_method='env://')
|
|
|
|
backend='nccl', init_method='env://')
|
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
|
r = torch.distributed.get_rank()
|
|
|
|
args.rank = torch.distributed.get_rank()
|
|
|
|
|
|
|
|
assert args.rank >= 0
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
|
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
|
|
|
% (r, args.world_size))
|
|
|
|
% (args.rank, args.world_size))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
print('Training with a single process on %d GPUs.' % args.num_gpu)
|
|
|
|
print('Training with a single process on %d GPUs.' % args.num_gpu)
|
|
|
|
|
|
|
|
|
|
|
|
# FIXME seed handling for multi-process distributed?
|
|
|
|
torch.manual_seed(args.seed + args.rank)
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_dir = ''
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
|
|
|
if args.output:
|
|
|
|
|
|
|
|
output_base = args.output
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
output_base = './output'
|
|
|
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
|
|
|
args.model,
|
|
|
|
|
|
|
|
str(args.img_size)])
|
|
|
|
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = create_model(
|
|
|
|
model = create_model(
|
|
|
|
args.model,
|
|
|
|
args.model,
|
|
|
@ -191,6 +188,8 @@ def main():
|
|
|
|
args.amp = False
|
|
|
|
args.amp = False
|
|
|
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
|
|
|
|
if args.distributed and args.sync_bn and has_apex:
|
|
|
|
|
|
|
|
model = convert_syncbn_model(model)
|
|
|
|
model.cuda()
|
|
|
|
model.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = create_optimizer(args, model)
|
|
|
|
optimizer = create_optimizer(args, model)
|
|
|
@ -205,8 +204,20 @@ def main():
|
|
|
|
use_amp = False
|
|
|
|
use_amp = False
|
|
|
|
print('AMP disabled')
|
|
|
|
print('AMP disabled')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_ema = None
|
|
|
|
|
|
|
|
if args.model_ema:
|
|
|
|
|
|
|
|
model_ema = ModelEma(
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
decay=args.model_ema_decay,
|
|
|
|
|
|
|
|
device='cpu' if args.model_ema_force_cpu else '',
|
|
|
|
|
|
|
|
resume=args.resume)
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
|
model = DDP(model, delay_allreduce=True)
|
|
|
|
model = DDP(model, delay_allreduce=True)
|
|
|
|
|
|
|
|
if model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
|
|
|
|
# must also distribute EMA model to allow validation
|
|
|
|
|
|
|
|
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
|
|
|
|
|
|
|
|
model_ema.ema_has_module = True
|
|
|
|
|
|
|
|
|
|
|
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
|
|
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
|
|
|
if start_epoch > 0:
|
|
|
|
if start_epoch > 0:
|
|
|
@ -271,12 +282,21 @@ def main():
|
|
|
|
validate_loss_fn = train_loss_fn
|
|
|
|
validate_loss_fn = train_loss_fn
|
|
|
|
|
|
|
|
|
|
|
|
eval_metric = args.eval_metric
|
|
|
|
eval_metric = args.eval_metric
|
|
|
|
|
|
|
|
best_metric = None
|
|
|
|
|
|
|
|
best_epoch = None
|
|
|
|
saver = None
|
|
|
|
saver = None
|
|
|
|
if output_dir:
|
|
|
|
output_dir = ''
|
|
|
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
|
|
|
output_base = args.output if args.output else './output'
|
|
|
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
|
|
|
args.model,
|
|
|
|
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
|
|
|
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
|
|
|
best_metric = None
|
|
|
|
|
|
|
|
best_epoch = None
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
@ -284,10 +304,15 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
train_metrics = train_epoch(
|
|
|
|
train_metrics = train_epoch(
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
|
|
|
use_amp=use_amp, model_ema=model_ema)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
|
|
|
|
|
|
|
|
|
|
|
eval_metrics = validate(
|
|
|
|
if model_ema is not None and not args.model_ema_force_cpu:
|
|
|
|
model, loader_eval, validate_loss_fn, args)
|
|
|
|
ema_eval_metrics = validate(
|
|
|
|
|
|
|
|
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
|
|
|
|
|
|
|
|
eval_metrics = ema_eval_metrics
|
|
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
lr_scheduler.step(epoch, eval_metrics[eval_metric])
|
|
|
|
lr_scheduler.step(epoch, eval_metrics[eval_metric])
|
|
|
@ -298,15 +323,12 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
if saver is not None:
|
|
|
|
if saver is not None:
|
|
|
|
# save proper checkpoint with eval metric
|
|
|
|
# save proper checkpoint with eval metric
|
|
|
|
best_metric, best_epoch = saver.save_checkpoint({
|
|
|
|
save_metric = eval_metrics[eval_metric]
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
best_metric, best_epoch = saver.save_checkpoint(
|
|
|
|
'arch': args.model,
|
|
|
|
model, optimizer, args,
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
|
|
|
|
'args': args,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
epoch=epoch + 1,
|
|
|
|
epoch=epoch + 1,
|
|
|
|
metric=eval_metrics[eval_metric])
|
|
|
|
model_ema=model_ema,
|
|
|
|
|
|
|
|
metric=save_metric)
|
|
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
@ -316,7 +338,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(
|
|
|
|
def train_epoch(
|
|
|
|
epoch, model, loader, optimizer, loss_fn, args,
|
|
|
|
epoch, model, loader, optimizer, loss_fn, args,
|
|
|
|
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
|
|
|
|
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
|
|
|
|
|
|
|
|
|
|
|
|
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
|
|
|
|
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
|
|
|
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
|
|
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
|
|
@ -359,6 +381,8 @@ def train_epoch(
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
if model_ema is not None:
|
|
|
|
|
|
|
|
model_ema.update(model)
|
|
|
|
num_updates += 1
|
|
|
|
num_updates += 1
|
|
|
|
|
|
|
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
@ -394,18 +418,11 @@ def train_epoch(
|
|
|
|
padding=0,
|
|
|
|
padding=0,
|
|
|
|
normalize=True)
|
|
|
|
normalize=True)
|
|
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0 and (
|
|
|
|
if saver is not None and args.recovery_interval and (
|
|
|
|
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
save_epoch = epoch + 1 if last_batch else epoch
|
|
|
|
save_epoch = epoch + 1 if last_batch else epoch
|
|
|
|
saver.save_recovery({
|
|
|
|
saver.save_recovery(
|
|
|
|
'epoch': save_epoch,
|
|
|
|
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
|
|
|
|
'arch': args.model,
|
|
|
|
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
|
|
|
|
'args': args,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
epoch=save_epoch,
|
|
|
|
|
|
|
|
batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
if lr_scheduler is not None:
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
|
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
|
|
@ -415,7 +432,7 @@ def train_epoch(
|
|
|
|
return OrderedDict([('loss', losses_m.avg)])
|
|
|
|
return OrderedDict([('loss', losses_m.avg)])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validate(model, loader, loss_fn, args):
|
|
|
|
def validate(model, loader, loss_fn, args, log_suffix=''):
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
losses_m = AverageMeter()
|
|
|
|
losses_m = AverageMeter()
|
|
|
|
prec1_m = AverageMeter()
|
|
|
|
prec1_m = AverageMeter()
|
|
|
@ -461,12 +478,13 @@ def validate(model, loader, loss_fn, args):
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
end = time.time()
|
|
|
|
end = time.time()
|
|
|
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
|
|
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
|
|
|
print('Test: [{0}/{1}]\t'
|
|
|
|
log_name = 'Test' + log_suffix
|
|
|
|
|
|
|
|
print('{0}: [{1}/{2}]\t'
|
|
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
|
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
|
|
|
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
|
|
|
|
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
|
|
|
|
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
|
|
|
|
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
|
|
|
|
batch_idx, last_idx,
|
|
|
|
log_name, batch_idx, last_idx,
|
|
|
|
batch_time=batch_time_m, loss=losses_m,
|
|
|
|
batch_time=batch_time_m, loss=losses_m,
|
|
|
|
top1=prec1_m, top5=prec5_m))
|
|
|
|
top1=prec1_m, top5=prec5_m))
|
|
|
|
|
|
|
|
|
|
|
@ -475,12 +493,5 @@ def validate(model, loader, loss_fn, args):
|
|
|
|
return metrics
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_tensor(tensor, n):
|
|
|
|
|
|
|
|
rt = tensor.clone()
|
|
|
|
|
|
|
|
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
|
|
|
|
|
|
|
rt /= n
|
|
|
|
|
|
|
|
return rt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|
|
|
|
main()
|
|
|
|