From 1726f27f8d55e8a00fa2849629511a6576cbadd8 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 09:40:46 -0700 Subject: [PATCH] add cuda graph and aot_autograd to benchmark.py --- benchmark.py | 75 +++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 13 deletions(-) diff --git a/benchmark.py b/benchmark.py index 422da45d..9e75f6cc 100755 --- a/benchmark.py +++ b/benchmark.py @@ -51,6 +51,11 @@ except ImportError as e: FlopCountAnalysis = None has_fvcore_profiling = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -99,6 +104,10 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--cuda-graph', default=False, action='store_true', + help="Enable CUDA Graph support") +parser.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)") # train optimizer parameters @@ -188,14 +197,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( - self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', - fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, + cuda_graph=False, precision='float32', fuser='', num_warm_iter=10, num_bench_iter=50, + use_train_size=False, **kwargs): self.model_name = model_name self.detail = detail self.device = device self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + self.cuda_graph = cuda_graph if fuser: set_jit_fuser(fuser) @@ -220,6 +231,9 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) self.scripted = True + if aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + self.model = memory_efficient_fusion(self.model) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] @@ -248,11 +262,11 @@ class InferenceBenchmarkRunner(BenchmarkRunner): self.model.eval() def run(self): - def _step(): + def _step(sync=True): t_step_start = self.time_fn() with self.amp_autocast(): output = self.model(self.example_inputs) - t_step_end = self.time_fn(True) + t_step_end = self.time_fn(sync) return t_step_end - t_step_start _logger.info( @@ -265,12 +279,28 @@ class InferenceBenchmarkRunner(BenchmarkRunner): for _ in range(self.num_warm_iter): _step() + if self.cuda_graph: + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + _step(sync=False) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + _step(sync=False) + total_step = 0. num_samples = 0 - t_run_start = self.time_fn() + t_run_start = self.time_fn(True) for i in range(self.num_bench_iter): - delta_fwd = _step() - total_step += delta_fwd + if self.cuda_graph: + g.replay() + total_step = self.time_fn(True) - t_run_start + else: + delta_fwd = _step() + total_step += delta_fwd num_samples += self.batch_size num_steps = i + 1 if num_steps % self.log_freq == 0: @@ -332,7 +362,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) def run(self): - def _step(detail=False): + def _step(detail=False, sync=True): self.optimizer.zero_grad() # can this be ignored? t_start = self.time_fn() t_fwd_end = t_start @@ -348,7 +378,7 @@ class TrainBenchmarkRunner(BenchmarkRunner): if detail: t_bwd_end = self.time_fn(True) self.optimizer.step() - t_end = self.time_fn(True) + t_end = self.time_fn(sync) if detail: delta_fwd = t_fwd_end - t_start delta_bwd = t_bwd_end - t_fwd_end @@ -367,7 +397,21 @@ class TrainBenchmarkRunner(BenchmarkRunner): for _ in range(self.num_warm_iter): _step() - t_run_start = self.time_fn() + if self.cuda_graph: + assert self.detail is False, "mode --detail is not supported with CUDA Graph in training benchmark" + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + _step(sync=False) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + self.optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(g): + _step(sync=False) + + t_run_start = self.time_fn(True) if self.detail: total_fwd = 0. total_bwd = 0. @@ -405,9 +449,13 @@ class TrainBenchmarkRunner(BenchmarkRunner): total_step = 0. num_samples = 0 for i in range(self.num_bench_iter): - delta_step = _step(False) + if self.cuda_graph: + g.replay() + total_step = self.time_fn(True) - t_run_start + else: + delta_step = _step(False) + total_step += delta_step num_samples += self.batch_size - total_step += delta_step num_steps = (i + 1) if num_steps % self.log_freq == 0: _logger.info( @@ -506,7 +554,8 @@ def benchmark(args): args.precision = 'amp' _logger.info(f'Benchmarking in {args.precision} precision. ' f'{"NHWC" if args.channels_last else "NCHW"} layout. ' - f'torchscript {"enabled" if args.torchscript else "disabled"}') + f'torchscript {"enabled" if args.torchscript else "disabled"}. ' + f'cuda graph {"enabled" if args.cuda_graph else "disabled"}') bench_kwargs = vars(args).copy() bench_kwargs.pop('amp')