Distributed (multi-process) train, multi-gpu single process train, and NVIDIA AMP support

pull/1/head
Ross Wightman 5 years ago
parent 6f9a0c8ef2
commit 5180f94c7e

@ -91,15 +91,13 @@ class RandomErasingTorch:
def __init__(
self,
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
per_pixel=False, rand_color=False,
device='cuda'):
per_pixel=False, rand_color=False):
self.probability = probability
self.sl = sl
self.sh = sh
self.min_aspect = min_aspect
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
self.rand_color = rand_color # per block random, bounded by [pl, ph]
self.device = device
def __call__(self, batch):
batch_size, chan, img_h, img_w = batch.size()
@ -115,15 +113,15 @@ class RandomErasingTorch:
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if self.rand_color:
c = torch.empty((chan, 1, 1), dtype=batch.dtype, device=self.device).normal_()
c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_()
elif not self.per_pixel:
c = torch.zeros((chan, 1, 1), dtype=batch.dtype, device=self.device)
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
if self.per_pixel:
img[:, top:top + h, left:left + w] = torch.empty(
(chan, h, w), dtype=batch.dtype, device=self.device).normal_()
(chan, h, w), dtype=batch.dtype).cuda().normal_()
else:
img[:, top:top + h, left:left + w] = c
break

@ -18,25 +18,19 @@ class PrefetchLoader:
def __init__(self,
loader,
fp16=False,
random_erasing=0.,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
self.loader = loader
self.fp16 = fp16
self.random_erasing = random_erasing
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
if random_erasing:
self.random_erasing = RandomErasingTorch(
probability=random_erasing, per_pixel=True)
probability=random_erasing, per_pixel=False)
else:
self.random_erasing = None
if self.fp16:
self.mean = self.mean.half()
self.std = self.std.half()
def __iter__(self):
stream = torch.cuda.Stream()
first = True
@ -45,10 +39,7 @@ class PrefetchLoader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half()
else:
next_input = next_input.float()
next_input = next_input.float()
next_input = next_input.sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
@ -67,6 +58,10 @@ class PrefetchLoader:
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
def create_loader(
dataset,
@ -78,6 +73,7 @@ def create_loader(
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
):
if is_training:
@ -95,11 +91,16 @@ def create_loader(
dataset.transform = transform
sampler = None
if distributed:
sampler = tdata.distributed.DistributedSampler(dataset)
loader = tdata.DataLoader(
dataset,
batch_size=batch_size,
shuffle=is_training,
shuffle=sampler is None and is_training,
num_workers=num_workers,
sampler=sampler,
collate_fn=fast_collate if use_prefetcher else tdata.dataloader.default_collate,
)
if use_prefetcher:

@ -0,0 +1,5 @@
#!/bin/bash
NUM_PROC=$1
shift
python -m torch.distributed.launch --nproc_per_node=$NUM_PROC dtrain.py "$@"

@ -1,19 +1,29 @@
import argparse
import time
from collections import OrderedDict
from datetime import datetime
try:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
has_apex = True
except ImportError:
has_apex = False
from data import *
from models import model_factory
from utils import *
from optim import Nadam, AdaBound
from loss import LabelSmoothingCrossEntropy
import scheduler
import torch
import torch.nn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import torch.distributed as dist
import torchvision.utils
torch.backends.cudnn.benchmark = True
@ -45,6 +55,8 @@ parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-epochs', type=int, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
@ -53,10 +65,14 @@ parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
help='Dropout rate (default: 0.1)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
help='warmup learning rate (default: 0.0001)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.0005, metavar='M',
parser.add_argument('--weight-decay', type=float, default=0.0001, metavar='M',
help='weight decay (default: 0.0001)')
parser.add_argument('--smoothing', type=float, default=0.1, metavar='M',
help='label smoothing (default: 0.1)')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
@ -73,22 +89,51 @@ parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument("--local_rank", default=0, type=int)
def main():
args = parser.parse_args()
if args.output:
output_base = args.output
args.distributed = False
if 'WORLD_SIZE' in os.environ:
args.distributed = int(os.environ['WORLD_SIZE']) > 1
if args.distributed and args.num_gpu > 1:
print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
args.num_gpu = 1
args.device = 'cuda:0'
args.world_size = 1
r = -1
if args.distributed:
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
args.world_size = torch.distributed.get_world_size()
r = torch.distributed.get_rank()
if args.distributed:
print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.'
% (args.world_size, r))
else:
output_base = './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name)
print('Training with a single process with %d GPUs.' % args.num_gpu)
output_dir = ''
if args.local_rank == 0:
if args.output:
output_base = args.output
else:
output_base = './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name)
batch_size = args.batch_size
torch.manual_seed(args.seed)
@ -103,10 +148,11 @@ def main():
batch_size=batch_size,
is_training=True,
use_prefetcher=True,
random_erasing=0.5,
random_erasing=0.3,
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
@ -120,6 +166,7 @@ def main():
mean=data_mean,
std=data_std,
num_workers=args.workers,
distributed=args.distributed,
)
model = model_factory.create_model(
@ -156,28 +203,53 @@ def main():
print("=> no checkpoint found at '{}'".format(args.resume))
return False
if args.smoothing:
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
validate_loss_fn = nn.CrossEntropyLoss().cuda()
else:
train_loss_fn = nn.CrossEntropyLoss().cuda()
validate_loss_fn = train_loss_fn
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
if args.amp:
print('Warning: AMP does not work well with nn.DataParallel, disabling. '
'Use distributed mode for multi-GPU AMP.')
args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
model.cuda()
train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda()
optimizer = create_optimizer(args, model.parameters())
if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state)
if has_apex and args.amp:
model, optimizer = amp.initialize(model, optimizer, opt_level='O3')
use_amp = True
print('AMP enabled')
else:
use_amp = False
print('AMP disabled')
if args.distributed:
model = DDP(model, delay_allreduce=True)
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
print(num_epochs)
if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs)
saver = CheckpointSaver(checkpoint_dir=output_dir)
saver = None
if output_dir:
saver = CheckpointSaver(checkpoint_dir=output_dir)
best_loss = None
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed:
loader_train.sampler.set_epoch(epoch)
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir)
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
eval_metrics = validate(
model, loader_eval, validate_loss_fn, args)
@ -189,16 +261,17 @@ def main():
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_loss is None)
# save proper checkpoint with eval metric
best_loss = saver.save_checkpoint({
'epoch': epoch + 1,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=epoch + 1,
metric=eval_metrics['eval_loss'])
if saver is not None:
# save proper checkpoint with eval metric
best_loss = saver.save_checkpoint({
'epoch': epoch + 1,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=epoch + 1,
metric=eval_metrics['eval_loss'])
except KeyboardInterrupt:
pass
@ -207,7 +280,7 @@ def main():
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=''):
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
batch_time_m = AverageMeter()
data_time_m = AverageMeter()
@ -225,10 +298,15 @@ def train_epoch(
output = model(input)
loss = loss_fn(output, target)
losses_m.update(loss.item(), input.size(0))
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
optimizer.zero_grad()
loss.backward()
if use_amp:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
optimizer.step()
torch.cuda.synchronize()
@ -239,30 +317,36 @@ def train_epoch(
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
print('Train: {} [{}/{} ({:.0f}%)] '
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
'LR: {lr:.4f} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) / batch_time_m.val,
rate_avg=input.size(0) / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0:
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
losses_m.update(reduced_loss.item(), input.size(0))
if args.local_rank == 0:
print('Train: {} [{}/{} ({:.0f}%)] '
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
'LR: {lr:.4f} '
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx, len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) * args.world_size / batch_time_m.val,
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
lr=lr,
data_time=data_time_m))
if args.save_images and output_dir:
torchvision.utils.save_image(
input,
os.path.join(output_dir, 'train-batch-%d.jpg' % batch_idx),
padding=0,
normalize=True)
if args.local_rank == 0 and (
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({
'epoch': save_epoch,
@ -309,15 +393,22 @@ def validate(model, loader, loss_fn, args):
loss = loss_fn(output, target)
prec1, prec5 = accuracy(output, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data, args.world_size)
prec1 = reduce_tensor(prec1, args.world_size)
prec5 = reduce_tensor(prec5, args.world_size)
else:
reduced_loss = loss.data
torch.cuda.synchronize()
losses_m.update(loss.item(), input.size(0))
losses_m.update(reduced_loss.item(), input.size(0))
prec1_m.update(prec1.item(), output.size(0))
prec5_m.update(prec5.item(), output.size(0))
batch_time_m.update(time.time() - end)
end = time.time()
if last_batch or batch_idx % args.log_interval == 0:
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
@ -362,6 +453,7 @@ def create_optimizer(args, parameters):
def create_scheduler(args, optimizer):
num_epochs = args.epochs
#FIXME expose cycle parms of the scheduler config to arguments
if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler(
optimizer,
@ -369,8 +461,8 @@ def create_scheduler(args, optimizer):
t_mul=1.0,
lr_min=1e-5,
decay_rate=args.decay_rate,
warmup_lr_init=1e-4,
warmup_t=0,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
@ -381,8 +473,8 @@ def create_scheduler(args, optimizer):
t_initial=num_epochs,
t_mul=1.0,
lr_min=1e-5,
warmup_lr_init=.001,
warmup_t=3,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=1,
t_in_epochs=True,
)
@ -392,9 +484,18 @@ def create_scheduler(args, optimizer):
optimizer,
decay_t=args.decay_epochs,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
)
return lr_scheduler, num_epochs
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
rt /= n
return rt
if __name__ == '__main__':
main()

@ -9,6 +9,7 @@ import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from models import create_model
from data import Dataset, create_loader, get_model_meanstd
@ -60,7 +61,14 @@ def main():
print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict'])
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.checkpoint))

Loading…
Cancel
Save