get static data on a side stream

pull/1271/head
Xiao Wang 3 years ago
parent cd926775fb
commit 9b299cbcf1

@ -795,7 +795,12 @@ def train_one_epoch(
if cg_stage is None: # The default non-CUDAGraph mode
loss = _step(input, target)
elif cg_stage == 'get_static_data':
loss = _step(input, target)
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
optimizer.zero_grad(set_to_none=True)
loss = _step(input, target)
torch.cuda.current_stream().wait_stream(s)
return (input, target, loss)
elif cg_stage == 'capture':
s = torch.cuda.Stream()
@ -803,7 +808,7 @@ def train_one_epoch(
with torch.cuda.stream(s):
for _ in range(11):
optimizer.zero_grad(set_to_none=True)
_step(cg_static_input, cg_static_target, cg_static_loss)
_step(input, target)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
@ -811,7 +816,7 @@ def train_one_epoch(
with torch.cuda.graph(g):
_step(cg_static_input, cg_static_target, cg_static_loss)
return g
elif cg_stage == 'train':
elif cg_stage == 'replay':
cg_static_input.copy_(input)
cg_static_target.copy_(target)
cuda_graph.replay()

Loading…
Cancel
Save