From 1726f27f8d55e8a00fa2849629511a6576cbadd8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 09:40:46 -0700 Subject: [PATCH 1/5] add cuda graph and aot_autograd to benchmark.py --- benchmark.py | 75 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 13 deletions(-) diff --git a/benchmark.py b/benchmark.py index 422da45d..9e75f6cc 100755 --- a/benchmark.py +++ b/benchmark.py @@ -51,6 +51,11 @@ except ImportError as e: FlopCountAnalysis = None has_fvcore_profiling = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -99,6 +104,10 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--cuda-graph', default=False, action='store_true', + help="Enable CUDA Graph support") +parser.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)") # train optimizer parameters @@ -188,14 +197,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', - fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, + cuda_graph=False, precision='float32', fuser='', num_warm_iter=10, num_bench_iter=50, + use_train_size=False, **kwargs): self.model_name = model_name self.detail = detail self.device = device self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + self.cuda_graph = cuda_graph if fuser: set_jit_fuser(fuser) @@ -220,6 +231,9 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) self.scripted = True + if aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + self.model = memory_efficient_fusion(self.model) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] @@ -248,11 +262,11 @@ class InferenceBenchmarkRunner(BenchmarkRunner): self.model.eval() def run(self): - def _step(): + def _step(sync=True): t_step_start = self.time_fn() with self.amp_autocast(): output = self.model(self.example_inputs) - t_step_end = self.time_fn(True) + t_step_end = self.time_fn(sync) return t_step_end - t_step_start _logger.info( @@ -265,12 +279,28 @@ class InferenceBenchmarkRunner(BenchmarkRunner): for _ in range(self.num_warm_iter): _step() + if self.cuda_graph: + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + _step(sync=False) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + _step(sync=False) + total_step = 0. num_samples = 0 - t_run_start = self.time_fn() + t_run_start = self.time_fn(True) for i in range(self.num_bench_iter): - delta_fwd = _step() - total_step += delta_fwd + if self.cuda_graph: + g.replay() + total_step = self.time_fn(True) - t_run_start + else: + delta_fwd = _step() + total_step += delta_fwd num_samples += self.batch_size num_steps = i + 1 if num_steps % self.log_freq == 0: @@ -332,7 +362,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) def run(self): - def _step(detail=False): + def _step(detail=False, sync=True): self.optimizer.zero_grad() # can this be ignored? t_start = self.time_fn() t_fwd_end = t_start @@ -348,7 +378,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): if detail: t_bwd_end = self.time_fn(True) self.optimizer.step() - t_end = self.time_fn(True) + t_end = self.time_fn(sync) if detail: delta_fwd = t_fwd_end - t_start delta_bwd = t_bwd_end - t_fwd_end @@ -367,7 +397,21 @@ class TrainBenchmarkRunner(BenchmarkRunner): for _ in range(self.num_warm_iter): _step() - t_run_start = self.time_fn() + if self.cuda_graph: + assert self.detail is False, "mode --detail is not supported with CUDA Graph in training benchmark" + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + _step(sync=False) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + self.optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(g): + _step(sync=False) + + t_run_start = self.time_fn(True) if self.detail: total_fwd = 0. total_bwd = 0. @@ -405,9 +449,13 @@ class TrainBenchmarkRunner(BenchmarkRunner): total_step = 0. num_samples = 0 for i in range(self.num_bench_iter): - delta_step = _step(False) + if self.cuda_graph: + g.replay() + total_step = self.time_fn(True) - t_run_start + else: + delta_step = _step(False) + total_step += delta_step num_samples += self.batch_size - total_step += delta_step num_steps = (i + 1) if num_steps % self.log_freq == 0: _logger.info( @@ -506,7 +554,8 @@ def benchmark(args): args.precision = 'amp' _logger.info(f'Benchmarking in {args.precision} precision. ' f'{"NHWC" if args.channels_last else "NCHW"} layout. ' - f'torchscript {"enabled" if args.torchscript else "disabled"}') + f'torchscript {"enabled" if args.torchscript else "disabled"}. ' + f'cuda graph {"enabled" if args.cuda_graph else "disabled"}') bench_kwargs = vars(args).copy() bench_kwargs.pop('amp') From cd926775fb16058a0d3f7bc8a50ff142117e5e8f Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 09:40:58 -0700 Subject: [PATCH 2/5] added aot_autograd; cuda graph still wip --- train.py | 119 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index acdf93c3..7ab92d21 100755 --- a/train.py +++ b/train.py @@ -61,6 +61,12 @@ try: except ImportError: has_wandb = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') @@ -129,6 +135,10 @@ group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') +group.add_argument('--cuda-graph', default=False, action='store_true', + help="Enable CUDA Graph support") +group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)") # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') @@ -445,6 +455,9 @@ def main(): assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) @@ -636,14 +649,48 @@ def main(): f.write(args_text) try: + cuda_graph_ = None + cg_static_input = None + cg_static_target = None + cg_static_loss = None for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) - train_metrics = train_one_epoch( - epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + if not args.cuda_graph: + train_metrics = train_one_epoch( + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + else: + if cuda_graph_ is not None: + train_metrics = train_one_epoch( # cuda graph replay + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cuda_graph=cuda_graph_, cg_stage='replay', + cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) + else: + input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cg_stage='get_static_data') + cg_static_input = torch.empty_like(input_) + cg_static_target = torch.empty_like(target_) + cg_static_loss = torch.empty_like(loss_) + cuda_graph_ = train_one_epoch( # cuda graph capture + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cuda_graph=None, cg_stage='capture', + cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) + train_metrics = train_one_epoch( # cuda graph replay + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cuda_graph=cuda_graph_, cg_stage='replay', + cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: @@ -682,7 +729,9 @@ def main(): def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, - loss_scaler=None, model_ema=None, mixup_fn=None): + loss_scaler=None, model_ema=None, mixup_fn=None, + cuda_graph=None, cg_stage=None, + cg_static_input=None, cg_static_target=None, cg_static_loss=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -700,22 +749,16 @@ def train_one_epoch( end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) - for batch_idx, (input, target) in enumerate(loader): - last_batch = batch_idx == last_idx - data_time_m.update(time.time() - end) - if not args.prefetcher: - input, target = input.cuda(), target.cuda() - if mixup_fn is not None: - input, target = mixup_fn(input, target) + + def _step(input_, target_, loss_=None): if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) + input = input_.contiguous(memory_format=torch.channels_last) + else: + input = input_ with amp_autocast(): output = model(input) - loss = loss_fn(output, target) - - if not args.distributed: - losses_m.update(loss.item(), input.size(0)) + loss = loss_fn(output, target_) optimizer.zero_grad() if loss_scaler is not None: @@ -735,6 +778,48 @@ def train_one_epoch( if model_ema is not None: model_ema.update(model) + if loss_ is not None: + loss_.copy_(loss) + return None + else: + return loss + + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + data_time_m.update(time.time() - end) + if not args.prefetcher: + input, target = input.cuda(), target.cuda() + if mixup_fn is not None: + input, target = mixup_fn(input, target) + + if cg_stage is None: # The default non-CUDAGraph mode + loss = _step(input, target) + elif cg_stage == 'get_static_data': + loss = _step(input, target) + return (input, target, loss) + elif cg_stage == 'capture': + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(11): + optimizer.zero_grad(set_to_none=True) + _step(cg_static_input, cg_static_target, cg_static_loss) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(g): + _step(cg_static_input, cg_static_target, cg_static_loss) + return g + elif cg_stage == 'train': + cg_static_input.copy_(input) + cg_static_target.copy_(target) + cuda_graph.replay() + loss = cg_static_loss + + if not args.distributed: + losses_m.update(loss.item(), input.size(0)) + torch.cuda.synchronize() num_updates += 1 batch_time_m.update(time.time() - end) From 9b299cbcf1078ede17755a35e659358578404c77 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 10:01:26 -0700 Subject: [PATCH 3/5] get static data on a side stream --- train.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 7ab92d21..3171dda7 100755 --- a/train.py +++ b/train.py @@ -795,7 +795,12 @@ def train_one_epoch( if cg_stage is None: # The default non-CUDAGraph mode loss = _step(input, target) elif cg_stage == 'get_static_data': - loss = _step(input, target) + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + optimizer.zero_grad(set_to_none=True) + loss = _step(input, target) + torch.cuda.current_stream().wait_stream(s) return (input, target, loss) elif cg_stage == 'capture': s = torch.cuda.Stream() @@ -803,7 +808,7 @@ def train_one_epoch( with torch.cuda.stream(s): for _ in range(11): optimizer.zero_grad(set_to_none=True) - _step(cg_static_input, cg_static_target, cg_static_loss) + _step(input, target) torch.cuda.current_stream().wait_stream(s) g = torch.cuda.CUDAGraph() @@ -811,7 +816,7 @@ def train_one_epoch( with torch.cuda.graph(g): _step(cg_static_input, cg_static_target, cg_static_loss) return g - elif cg_stage == 'train': + elif cg_stage == 'replay': cg_static_input.copy_(input) cg_static_target.copy_(target) cuda_graph.replay() From be54f860d0a60b6d2be359835bb4501d6b291391 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 10:26:26 -0700 Subject: [PATCH 4/5] update cuda graph mode for AMP and DDP --- timm/utils/cuda.py | 18 ++++++++++++++---- train.py | 24 ++++++++++++++++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/timm/utils/cuda.py b/timm/utils/cuda.py index 9e7bddf3..0336e22b 100644 --- a/timm/utils/cuda.py +++ b/timm/utils/cuda.py @@ -17,7 +17,7 @@ from .clip_grad import dispatch_clip_grad class ApexScaler: state_dict_key = "amp" - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True): with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward(create_graph=create_graph) if clip_grad is not None: @@ -32,6 +32,8 @@ class ApexScaler: if 'load_state_dict' in amp.__dict__: amp.load_state_dict(state_dict) + def step_and_update(self): + pass class NativeScaler: state_dict_key = "amp_scaler" @@ -39,17 +41,25 @@ class NativeScaler: def __init__(self): self._scaler = torch.cuda.amp.GradScaler() - def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False): + def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True): self._scaler.scale(loss).backward(create_graph=create_graph) + self._optimizer = optimizer if clip_grad is not None: assert parameters is not None self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place dispatch_clip_grad(parameters, clip_grad, mode=clip_mode) - self._scaler.step(optimizer) - self._scaler.update() + if step_and_update: + self._scaler.step(optimizer) + self._scaler.update() def state_dict(self): return self._scaler.state_dict() def load_state_dict(self, state_dict): self._scaler.load_state_dict(state_dict) + + def step_and_update(self): + # need to separate this step from "scaler.scale(loss)" since scalar.step() syncs GPU with CPU, + # which is not allowed for cuda graph capture + self._scaler.step(self._optimizer) + self._scaler.update() diff --git a/train.py b/train.py index 3171dda7..84d4e629 100755 --- a/train.py +++ b/train.py @@ -364,6 +364,10 @@ def main(): args.world_size = 1 args.rank = 0 # global rank if args.distributed: + if args.cuda_graph: + os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '0' + if 'LOCAL_RANK' in os.environ: + args.local_rank = int(os.environ['LOCAL_RANK']) args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') @@ -375,6 +379,9 @@ def main(): _logger.info('Training with a single process on 1 GPUs.') assert args.rank >= 0 + if args.cuda_graph: + _logger.info("Training with CUDA Graph support is enabled.") + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: @@ -507,7 +514,13 @@ def main(): else: if args.local_rank == 0: _logger.info("Using native Torch DistributedDataParallel.") - model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + # Wrap DDP model in a side stream is needed for cuda graph workload. + # Since it does no harm to regular eager mode, we'll just use this with one less if-statement. + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb) + torch.cuda.current_stream().wait_stream(s) # NOTE: EMA model does not need to be wrapped by DDP # setup learning rate schedule and starting epoch @@ -750,7 +763,7 @@ def train_one_epoch( last_idx = len(loader) - 1 num_updates = epoch * len(loader) - def _step(input_, target_, loss_=None): + def _step(input_, target_, loss_=None, scaler_step_and_update=True): if args.channels_last: input = input_.contiguous(memory_format=torch.channels_last) else: @@ -766,7 +779,8 @@ def train_one_epoch( loss, optimizer, clip_grad=args.clip_grad, clip_mode=args.clip_mode, parameters=model_parameters(model, exclude_head='agc' in args.clip_mode), - create_graph=second_order) + create_graph=second_order, + step_and_update=scaler_step_and_update) else: loss.backward(create_graph=second_order) if args.clip_grad is not None: @@ -814,13 +828,15 @@ def train_one_epoch( g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - _step(cg_static_input, cg_static_target, cg_static_loss) + _step(cg_static_input, cg_static_target, cg_static_loss, False) return g elif cg_stage == 'replay': cg_static_input.copy_(input) cg_static_target.copy_(target) cuda_graph.replay() loss = cg_static_loss + if loss_scaler is not None: + loss_scaler.step_and_update() if not args.distributed: losses_m.update(loss.item(), input.size(0)) From 01010b5aa73a570fc4e5803a4f5eb7421dd58ed8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 10:36:39 -0700 Subject: [PATCH 5/5] add some comments --- train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 84d4e629..3f58a095 100755 --- a/train.py +++ b/train.py @@ -671,11 +671,13 @@ def main(): loader_train.sampler.set_epoch(epoch) if not args.cuda_graph: + # The default eager mode without cuda graph. train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) else: + # Replay an already captured cuda graph from previous epochs. if cuda_graph_ is not None: train_metrics = train_one_epoch( # cuda graph replay epoch, model, loader_train, optimizer, train_loss_fn, args, @@ -683,12 +685,13 @@ def main(): amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, cuda_graph=cuda_graph_, cg_stage='replay', cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) + # Capture a cuda graph. else: - input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data + input_, target_, loss_ = train_one_epoch( # cuda graph get_static_shapes epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, - cg_stage='get_static_data') + cg_stage='get_static_shapes') cg_static_input = torch.empty_like(input_) cg_static_target = torch.empty_like(target_) cg_static_loss = torch.empty_like(loss_) @@ -808,7 +811,7 @@ def train_one_epoch( if cg_stage is None: # The default non-CUDAGraph mode loss = _step(input, target) - elif cg_stage == 'get_static_data': + elif cg_stage == 'get_static_shapes': s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): @@ -828,7 +831,7 @@ def train_one_epoch( g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - _step(cg_static_input, cg_static_target, cg_static_loss, False) + _step(cg_static_input, cg_static_target, cg_static_loss, scaler_step_and_update=False) return g elif cg_stage == 'replay': cg_static_input.copy_(input)