|
|
|
@ -61,6 +61,12 @@ try:
|
|
|
|
|
except ImportError:
|
|
|
|
|
has_wandb = False
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from functorch.compile import memory_efficient_fusion
|
|
|
|
|
has_functorch = True
|
|
|
|
|
except ImportError as e:
|
|
|
|
|
has_functorch = False
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
_logger = logging.getLogger('train')
|
|
|
|
|
|
|
|
|
@ -129,6 +135,10 @@ group.add_argument('--fuser', default='', type=str,
|
|
|
|
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
|
|
|
|
group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
|
|
|
|
help='Enable gradient checkpointing through model blocks/stages')
|
|
|
|
|
group.add_argument('--cuda-graph', default=False, action='store_true',
|
|
|
|
|
help="Enable CUDA Graph support")
|
|
|
|
|
group.add_argument('--aot-autograd', default=False, action='store_true',
|
|
|
|
|
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)")
|
|
|
|
|
|
|
|
|
|
# Optimizer parameters
|
|
|
|
|
group = parser.add_argument_group('Optimizer parameters')
|
|
|
|
@ -445,6 +455,9 @@ def main():
|
|
|
|
|
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
|
|
|
|
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
if args.aot_autograd:
|
|
|
|
|
assert has_functorch, "functorch is needed for --aot-autograd"
|
|
|
|
|
model = memory_efficient_fusion(model)
|
|
|
|
|
|
|
|
|
|
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
|
|
|
|
|
|
|
|
|
@ -636,14 +649,48 @@ def main():
|
|
|
|
|
f.write(args_text)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
cuda_graph_ = None
|
|
|
|
|
cg_static_input = None
|
|
|
|
|
cg_static_target = None
|
|
|
|
|
cg_static_loss = None
|
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
|
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
|
|
|
|
|
loader_train.sampler.set_epoch(epoch)
|
|
|
|
|
|
|
|
|
|
train_metrics = train_one_epoch(
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
|
|
|
|
if not args.cuda_graph:
|
|
|
|
|
train_metrics = train_one_epoch(
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
|
|
|
|
|
else:
|
|
|
|
|
if cuda_graph_ is not None:
|
|
|
|
|
train_metrics = train_one_epoch( # cuda graph replay
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
|
|
|
|
|
cuda_graph=cuda_graph_, cg_stage='replay',
|
|
|
|
|
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
|
|
|
|
|
else:
|
|
|
|
|
input_, target_, loss_ = train_one_epoch( # cuda graph get_static_data
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
|
|
|
|
|
cg_stage='get_static_data')
|
|
|
|
|
cg_static_input = torch.empty_like(input_)
|
|
|
|
|
cg_static_target = torch.empty_like(target_)
|
|
|
|
|
cg_static_loss = torch.empty_like(loss_)
|
|
|
|
|
cuda_graph_ = train_one_epoch( # cuda graph capture
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
|
|
|
|
|
cuda_graph=None, cg_stage='capture',
|
|
|
|
|
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
|
|
|
|
|
train_metrics = train_one_epoch( # cuda graph replay
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
|
|
|
|
|
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
|
|
|
|
|
cuda_graph=cuda_graph_, cg_stage='replay',
|
|
|
|
|
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
|
|
|
|
|
|
|
|
|
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
@ -682,7 +729,9 @@ def main():
|
|
|
|
|
def train_one_epoch(
|
|
|
|
|
epoch, model, loader, optimizer, loss_fn, args,
|
|
|
|
|
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
|
|
|
|
|
loss_scaler=None, model_ema=None, mixup_fn=None):
|
|
|
|
|
loss_scaler=None, model_ema=None, mixup_fn=None,
|
|
|
|
|
cuda_graph=None, cg_stage=None,
|
|
|
|
|
cg_static_input=None, cg_static_target=None, cg_static_loss=None):
|
|
|
|
|
|
|
|
|
|
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
|
|
|
|
|
if args.prefetcher and loader.mixup_enabled:
|
|
|
|
@ -700,22 +749,16 @@ def train_one_epoch(
|
|
|
|
|
end = time.time()
|
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
|
num_updates = epoch * len(loader)
|
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
|
data_time_m.update(time.time() - end)
|
|
|
|
|
if not args.prefetcher:
|
|
|
|
|
input, target = input.cuda(), target.cuda()
|
|
|
|
|
if mixup_fn is not None:
|
|
|
|
|
input, target = mixup_fn(input, target)
|
|
|
|
|
|
|
|
|
|
def _step(input_, target_, loss_=None):
|
|
|
|
|
if args.channels_last:
|
|
|
|
|
input = input.contiguous(memory_format=torch.channels_last)
|
|
|
|
|
input = input_.contiguous(memory_format=torch.channels_last)
|
|
|
|
|
else:
|
|
|
|
|
input = input_
|
|
|
|
|
|
|
|
|
|
with amp_autocast():
|
|
|
|
|
output = model(input)
|
|
|
|
|
loss = loss_fn(output, target)
|
|
|
|
|
|
|
|
|
|
if not args.distributed:
|
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
|
loss = loss_fn(output, target_)
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
if loss_scaler is not None:
|
|
|
|
@ -735,6 +778,48 @@ def train_one_epoch(
|
|
|
|
|
if model_ema is not None:
|
|
|
|
|
model_ema.update(model)
|
|
|
|
|
|
|
|
|
|
if loss_ is not None:
|
|
|
|
|
loss_.copy_(loss)
|
|
|
|
|
return None
|
|
|
|
|
else:
|
|
|
|
|
return loss
|
|
|
|
|
|
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
|
data_time_m.update(time.time() - end)
|
|
|
|
|
if not args.prefetcher:
|
|
|
|
|
input, target = input.cuda(), target.cuda()
|
|
|
|
|
if mixup_fn is not None:
|
|
|
|
|
input, target = mixup_fn(input, target)
|
|
|
|
|
|
|
|
|
|
if cg_stage is None: # The default non-CUDAGraph mode
|
|
|
|
|
loss = _step(input, target)
|
|
|
|
|
elif cg_stage == 'get_static_data':
|
|
|
|
|
loss = _step(input, target)
|
|
|
|
|
return (input, target, loss)
|
|
|
|
|
elif cg_stage == 'capture':
|
|
|
|
|
s = torch.cuda.Stream()
|
|
|
|
|
s.wait_stream(torch.cuda.current_stream())
|
|
|
|
|
with torch.cuda.stream(s):
|
|
|
|
|
for _ in range(11):
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss)
|
|
|
|
|
torch.cuda.current_stream().wait_stream(s)
|
|
|
|
|
|
|
|
|
|
g = torch.cuda.CUDAGraph()
|
|
|
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
with torch.cuda.graph(g):
|
|
|
|
|
_step(cg_static_input, cg_static_target, cg_static_loss)
|
|
|
|
|
return g
|
|
|
|
|
elif cg_stage == 'train':
|
|
|
|
|
cg_static_input.copy_(input)
|
|
|
|
|
cg_static_target.copy_(target)
|
|
|
|
|
cuda_graph.replay()
|
|
|
|
|
loss = cg_static_loss
|
|
|
|
|
|
|
|
|
|
if not args.distributed:
|
|
|
|
|
losses_m.update(loss.item(), input.size(0))
|
|
|
|
|
|
|
|
|
|
torch.cuda.synchronize()
|
|
|
|
|
num_updates += 1
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|