add some comments

pull/1271/head
Xiao Wang 3 years ago
parent be54f860d0
commit 01010b5aa7

@ -671,11 +671,13 @@ def main():
loader_train.sampler.set_epoch(epoch)
if not args.cuda_graph:
# The default eager mode without cuda graph.
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
else:
# Replay an already captured cuda graph from previous epochs.
if cuda_graph_ is not None:
train_metrics = train_one_epoch( # cuda graph replay
epoch, model, loader_train, optimizer, train_loss_fn, args,
@ -683,12 +685,13 @@ def main():
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
cuda_graph=cuda_graph_, cg_stage='replay',
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
# Capture a cuda graph.
else:
input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data
input_, target_, loss_ = train_one_epoch( # cuda graph get_static_shapes
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
cg_stage='get_static_data')
cg_stage='get_static_shapes')
cg_static_input = torch.empty_like(input_)
cg_static_target = torch.empty_like(target_)
cg_static_loss = torch.empty_like(loss_)
@ -808,7 +811,7 @@ def train_one_epoch(
if cg_stage is None: # The default non-CUDAGraph mode
loss = _step(input, target)
elif cg_stage == 'get_static_data':
elif cg_stage == 'get_static_shapes':
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
@ -828,7 +831,7 @@ def train_one_epoch(
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
_step(cg_static_input, cg_static_target, cg_static_loss, False)
_step(cg_static_input, cg_static_target, cg_static_loss, scaler_step_and_update=False)
return g
elif cg_stage == 'replay':
cg_static_input.copy_(input)

Loading…
Cancel
Save