Distributed (multi-process) train, multi-gpu single process train, and NVIDIA AMP support

pull/1/head
Ross Wightman 6 years ago
parent 6f9a0c8ef2
commit 5180f94c7e

@ -91,15 +91,13 @@ class RandomErasingTorch:
def __init__( def __init__(
self, self,
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
per_pixel=False, rand_color=False, per_pixel=False, rand_color=False):
device='cuda'):
self.probability = probability self.probability = probability
self.sl = sl self.sl = sl
self.sh = sh self.sh = sh
self.min_aspect = min_aspect self.min_aspect = min_aspect
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
self.rand_color = rand_color # per block random, bounded by [pl, ph] self.rand_color = rand_color # per block random, bounded by [pl, ph]
self.device = device
def __call__(self, batch): def __call__(self, batch):
batch_size, chan, img_h, img_w = batch.size() batch_size, chan, img_h, img_w = batch.size()
@ -115,15 +113,15 @@ class RandomErasingTorch:
h = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio))) w = int(round(math.sqrt(target_area / aspect_ratio)))
if self.rand_color: if self.rand_color:
c = torch.empty((chan, 1, 1), dtype=batch.dtype, device=self.device).normal_() c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_()
elif not self.per_pixel: elif not self.per_pixel:
c = torch.zeros((chan, 1, 1), dtype=batch.dtype, device=self.device) c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
if w < img_w and h < img_h: if w < img_w and h < img_h:
top = random.randint(0, img_h - h) top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w) left = random.randint(0, img_w - w)
if self.per_pixel: if self.per_pixel:
img[:, top:top + h, left:left + w] = torch.empty( img[:, top:top + h, left:left + w] = torch.empty(
(chan, h, w), dtype=batch.dtype, device=self.device).normal_() (chan, h, w), dtype=batch.dtype).cuda().normal_()
else: else:
img[:, top:top + h, left:left + w] = c img[:, top:top + h, left:left + w] = c
break break

@ -18,25 +18,19 @@ class PrefetchLoader:
def __init__(self, def __init__(self,
loader, loader,
fp16=False,
random_erasing=0., random_erasing=0.,
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD): std=IMAGENET_DEFAULT_STD):
self.loader = loader self.loader = loader
self.fp16 = fp16
self.random_erasing = random_erasing self.random_erasing = random_erasing
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if random_erasing: if random_erasing:
self.random_erasing = RandomErasingTorch( self.random_erasing = RandomErasingTorch(
probability=random_erasing, per_pixel=True) probability=random_erasing, per_pixel=False)
else: else:
self.random_erasing = None self.random_erasing = None
if self.fp16:
self.mean = self.mean.half()
self.std = self.std.half()
def __iter__(self): def __iter__(self):
stream = torch.cuda.Stream() stream = torch.cuda.Stream()
first = True first = True
@ -45,10 +39,7 @@ class PrefetchLoader:
with torch.cuda.stream(stream): with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True) next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True)
if self.fp16: next_input = next_input.float()
next_input = next_input.half()
else:
next_input = next_input.float()
next_input = next_input.sub_(self.mean).div_(self.std) next_input = next_input.sub_(self.mean).div_(self.std)
if self.random_erasing is not None: if self.random_erasing is not None:
next_input = self.random_erasing(next_input) next_input = self.random_erasing(next_input)
@ -67,6 +58,10 @@ class PrefetchLoader:
def __len__(self): def __len__(self):
return len(self.loader) return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
def create_loader( def create_loader(
dataset, dataset,
@ -78,6 +73,7 @@ def create_loader(
mean=IMAGENET_DEFAULT_MEAN, mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD, std=IMAGENET_DEFAULT_STD,
num_workers=1, num_workers=1,
distributed=False,
): ):
if is_training: if is_training:
@ -95,11 +91,16 @@ def create_loader(
dataset.transform = transform dataset.transform = transform
sampler = None
if distributed:
sampler = tdata.distributed.DistributedSampler(dataset)
loader = tdata.DataLoader( loader = tdata.DataLoader(
dataset, dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=is_training, shuffle=sampler is None and is_training,
num_workers=num_workers, num_workers=num_workers,
sampler=sampler,
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate, collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
) )
if use_prefetcher: if use_prefetcher:

@ -0,0 +1,5 @@
#!/bin/bash
NUM_PROC=$1
shift
python -m torch.distributed.launch --nproc_per_node=$NUM_PROC dtrain.py "$@"

@ -1,19 +1,29 @@
import argparse import argparse
import time import time
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime from datetime import datetime
try:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
has_apex = True
except ImportError:
has_apex = False
from data import * from data import *
from models import model_factory from models import model_factory
from utils import * from utils import *
from optim import Nadam, AdaBound from optim import Nadam, AdaBound
from loss import LabelSmoothingCrossEntropy
import scheduler import scheduler
import torch import torch
import torch.nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import torch.utils.data as data import torch.utils.data as data
import torch.distributed as dist
import torchvision.utils import torchvision.utils
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -45,6 +55,8 @@ parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N', parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)') help='LR decay rate (default: 0.1)')
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
@ -53,10 +65,14 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.1)') help='Dropout rate (default: 0.1)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR', parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)') help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)') help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0005, metavar='M', parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M',
help='weight decay (default: 0.0001)') help='weight decay (default: 0.0001)')
parser.add_argument('--smoothing', type=float, default=0.1, metavar='M',
help='label smoothing (default: 0.1)')
parser.add_argument('--seed', type=int, default=42, metavar='S', parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)') help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N', parser.add_argument('--log-interval', type=int, default=50, metavar='N',
@ -73,22 +89,51 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False, parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)') help='path to output folder (default: none, current dir)')
parser.add_argument("--local_rank", default=0, type=int)
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
if args.output: args.distributed = False
output_base = args.output if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0'
args.world_size = 1
r = -1
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
r = torch.distributed.get_rank()
if args.distributed:
print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.'
% (args.world_size, r))
else: else:
output_base = './output' print('Training with a single process with %d GPUs.' % args.num_gpu)
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"), output_dir = ''
args.model, if args.local_rank == 0:
str(args.img_size)]) if args.output:
output_dir = get_outdir(output_base, 'train', exp_name) output_base = args.output
else:
output_base = './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name)
batch_size = args.batch_size batch_size = args.batch_size
torch.manual_seed(args.seed) torch.manual_seed(args.seed)
@ -103,10 +148,11 @@ def main():
batch_size=batch_size, batch_size=batch_size,
is_training=True, is_training=True,
use_prefetcher=True, use_prefetcher=True,
random_erasing=0.5, random_erasing=0.3,
mean=data_mean, mean=data_mean,
std=data_std, std=data_std,
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed,
) )
dataset_eval = Dataset(os.path.join(args.data, 'validation')) dataset_eval = Dataset(os.path.join(args.data, 'validation'))
@ -120,6 +166,7 @@ def main():
mean=data_mean, mean=data_mean,
std=data_std, std=data_std,
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed,
) )
model = model_factory.create_model( model = model_factory.create_model(
@ -156,28 +203,53 @@ def main():
print("=> no checkpoint found at '{}'".format(args.resume)) print("=> no checkpoint found at '{}'".format(args.resume))
return False return False
if args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() if args.amp:
print('Warning: AMP does not work well with nn.DataParallel, disabling. '
'Use distributed mode for multi-GPU AMP.')
args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
model.cuda() model.cuda()
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda()
optimizer = create_optimizer(args, model.parameters()) optimizer = create_optimizer(args, model.parameters())
if optimizer_state is not None: if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state) optimizer.load_state_dict(optimizer_state)
if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O3')
use_amp = True
print('AMP enabled')
else:
use_amp = False
print('AMP disabled')
if args.distributed:
model = DDP(model, delay_allreduce=True)
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
print(num_epochs) if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs)
saver = CheckpointSaver(checkpoint_dir=output_dir) saver = None
if output_dir:
saver = CheckpointSaver(checkpoint_dir=output_dir)
best_loss = None best_loss = None
try: try:
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch( train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir) lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
eval_metrics = validate( eval_metrics = validate(
model, loader_eval, validate_loss_fn, args) model, loader_eval, validate_loss_fn, args)
@ -189,16 +261,17 @@ def main():
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_loss is None) write_header=best_loss is None)
# save proper checkpoint with eval metric if saver is not None:
best_loss = saver.save_checkpoint({ # save proper checkpoint with eval metric
'epoch': epoch + 1, best_loss = saver.save_checkpoint({
'arch': args.model, 'epoch': epoch + 1,
'state_dict': model.state_dict(), 'arch': args.model,
'optimizer': optimizer.state_dict(), 'state_dict': model.state_dict(),
'args': args, 'optimizer': optimizer.state_dict(),
}, 'args': args,
epoch=epoch + 1, },
metric=eval_metrics['eval_loss']) epoch=epoch + 1,
metric=eval_metrics['eval_loss'])
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -207,7 +280,7 @@ def main():
def train_epoch( def train_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=''): lr_scheduler=None, saver=None, output_dir='', use_amp=False):
batch_time_m = AverageMeter() batch_time_m = AverageMeter()
data_time_m = AverageMeter() data_time_m = AverageMeter()
@ -225,10 +298,15 @@ def train_epoch(
output = model(input) output = model(input)
loss = loss_fn(output, target) loss = loss_fn(output, target)
losses_m.update(loss.item(), input.size(0)) if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
@ -239,30 +317,36 @@ def train_epoch(
lrl = [param_group['lr'] for param_group in optimizer.param_groups] lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl) lr = sum(lrl) / len(lrl)
print('Train: {} [{}/{} ({:.0f}%)] ' if args.distributed:
'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' reduced_loss = reduce_tensor(loss.data, args.world_size)
'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' losses_m.update(reduced_loss.item(), input.size(0))
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
'LR: {lr:.4f} ' if args.local_rank == 0:
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( print('Train: {} [{}/{} ({:.0f}%)] '
epoch, 'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
batch_idx, len(loader), 'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
100. * batch_idx / last_idx, '({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
loss=losses_m, 'LR: {lr:.4f} '
batch_time=batch_time_m, 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
rate=input.size(0) / batch_time_m.val, epoch,
rate_avg=input.size(0) / batch_time_m.avg, batch_idx, len(loader),
lr=lr, 100. * batch_idx / last_idx,
data_time=data_time_m)) loss=losses_m,
batch_time=batch_time_m,
if args.save_images: rate=input.size(0) * args.world_size / batch_time_m.val,
torchvision.utils.save_image( rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
input, lr=lr,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx), data_time=data_time_m))
padding=0,
normalize=True) if args.save_images and output_dir:
torchvision.utils.save_image(
if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0: input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if args.local_rank == 0 and (
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({ saver.save_recovery({
'epoch': save_epoch, 'epoch': save_epoch,
@ -309,15 +393,22 @@ def validate(model, loader, loss_fn, args):
loss = loss_fn(output, target) loss = loss_fn(output, target)
prec1, prec5 = accuracy(output, target, topk=(1, 5)) prec1, prec5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
prec1 = reduce_tensor(prec1, args.world_size)
prec5 = reduce_tensor(prec5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize() torch.cuda.synchronize()
losses_m.update(loss.item(), input.size(0)) losses_m.update(reduced_loss.item(), input.size(0))
prec1_m.update(prec1.item(), output.size(0)) prec1_m.update(prec1.item(), output.size(0))
prec5_m.update(prec5.item(), output.size(0)) prec5_m.update(prec5.item(), output.size(0))
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
end = time.time() end = time.time()
if last_batch or batch_idx % args.log_interval == 0: if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
print('Test: [{0}/{1}]\t' print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
@ -362,6 +453,7 @@ def create_optimizer(args, parameters):
def create_scheduler(args, optimizer): def create_scheduler(args, optimizer):
num_epochs = args.epochs num_epochs = args.epochs
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine': if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler( lr_scheduler = scheduler.CosineLRScheduler(
optimizer, optimizer,
@ -369,8 +461,8 @@ def create_scheduler(args, optimizer):
t_mul=1.0, t_mul=1.0,
lr_min=1e-5, lr_min=1e-5,
decay_rate=args.decay_rate, decay_rate=args.decay_rate,
warmup_lr_init=1e-4, warmup_lr_init=args.warmup_lr,
warmup_t=0, warmup_t=args.warmup_epochs,
cycle_limit=1, cycle_limit=1,
t_in_epochs=True, t_in_epochs=True,
) )
@ -381,8 +473,8 @@ def create_scheduler(args, optimizer):
t_initial=num_epochs, t_initial=num_epochs,
t_mul=1.0, t_mul=1.0,
lr_min=1e-5, lr_min=1e-5,
warmup_lr_init=.001, warmup_lr_init=args.warmup_lr,
warmup_t=3, warmup_t=args.warmup_epochs,
cycle_limit=1, cycle_limit=1,
t_in_epochs=True, t_in_epochs=True,
) )
@ -392,9 +484,18 @@ def create_scheduler(args, optimizer):
optimizer, optimizer,
decay_t=args.decay_epochs, decay_t=args.decay_epochs,
decay_rate=args.decay_rate, decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
) )
return lr_scheduler, num_epochs return lr_scheduler, num_epochs
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= n
return rt
if __name__ == '__main__': if __name__ == '__main__':
main() main()

@ -9,6 +9,7 @@ import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict
from models import create_model from models import create_model
from data import Dataset, create_loader, get_model_meanstd from data import Dataset, create_loader, get_model_meanstd
@ -60,7 +61,14 @@ def main():
print("=> loading checkpoint '{}'".format(args.checkpoint)) print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint) checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict']) new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.checkpoint)) print("=> loaded checkpoint '{}'".format(args.checkpoint))

Loading…
Cancel
Save