|
|
@ -364,6 +364,10 @@ def main():
|
|
|
|
args.world_size = 1
|
|
|
|
args.world_size = 1
|
|
|
|
args.rank = 0 # global rank
|
|
|
|
args.rank = 0 # global rank
|
|
|
|
if args.distributed:
|
|
|
|
if args.distributed:
|
|
|
|
|
|
|
|
if args.cuda_graph:
|
|
|
|
|
|
|
|
os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '0'
|
|
|
|
|
|
|
|
if 'LOCAL_RANK' in os.environ:
|
|
|
|
|
|
|
|
args.local_rank = int(os.environ['LOCAL_RANK'])
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
args.device = 'cuda:%d' % args.local_rank
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
torch.cuda.set_device(args.local_rank)
|
|
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
|
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
|
|
@ -375,6 +379,9 @@ def main():
|
|
|
|
_logger.info('Training with a single process on 1 GPUs.')
|
|
|
|
_logger.info('Training with a single process on 1 GPUs.')
|
|
|
|
assert args.rank >= 0
|
|
|
|
assert args.rank >= 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if args.cuda_graph:
|
|
|
|
|
|
|
|
_logger.info("Training with CUDA Graph support is enabled.")
|
|
|
|
|
|
|
|
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
use_amp = None
|
|
|
|
use_amp = None
|
|
|
|
if args.amp:
|
|
|
|
if args.amp:
|
|
|
@ -507,7 +514,13 @@ def main():
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
_logger.info("Using native Torch DistributedDataParallel.")
|
|
|
|
_logger.info("Using native Torch DistributedDataParallel.")
|
|
|
|
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
|
|
|
|
# Wrap DDP model in a side stream is needed for cuda graph workload.
|
|
|
|
|
|
|
|
# Since it does no harm to regular eager mode, we'll just use this with one less if-statement.
|
|
|
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
|
|
|
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
|
|
|
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
# NOTE: EMA model does not need to be wrapped by DDP
|
|
|
|
|
|
|
|
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
|
# setup learning rate schedule and starting epoch
|
|
|
@ -750,7 +763,7 @@ def train_one_epoch(
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
num_updates = epoch * len(loader)
|
|
|
|
num_updates = epoch * len(loader)
|
|
|
|
|
|
|
|
|
|
|
|
def _step(input_, target_, loss_=None):
|
|
|
|
def _step(input_, target_, loss_=None, scaler_step_and_update=True):
|
|
|
|
if args.channels_last:
|
|
|
|
if args.channels_last:
|
|
|
|
input = input_.contiguous(memory_format=torch.channels_last)
|
|
|
|
input = input_.contiguous(memory_format=torch.channels_last)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -766,7 +779,8 @@ def train_one_epoch(
|
|
|
|
loss, optimizer,
|
|
|
|
loss, optimizer,
|
|
|
|
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
|
|
|
|
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
|
|
|
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
|
|
|
|
create_graph=second_order)
|
|
|
|
create_graph=second_order,
|
|
|
|
|
|
|
|
step_and_update=scaler_step_and_update)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
|
loss.backward(create_graph=second_order)
|
|
|
|
if args.clip_grad is not None:
|
|
|
|
if args.clip_grad is not None:
|
|
|
@ -814,13 +828,15 @@ def train_one_epoch(
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
with torch.cuda.graph(g):
|
|
|
|
with torch.cuda.graph(g):
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss)
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss, False)
|
|
|
|
return g
|
|
|
|
return g
|
|
|
|
elif cg_stage == 'replay':
|
|
|
|
elif cg_stage == 'replay':
|
|
|
|
cg_static_input.copy_(input)
|
|
|
|
cg_static_input.copy_(input)
|
|
|
|
cg_static_target.copy_(target)
|
|
|
|
cg_static_target.copy_(target)
|
|
|
|
cuda_graph.replay()
|
|
|
|
cuda_graph.replay()
|
|
|
|
loss = cg_static_loss
|
|
|
|
loss = cg_static_loss
|
|
|
|
|
|
|
|
if loss_scaler is not None:
|
|
|
|
|
|
|
|
loss_scaler.step_and_update()
|
|
|
|
|
|
|
|
|
|
|
|
if not args.distributed:
|
|
|
|
if not args.distributed:
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|