From cd926775fb16058a0d3f7bc8a50ff142117e5e8f Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 09:40:58 -0700 Subject: [PATCH] added aot_autograd; cuda graph still wip --- train.py | 119 +++++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 102 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index acdf93c3..7ab92d21 100755 --- a/train.py +++ b/train.py @@ -61,6 +61,12 @@ try: except ImportError: has_wandb = False +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('train') @@ -129,6 +135,10 @@ group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') +group.add_argument('--cuda-graph', default=False, action='store_true', + help="Enable CUDA Graph support") +group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)") # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') @@ -445,6 +455,9 @@ def main(): assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' model = torch.jit.script(model) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) @@ -636,14 +649,48 @@ def main(): f.write(args_text) try: + cuda_graph_ = None + cg_static_input = None + cg_static_target = None + cg_static_loss = None for epoch in range(start_epoch, num_epochs): if args.distributed and hasattr(loader_train.sampler, 'set_epoch'): loader_train.sampler.set_epoch(epoch) - train_metrics = train_one_epoch( - epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, - amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + if not args.cuda_graph: + train_metrics = train_one_epoch( + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) + else: + if cuda_graph_ is not None: + train_metrics = train_one_epoch( # cuda graph replay + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cuda_graph=cuda_graph_, cg_stage='replay', + cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) + else: + input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cg_stage='get_static_data') + cg_static_input = torch.empty_like(input_) + cg_static_target = torch.empty_like(target_) + cg_static_loss = torch.empty_like(loss_) + cuda_graph_ = train_one_epoch( # cuda graph capture + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cuda_graph=None, cg_stage='capture', + cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) + train_metrics = train_one_epoch( # cuda graph replay + epoch, model, loader_train, optimizer, train_loss_fn, args, + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, + cuda_graph=cuda_graph_, cg_stage='replay', + cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: @@ -682,7 +729,9 @@ def main(): def train_one_epoch( epoch, model, loader, optimizer, loss_fn, args, lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress, - loss_scaler=None, model_ema=None, mixup_fn=None): + loss_scaler=None, model_ema=None, mixup_fn=None, + cuda_graph=None, cg_stage=None, + cg_static_input=None, cg_static_target=None, cg_static_loss=None): if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.prefetcher and loader.mixup_enabled: @@ -700,22 +749,16 @@ def train_one_epoch( end = time.time() last_idx = len(loader) - 1 num_updates = epoch * len(loader) - for batch_idx, (input, target) in enumerate(loader): - last_batch = batch_idx == last_idx - data_time_m.update(time.time() - end) - if not args.prefetcher: - input, target = input.cuda(), target.cuda() - if mixup_fn is not None: - input, target = mixup_fn(input, target) + + def _step(input_, target_, loss_=None): if args.channels_last: - input = input.contiguous(memory_format=torch.channels_last) + input = input_.contiguous(memory_format=torch.channels_last) + else: + input = input_ with amp_autocast(): output = model(input) - loss = loss_fn(output, target) - - if not args.distributed: - losses_m.update(loss.item(), input.size(0)) + loss = loss_fn(output, target_) optimizer.zero_grad() if loss_scaler is not None: @@ -735,6 +778,48 @@ def train_one_epoch( if model_ema is not None: model_ema.update(model) + if loss_ is not None: + loss_.copy_(loss) + return None + else: + return loss + + for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + data_time_m.update(time.time() - end) + if not args.prefetcher: + input, target = input.cuda(), target.cuda() + if mixup_fn is not None: + input, target = mixup_fn(input, target) + + if cg_stage is None: # The default non-CUDAGraph mode + loss = _step(input, target) + elif cg_stage == 'get_static_data': + loss = _step(input, target) + return (input, target, loss) + elif cg_stage == 'capture': + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(11): + optimizer.zero_grad(set_to_none=True) + _step(cg_static_input, cg_static_target, cg_static_loss) + torch.cuda.current_stream().wait_stream(s) + + g = torch.cuda.CUDAGraph() + optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(g): + _step(cg_static_input, cg_static_target, cg_static_loss) + return g + elif cg_stage == 'train': + cg_static_input.copy_(input) + cg_static_target.copy_(target) + cuda_graph.replay() + loss = cg_static_loss + + if not args.distributed: + losses_m.update(loss.item(), input.size(0)) + torch.cuda.synchronize() num_updates += 1 batch_time_m.update(time.time() - end)