|
|
|
@ -795,7 +795,12 @@ def train_one_epoch(
|
|
|
|
|
if cg_stage is None: # The default non-CUDAGraph mode
|
|
|
|
|
loss = _step(input, target)
|
|
|
|
|
elif cg_stage == 'get_static_data':
|
|
|
|
|
loss = _step(input, target)
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
loss = _step(input, target)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
return (input, target, loss)
|
|
|
|
|
elif cg_stage == 'capture':
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
@ -803,7 +808,7 @@ def train_one_epoch(
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(11):
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss)
|
|
|
|
|
_step(input, target)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
|
|
@ -811,7 +816,7 @@ def train_one_epoch(
|
|
|
|
|
with torch.cuda.graph(g):
|
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss)
|
|
|
|
|
return g
|
|
|
|
|
elif cg_stage == 'train':
|
|
|
|
|
elif cg_stage == 'replay':
|
|
|
|
|
cg_static_input.copy_(input)
|
|
|
|
|
cg_static_target.copy_(target)
|
|
|
|
|
cuda_graph.replay()
|
|
|
|
|