diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 9e7bddf3..0336e22b 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -17,7 +17,7 @@ from .clip_grad import dispatch_clip_grad class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: @@ -32,6 +32,8 @@ class ApexScaler: if 'load_state_dict' in amp.__dict__: amp.load_state_dict(state_dict) + def step_and_update(self): + pass class NativeScaler: state_dict_key = "amp_scaler" @@ -39,17 +41,25 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True): self._scaler.scale(loss).backward(create_graph=create_graph) + self._optimizer = optimizer if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) - self._scaler.step(optimizer) - self._scaler.update() + if step_and_update: + self._scaler.step(optimizer) + self._scaler.update() def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) + + def step_and_update(self): + # need to separate this step from "scaler.scale(loss)" since scalar.step() syncs GPU with CPU, + # which is not allowed for cuda graph capture + self._scaler.step(self._optimizer) + self._scaler.update() diff --git a/train.py b/train.py index 3171dda7..84d4e629 100755 --- a/train.py +++ b/train.py @@ -364,6 +364,10 @@ def main(): args.world_size = 1 args.rank = 0 # global rank if args.distributed: + if args.cuda_graph: + os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '0' + if 'LOCAL_RANK' in os.environ: + args.local_rank = int(os.environ['LOCAL_RANK']) args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') @@ -375,6 +379,9 @@ def main(): _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 + if args.cuda_graph: + _logger.info("Training with CUDA Graph support is enabled.") + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: @@ -507,7 +514,13 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + # Wrap DDP model in a side stream is needed for cuda graph workload. + # Since it does no harm to regular eager mode, we'll just use this with one less if-statement. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + torch.cuda.current_stream().wait_stream(s) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch @@ -750,7 +763,7 @@ def train_one_epoch( last_idx = len(loader) - 1 num_updates = epoch * len(loader) - def _step(input_, target_, loss_=None): + def _step(input_, target_, loss_=None, scaler_step_and_update=True): if args.channels_last: input = input_.contiguous(memory_format=torch.channels_last) else: @@ -766,7 +779,8 @@ def train_one_epoch( loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order) + create_graph=second_order, + step_and_update=scaler_step_and_update) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: @@ -814,13 +828,15 @@ def train_one_epoch( g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - _step(cg_static_input, cg_static_target, cg_static_loss) + _step(cg_static_input, cg_static_target, cg_static_loss, False) return g elif cg_stage == 'replay': cg_static_input.copy_(input) cg_static_target.copy_(target) cuda_graph.replay() loss = cg_static_loss + if loss_scaler is not None: + loss_scaler.step_and_update() if not args.distributed: losses_m.update(loss.item(), input.size(0))