From 9b299cbcf1078ede17755a35e659358578404c77 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 24 May 2022 10:01:26 -0700 Subject: [PATCH] get static data on a side stream --- train.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/train.py b/train.py index 7ab92d21..3171dda7 100755 --- a/train.py +++ b/train.py @@ -795,7 +795,12 @@ def train_one_epoch( if cg_stage is None: # The default non-CUDAGraph mode loss = _step(input, target) elif cg_stage == 'get_static_data': - loss = _step(input, target) + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + optimizer.zero_grad(set_to_none=True) + loss = _step(input, target) + torch.cuda.current_stream().wait_stream(s) return (input, target, loss) elif cg_stage == 'capture': s = torch.cuda.Stream() @@ -803,7 +808,7 @@ def train_one_epoch( with torch.cuda.stream(s): for _ in range(11): optimizer.zero_grad(set_to_none=True) - _step(cg_static_input, cg_static_target, cg_static_loss) + _step(input, target) torch.cuda.current_stream().wait_stream(s) g = torch.cuda.CUDAGraph() @@ -811,7 +816,7 @@ def train_one_epoch( with torch.cuda.graph(g): _step(cg_static_input, cg_static_target, cg_static_loss) return g - elif cg_stage == 'train': + elif cg_stage == 'replay': cg_static_input.copy_(input) cg_static_target.copy_(target) cuda_graph.replay()