|
|
|
@ -671,11 +671,13 @@ def main():
|
|
|
|
|
loader_train.sampler.set_epoch(epoch)
|
|
|
|
|
|
|
|
|
|
if not args.cuda_graph:
|
|
|
|
|
# The default eager mode without cuda graph.
|
|
|
|
|
train_metrics = train_one_epoch(
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
|
|
|
|
else:
|
|
|
|
|
# Replay an already captured cuda graph from previous epochs.
|
|
|
|
|
if cuda_graph_ is not None:
|
|
|
|
|
train_metrics = train_one_epoch( # cuda graph replay
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
@ -683,12 +685,13 @@ def main():
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
|
|
|
|
|
cuda_graph=cuda_graph_, cg_stage='replay',
|
|
|
|
|
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
|
|
|
|
|
# Capture a cuda graph.
|
|
|
|
|
else:
|
|
|
|
|
input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data
|
|
|
|
|
input_, target_, loss_ = train_one_epoch( # cuda graph get_static_shapes
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
|
|
|
|
|
cg_stage='get_static_data')
|
|
|
|
|
cg_stage='get_static_shapes')
|
|
|
|
|
cg_static_input = torch.empty_like(input_)
|
|
|
|
|
cg_static_target = torch.empty_like(target_)
|
|
|
|
|
cg_static_loss = torch.empty_like(loss_)
|
|
|
|
@ -808,7 +811,7 @@ def train_one_epoch(
|
|
|
|
|
|
|
|
|
|
if cg_stage is None: # The default non-CUDAGraph mode
|
|
|
|
|
loss = _step(input, target)
|
|
|
|
|
elif cg_stage == 'get_static_data':
|
|
|
|
|
elif cg_stage == 'get_static_shapes':
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
@ -828,7 +831,7 @@ def train_one_epoch(
|
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
with torch.cuda.graph(g):
|
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss, False)
|
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss, scaler_step_and_update=False)
|
|
|
|
|
return g
|
|
|
|
|
elif cg_stage == 'replay':
|
|
|
|
|
cg_static_input.copy_(input)
|
|
|
|
|