|
|
@ -1,19 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
import argparse
|
|
|
|
import time
|
|
|
|
import time
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
from datetime import datetime
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
from apex import amp
|
|
|
|
|
|
|
|
from apex.parallel import DistributedDataParallel as DDP
|
|
|
|
|
|
|
|
has_apex = True
|
|
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
|
|
has_apex = False
|
|
|
|
|
|
|
|
|
|
|
|
from data import *
|
|
|
|
from data import *
|
|
|
|
from models import model_factory
|
|
|
|
from models import model_factory
|
|
|
|
from utils import *
|
|
|
|
from utils import *
|
|
|
|
from optim import Nadam, AdaBound
|
|
|
|
from optim import Nadam, AdaBound
|
|
|
|
|
|
|
|
from loss import LabelSmoothingCrossEntropy
|
|
|
|
import scheduler
|
|
|
|
import scheduler
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.nn.functional as F
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.optim as optim
|
|
|
|
import torch.utils.data as data
|
|
|
|
import torch.utils.data as data
|
|
|
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torchvision.utils
|
|
|
|
import torchvision.utils
|
|
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
@ -45,6 +55,8 @@ parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
|
|
|
help='manual epoch number (useful on restarts)')
|
|
|
|
help='manual epoch number (useful on restarts)')
|
|
|
|
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
|
|
|
|
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
|
|
|
|
help='epoch interval to decay LR')
|
|
|
|
help='epoch interval to decay LR')
|
|
|
|
|
|
|
|
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
|
|
|
|
|
|
|
help='epochs to warmup LR, if scheduler supports')
|
|
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
|
|
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
|
|
|
help='LR decay rate (default: 0.1)')
|
|
|
|
help='LR decay rate (default: 0.1)')
|
|
|
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|
|
|
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|
|
@ -53,10 +65,14 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
|
|
|
help='Dropout rate (default: 0.1)')
|
|
|
|
help='Dropout rate (default: 0.1)')
|
|
|
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
|
|
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
|
|
|
help='learning rate (default: 0.01)')
|
|
|
|
help='learning rate (default: 0.01)')
|
|
|
|
|
|
|
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
|
|
|
|
|
|
|
help='warmup learning rate (default: 0.0001)')
|
|
|
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|
|
|
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
|
|
|
help='SGD momentum (default: 0.9)')
|
|
|
|
help='SGD momentum (default: 0.9)')
|
|
|
|
parser.add_argument('--weight-decay', type=float, default=0.0005, metavar='M',
|
|
|
|
parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M',
|
|
|
|
help='weight decay (default: 0.0001)')
|
|
|
|
help='weight decay (default: 0.0001)')
|
|
|
|
|
|
|
|
parser.add_argument('--smoothing', type=float, default=0.1, metavar='M',
|
|
|
|
|
|
|
|
help='label smoothing (default: 0.1)')
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
|
|
|
help='random seed (default: 42)')
|
|
|
|
help='random seed (default: 42)')
|
|
|
|
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|
|
|
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|
|
@ -73,22 +89,51 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
|
|
|
help='path to latest checkpoint (default: none)')
|
|
|
|
help='path to latest checkpoint (default: none)')
|
|
|
|
parser.add_argument('--save-images', action='store_true', default=False,
|
|
|
|
parser.add_argument('--save-images', action='store_true', default=False,
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
help='save images of input bathes every log interval for debugging')
|
|
|
|
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
|
|
|
|
help='use NVIDIA amp for mixed precision training')
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
help='path to output folder (default: none, current dir)')
|
|
|
|
help='path to output folder (default: none, current dir)')
|
|
|
|
|
|
|
|
parser.add_argument("--local_rank", default=0, type=int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
def main():
|
|
|
|
args = parser.parse_args()
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
if args.output:
|
|
|
|
args.distributed = False
|
|
|
|
output_base = args.output
|
|
|
|
if 'WORLD_SIZE' in os.environ:
|
|
|
|
|
|
|
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
|
|
|
|
|
|
|
if args.distributed and args.num_gpu > 1:
|
|
|
|
|
|
|
|
print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
|
|
|
|
|
|
|
args.num_gpu = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args.device = 'cuda:0'
|
|
|
|
|
|
|
|
args.world_size = 1
|
|
|
|
|
|
|
|
r = -1
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
|
|
|
|
torch.distributed.init_process_group(backend='nccl',
|
|
|
|
|
|
|
|
init_method='env://')
|
|
|
|
|
|
|
|
args.world_size = torch.distributed.get_world_size()
|
|
|
|
|
|
|
|
r = torch.distributed.get_rank()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
|
|
|
print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.'
|
|
|
|
|
|
|
|
% (args.world_size, r))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
output_base = './output'
|
|
|
|
print('Training with a single process with %d GPUs.' % args.num_gpu)
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
output_dir = ''
|
|
|
|
args.model,
|
|
|
|
if args.local_rank == 0:
|
|
|
|
str(args.img_size)])
|
|
|
|
if args.output:
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
output_base = args.output
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
output_base = './output'
|
|
|
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
|
|
|
args.model,
|
|
|
|
|
|
|
|
str(args.img_size)])
|
|
|
|
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
|
|
|
|
|
|
|
|
batch_size = args.batch_size
|
|
|
|
batch_size = args.batch_size
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
|
torch.manual_seed(args.seed)
|
|
|
@ -103,10 +148,11 @@ def main():
|
|
|
|
batch_size=batch_size,
|
|
|
|
batch_size=batch_size,
|
|
|
|
is_training=True,
|
|
|
|
is_training=True,
|
|
|
|
use_prefetcher=True,
|
|
|
|
use_prefetcher=True,
|
|
|
|
random_erasing=0.5,
|
|
|
|
random_erasing=0.3,
|
|
|
|
mean=data_mean,
|
|
|
|
mean=data_mean,
|
|
|
|
std=data_std,
|
|
|
|
std=data_std,
|
|
|
|
num_workers=args.workers,
|
|
|
|
num_workers=args.workers,
|
|
|
|
|
|
|
|
distributed=args.distributed,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
|
|
|
|
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
|
|
|
@ -120,6 +166,7 @@ def main():
|
|
|
|
mean=data_mean,
|
|
|
|
mean=data_mean,
|
|
|
|
std=data_std,
|
|
|
|
std=data_std,
|
|
|
|
num_workers=args.workers,
|
|
|
|
num_workers=args.workers,
|
|
|
|
|
|
|
|
distributed=args.distributed,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model = model_factory.create_model(
|
|
|
|
model = model_factory.create_model(
|
|
|
@ -156,28 +203,53 @@ def main():
|
|
|
|
print("=> no checkpoint found at '{}'".format(args.resume))
|
|
|
|
print("=> no checkpoint found at '{}'".format(args.resume))
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.smoothing:
|
|
|
|
|
|
|
|
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
|
|
|
|
|
|
|
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
validate_loss_fn = train_loss_fn
|
|
|
|
|
|
|
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
if args.amp:
|
|
|
|
|
|
|
|
print('Warning: AMP does not work well with nn.DataParallel, disabling. '
|
|
|
|
|
|
|
|
'Use distributed mode for multi-GPU AMP.')
|
|
|
|
|
|
|
|
args.amp = False
|
|
|
|
|
|
|
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
model.cuda()
|
|
|
|
model.cuda()
|
|
|
|
|
|
|
|
|
|
|
|
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = create_optimizer(args, model.parameters())
|
|
|
|
optimizer = create_optimizer(args, model.parameters())
|
|
|
|
if optimizer_state is not None:
|
|
|
|
if optimizer_state is not None:
|
|
|
|
optimizer.load_state_dict(optimizer_state)
|
|
|
|
optimizer.load_state_dict(optimizer_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if has_apex and args.amp:
|
|
|
|
|
|
|
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O3')
|
|
|
|
|
|
|
|
use_amp = True
|
|
|
|
|
|
|
|
print('AMP enabled')
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
use_amp = False
|
|
|
|
|
|
|
|
print('AMP disabled')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
|
|
|
model = DDP(model, delay_allreduce=True)
|
|
|
|
|
|
|
|
|
|
|
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
|
|
|
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
|
|
|
print(num_epochs)
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
|
|
|
print('Scheduled epochs: ', num_epochs)
|
|
|
|
|
|
|
|
|
|
|
|
saver = CheckpointSaver(checkpoint_dir=output_dir)
|
|
|
|
saver = None
|
|
|
|
|
|
|
|
if output_dir:
|
|
|
|
|
|
|
|
saver = CheckpointSaver(checkpoint_dir=output_dir)
|
|
|
|
best_loss = None
|
|
|
|
best_loss = None
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
|
|
|
loader_train.sampler.set_epoch(epoch)
|
|
|
|
|
|
|
|
|
|
|
|
train_metrics = train_epoch(
|
|
|
|
train_metrics = train_epoch(
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir)
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
|
|
|
|
|
|
|
|
|
|
|
|
eval_metrics = validate(
|
|
|
|
eval_metrics = validate(
|
|
|
|
model, loader_eval, validate_loss_fn, args)
|
|
|
|
model, loader_eval, validate_loss_fn, args)
|
|
|
@ -189,16 +261,17 @@ def main():
|
|
|
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
|
|
|
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
|
|
|
write_header=best_loss is None)
|
|
|
|
write_header=best_loss is None)
|
|
|
|
|
|
|
|
|
|
|
|
# save proper checkpoint with eval metric
|
|
|
|
if saver is not None:
|
|
|
|
best_loss = saver.save_checkpoint({
|
|
|
|
# save proper checkpoint with eval metric
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
best_loss = saver.save_checkpoint({
|
|
|
|
'arch': args.model,
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
'arch': args.model,
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
'args': args,
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
},
|
|
|
|
'args': args,
|
|
|
|
epoch=epoch + 1,
|
|
|
|
},
|
|
|
|
metric=eval_metrics['eval_loss'])
|
|
|
|
epoch=epoch + 1,
|
|
|
|
|
|
|
|
metric=eval_metrics['eval_loss'])
|
|
|
|
|
|
|
|
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
except KeyboardInterrupt:
|
|
|
|
pass
|
|
|
|
pass
|
|
|
@ -207,7 +280,7 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
def train_epoch(
|
|
|
|
def train_epoch(
|
|
|
|
epoch, model, loader, optimizer, loss_fn, args,
|
|
|
|
epoch, model, loader, optimizer, loss_fn, args,
|
|
|
|
lr_scheduler=None, saver=None, output_dir=''):
|
|
|
|
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
|
|
|
|
|
|
|
|
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
batch_time_m = AverageMeter()
|
|
|
|
data_time_m = AverageMeter()
|
|
|
|
data_time_m = AverageMeter()
|
|
|
@ -225,10 +298,15 @@ def train_epoch(
|
|
|
|
output = model(input)
|
|
|
|
output = model(input)
|
|
|
|
|
|
|
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
if not args.distributed:
|
|
|
|
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
loss.backward()
|
|
|
|
if use_amp:
|
|
|
|
|
|
|
|
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
|
|
|
|
|
|
|
scaled_loss.backward()
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.synchronize()
|
|
|
@ -239,30 +317,36 @@ def train_epoch(
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
lr = sum(lrl) / len(lrl)
|
|
|
|
|
|
|
|
|
|
|
|
print('Train: {} [{}/{} ({:.0f}%)] '
|
|
|
|
if args.distributed:
|
|
|
|
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
|
|
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
|
|
|
|
|
|
|
|
'LR: {lr:.4f} '
|
|
|
|
if args.local_rank == 0:
|
|
|
|
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
|
|
|
print('Train: {} [{}/{} ({:.0f}%)] '
|
|
|
|
epoch,
|
|
|
|
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
|
|
|
|
batch_idx, len(loader),
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
|
|
|
|
100. * batch_idx / last_idx,
|
|
|
|
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
|
|
|
|
loss=losses_m,
|
|
|
|
'LR: {lr:.4f} '
|
|
|
|
batch_time=batch_time_m,
|
|
|
|
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
|
|
|
rate=input.size(0) / batch_time_m.val,
|
|
|
|
epoch,
|
|
|
|
rate_avg=input.size(0) / batch_time_m.avg,
|
|
|
|
batch_idx, len(loader),
|
|
|
|
lr=lr,
|
|
|
|
100. * batch_idx / last_idx,
|
|
|
|
data_time=data_time_m))
|
|
|
|
loss=losses_m,
|
|
|
|
|
|
|
|
batch_time=batch_time_m,
|
|
|
|
if args.save_images:
|
|
|
|
rate=input.size(0) * args.world_size / batch_time_m.val,
|
|
|
|
torchvision.utils.save_image(
|
|
|
|
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
|
|
|
input,
|
|
|
|
lr=lr,
|
|
|
|
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
|
|
|
data_time=data_time_m))
|
|
|
|
padding=0,
|
|
|
|
|
|
|
|
normalize=True)
|
|
|
|
if args.save_images and output_dir:
|
|
|
|
|
|
|
|
torchvision.utils.save_image(
|
|
|
|
if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0:
|
|
|
|
input,
|
|
|
|
|
|
|
|
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
|
|
|
|
|
|
|
|
padding=0,
|
|
|
|
|
|
|
|
normalize=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.local_rank == 0 and (
|
|
|
|
|
|
|
|
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
|
|
|
save_epoch = epoch + 1 if last_batch else epoch
|
|
|
|
save_epoch = epoch + 1 if last_batch else epoch
|
|
|
|
saver.save_recovery({
|
|
|
|
saver.save_recovery({
|
|
|
|
'epoch': save_epoch,
|
|
|
|
'epoch': save_epoch,
|
|
|
@ -309,15 +393,22 @@ def validate(model, loader, loss_fn, args):
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
|
|
|
prec1, prec5 = accuracy(output, target, topk=(1, 5))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
|
|
|
reduced_loss = reduce_tensor(loss.data, args.world_size)
|
|
|
|
|
|
|
|
prec1 = reduce_tensor(prec1, args.world_size)
|
|
|
|
|
|
|
|
prec5 = reduce_tensor(prec5, args.world_size)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
reduced_loss = loss.data
|
|
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
|
|
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
losses_m.update(reduced_loss.item(), input.size(0))
|
|
|
|
prec1_m.update(prec1.item(), output.size(0))
|
|
|
|
prec1_m.update(prec1.item(), output.size(0))
|
|
|
|
prec5_m.update(prec5.item(), output.size(0))
|
|
|
|
prec5_m.update(prec5.item(), output.size(0))
|
|
|
|
|
|
|
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
end = time.time()
|
|
|
|
end = time.time()
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
|
|
|
print('Test: [{0}/{1}]\t'
|
|
|
|
print('Test: [{0}/{1}]\t'
|
|
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
|
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
|
|
@ -362,6 +453,7 @@ def create_optimizer(args, parameters):
|
|
|
|
|
|
|
|
|
|
|
|
def create_scheduler(args, optimizer):
|
|
|
|
def create_scheduler(args, optimizer):
|
|
|
|
num_epochs = args.epochs
|
|
|
|
num_epochs = args.epochs
|
|
|
|
|
|
|
|
#FIXME expose cycle parms of the scheduler config to arguments
|
|
|
|
if args.sched == 'cosine':
|
|
|
|
if args.sched == 'cosine':
|
|
|
|
lr_scheduler = scheduler.CosineLRScheduler(
|
|
|
|
lr_scheduler = scheduler.CosineLRScheduler(
|
|
|
|
optimizer,
|
|
|
|
optimizer,
|
|
|
@ -369,8 +461,8 @@ def create_scheduler(args, optimizer):
|
|
|
|
t_mul=1.0,
|
|
|
|
t_mul=1.0,
|
|
|
|
lr_min=1e-5,
|
|
|
|
lr_min=1e-5,
|
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
warmup_lr_init=1e-4,
|
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=0,
|
|
|
|
warmup_t=args.warmup_epochs,
|
|
|
|
cycle_limit=1,
|
|
|
|
cycle_limit=1,
|
|
|
|
t_in_epochs=True,
|
|
|
|
t_in_epochs=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -381,8 +473,8 @@ def create_scheduler(args, optimizer):
|
|
|
|
t_initial=num_epochs,
|
|
|
|
t_initial=num_epochs,
|
|
|
|
t_mul=1.0,
|
|
|
|
t_mul=1.0,
|
|
|
|
lr_min=1e-5,
|
|
|
|
lr_min=1e-5,
|
|
|
|
warmup_lr_init=.001,
|
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
warmup_t=3,
|
|
|
|
warmup_t=args.warmup_epochs,
|
|
|
|
cycle_limit=1,
|
|
|
|
cycle_limit=1,
|
|
|
|
t_in_epochs=True,
|
|
|
|
t_in_epochs=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -392,9 +484,18 @@ def create_scheduler(args, optimizer):
|
|
|
|
optimizer,
|
|
|
|
optimizer,
|
|
|
|
decay_t=args.decay_epochs,
|
|
|
|
decay_t=args.decay_epochs,
|
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
decay_rate=args.decay_rate,
|
|
|
|
|
|
|
|
warmup_lr_init=args.warmup_lr,
|
|
|
|
|
|
|
|
warmup_t=args.warmup_epochs,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return lr_scheduler, num_epochs
|
|
|
|
return lr_scheduler, num_epochs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_tensor(tensor, n):
|
|
|
|
|
|
|
|
rt = tensor.clone()
|
|
|
|
|
|
|
|
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
|
|
|
|
|
|
|
rt /= n
|
|
|
|
|
|
|
|
return rt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
if __name__ == '__main__':
|
|
|
|
main()
|
|
|
|
main()
|
|
|
|