add cuda graph and aot_autograd to benchmark.py

pull/1271/head
Xiao Wang 3 years ago
parent fd360ac951
commit 1726f27f8d

@ -51,6 +51,11 @@ except ImportError as e:
FlopCountAnalysis = None FlopCountAnalysis = None
has_fvcore_profiling = False has_fvcore_profiling = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -99,6 +104,10 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference') help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--cuda-graph', default=False, action='store_true',
help="Enable CUDA Graph support")
parser.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)")
# train optimizer parameters # train optimizer parameters
@ -188,14 +197,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner: class BenchmarkRunner:
def __init__( def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False,
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): cuda_graph=False, precision='float32', fuser='', num_warm_iter=10, num_bench_iter=50,
use_train_size=False, **kwargs):
self.model_name = model_name self.model_name = model_name
self.detail = detail self.detail = detail
self.device = device self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False) self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
self.cuda_graph = cuda_graph
if fuser: if fuser:
set_jit_fuser(fuser) set_jit_fuser(fuser)
@ -220,6 +231,9 @@ class BenchmarkRunner:
if torchscript: if torchscript:
self.model = torch.jit.script(self.model) self.model = torch.jit.script(self.model)
self.scripted = True self.scripted = True
if aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model)
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size'] self.input_size = data_config['input_size']
@ -248,11 +262,11 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
self.model.eval() self.model.eval()
def run(self): def run(self):
def _step(): def _step(sync=True):
t_step_start = self.time_fn() t_step_start = self.time_fn()
with self.amp_autocast(): with self.amp_autocast():
output = self.model(self.example_inputs) output = self.model(self.example_inputs)
t_step_end = self.time_fn(True) t_step_end = self.time_fn(sync)
return t_step_end - t_step_start return t_step_end - t_step_start
_logger.info( _logger.info(
@ -265,12 +279,28 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
for _ in range(self.num_warm_iter): for _ in range(self.num_warm_iter):
_step() _step()
if self.cuda_graph:
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_step(sync=False)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
_step(sync=False)
total_step = 0. total_step = 0.
num_samples = 0 num_samples = 0
t_run_start = self.time_fn() t_run_start = self.time_fn(True)
for i in range(self.num_bench_iter): for i in range(self.num_bench_iter):
delta_fwd = _step() if self.cuda_graph:
total_step += delta_fwd g.replay()
total_step = self.time_fn(True) - t_run_start
else:
delta_fwd = _step()
total_step += delta_fwd
num_samples += self.batch_size num_samples += self.batch_size
num_steps = i + 1 num_steps = i + 1
if num_steps % self.log_freq == 0: if num_steps % self.log_freq == 0:
@ -332,7 +362,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes) (batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
def run(self): def run(self):
def _step(detail=False): def _step(detail=False, sync=True):
self.optimizer.zero_grad() # can this be ignored? self.optimizer.zero_grad() # can this be ignored?
t_start = self.time_fn() t_start = self.time_fn()
t_fwd_end = t_start t_fwd_end = t_start
@ -348,7 +378,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
if detail: if detail:
t_bwd_end = self.time_fn(True) t_bwd_end = self.time_fn(True)
self.optimizer.step() self.optimizer.step()
t_end = self.time_fn(True) t_end = self.time_fn(sync)
if detail: if detail:
delta_fwd = t_fwd_end - t_start delta_fwd = t_fwd_end - t_start
delta_bwd = t_bwd_end - t_fwd_end delta_bwd = t_bwd_end - t_fwd_end
@ -367,7 +397,21 @@ class TrainBenchmarkRunner(BenchmarkRunner):
for _ in range(self.num_warm_iter): for _ in range(self.num_warm_iter):
_step() _step()
t_run_start = self.time_fn() if self.cuda_graph:
assert self.detail is False, "mode --detail is not supported with CUDA Graph in training benchmark"
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_step(sync=False)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
self.optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
_step(sync=False)
t_run_start = self.time_fn(True)
if self.detail: if self.detail:
total_fwd = 0. total_fwd = 0.
total_bwd = 0. total_bwd = 0.
@ -405,9 +449,13 @@ class TrainBenchmarkRunner(BenchmarkRunner):
total_step = 0. total_step = 0.
num_samples = 0 num_samples = 0
for i in range(self.num_bench_iter): for i in range(self.num_bench_iter):
delta_step = _step(False) if self.cuda_graph:
g.replay()
total_step = self.time_fn(True) - t_run_start
else:
delta_step = _step(False)
total_step += delta_step
num_samples += self.batch_size num_samples += self.batch_size
total_step += delta_step
num_steps = (i + 1) num_steps = (i + 1)
if num_steps % self.log_freq == 0: if num_steps % self.log_freq == 0:
_logger.info( _logger.info(
@ -506,7 +554,8 @@ def benchmark(args):
args.precision = 'amp' args.precision = 'amp'
_logger.info(f'Benchmarking in {args.precision} precision. ' _logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. ' f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}') f'torchscript {"enabled" if args.torchscript else "disabled"}. '
f'cuda graph {"enabled" if args.cuda_graph else "disabled"}')
bench_kwargs = vars(args).copy() bench_kwargs = vars(args).copy()
bench_kwargs.pop('amp') bench_kwargs.pop('amp')

Loading…
Cancel
Save