update cuda graph mode for AMP and DDP

pull/1271/head
Xiao Wang 3 years ago
parent 9b299cbcf1
commit be54f860d0

@ -17,7 +17,7 @@ from .clip_grad import dispatch_clip_grad
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
@ -32,6 +32,8 @@ class ApexScaler:
if 'load_state_dict' in amp.__dict__:
amp.load_state_dict(state_dict)
def step_and_update(self):
pass
class NativeScaler:
state_dict_key = "amp_scaler"
@ -39,17 +41,25 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
self._optimizer = optimizer
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
if step_and_update:
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def step_and_update(self):
# need to separate this step from "scaler.scale(loss)" since scalar.step() syncs GPU with CPU,
# which is not allowed for cuda graph capture
self._scaler.step(self._optimizer)
self._scaler.update()

@ -364,6 +364,10 @@ def main():
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
if args.cuda_graph:
os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '0'
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.environ['LOCAL_RANK'])
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
@ -375,6 +379,9 @@ def main():
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
if args.cuda_graph:
_logger.info("Training with CUDA Graph support is enabled.")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
@ -507,7 +514,13 @@ def main():
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# Wrap DDP model in a side stream is needed for cuda graph workload.
# Since it does no harm to regular eager mode, we'll just use this with one less if-statement.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
torch.cuda.current_stream().wait_stream(s)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
@ -750,7 +763,7 @@ def train_one_epoch(
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
def _step(input_, target_, loss_=None):
def _step(input_, target_, loss_=None, scaler_step_and_update=True):
if args.channels_last:
input = input_.contiguous(memory_format=torch.channels_last)
else:
@ -766,7 +779,8 @@ def train_one_epoch(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
create_graph=second_order,
step_and_update=scaler_step_and_update)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
@ -814,13 +828,15 @@ def train_one_epoch(
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
_step(cg_static_input, cg_static_target, cg_static_loss)
_step(cg_static_input, cg_static_target, cg_static_loss, False)
return g
elif cg_stage == 'replay':
cg_static_input.copy_(input)
cg_static_target.copy_(target)
cuda_graph.replay()
loss = cg_static_loss
if loss_scaler is not None:
loss_scaler.step_and_update()
if not args.distributed:
losses_m.update(loss.item(), input.size(0))

Loading…
Cancel
Save