diff --git a/train.py b/train.py index 84d4e629..3f58a095 100755 --- a/train.py +++ b/train.py @@ -671,11 +671,13 @@ def main(): loader_train.sampler.set_epoch(epoch) if not args.cuda_graph: + # The default eager mode without cuda graph. train_metrics = train_one_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn) else: + # Replay an already captured cuda graph from previous epochs. if cuda_graph_ is not None: train_metrics = train_one_epoch( # cuda graph replay epoch, model, loader_train, optimizer, train_loss_fn, args, @@ -683,12 +685,13 @@ def main(): amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, cuda_graph=cuda_graph_, cg_stage='replay', cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss) + # Capture a cuda graph. else: - input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data + input_, target_, loss_ = train_one_epoch( # cuda graph get_static_shapes epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn, - cg_stage='get_static_data') + cg_stage='get_static_shapes') cg_static_input = torch.empty_like(input_) cg_static_target = torch.empty_like(target_) cg_static_loss = torch.empty_like(loss_) @@ -808,7 +811,7 @@ def train_one_epoch( if cg_stage is None: # The default non-CUDAGraph mode loss = _step(input, target) - elif cg_stage == 'get_static_data': + elif cg_stage == 'get_static_shapes': s = torch.cuda.Stream() s.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(s): @@ -828,7 +831,7 @@ def train_one_epoch( g = torch.cuda.CUDAGraph() optimizer.zero_grad(set_to_none=True) with torch.cuda.graph(g): - _step(cg_static_input, cg_static_target, cg_static_loss, False) + _step(cg_static_input, cg_static_target, cg_static_loss, scaler_step_and_update=False) return g elif cg_stage == 'replay': cg_static_input.copy_(input)