pull/1271/merge
X Wang 3 years ago committed by GitHub
commit 089d9046bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,11 @@ except ImportError as e:
FlopCountAnalysis = None
has_fvcore_profiling = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -99,6 +104,10 @@ parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--cuda-graph', default=False, action='store_true',
help="Enable CUDA Graph support")
parser.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)")
# train optimizer parameters
@ -188,14 +197,16 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner:
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False,
cuda_graph=False, precision='float32', fuser='', num_warm_iter=10, num_bench_iter=50,
use_train_size=False, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
self.cuda_graph = cuda_graph
if fuser:
set_jit_fuser(fuser)
@ -220,6 +231,9 @@ class BenchmarkRunner:
if torchscript:
self.model = torch.jit.script(self.model)
self.scripted = True
if aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model)
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size']
@ -248,11 +262,11 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
self.model.eval()
def run(self):
def _step():
def _step(sync=True):
t_step_start = self.time_fn()
with self.amp_autocast():
output = self.model(self.example_inputs)
t_step_end = self.time_fn(True)
t_step_end = self.time_fn(sync)
return t_step_end - t_step_start
_logger.info(
@ -265,12 +279,28 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
for _ in range(self.num_warm_iter):
_step()
if self.cuda_graph:
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_step(sync=False)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
_step(sync=False)
total_step = 0.
num_samples = 0
t_run_start = self.time_fn()
t_run_start = self.time_fn(True)
for i in range(self.num_bench_iter):
delta_fwd = _step()
total_step += delta_fwd
if self.cuda_graph:
g.replay()
total_step = self.time_fn(True) - t_run_start
else:
delta_fwd = _step()
total_step += delta_fwd
num_samples += self.batch_size
num_steps = i + 1
if num_steps % self.log_freq == 0:
@ -332,7 +362,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
def run(self):
def _step(detail=False):
def _step(detail=False, sync=True):
self.optimizer.zero_grad() # can this be ignored?
t_start = self.time_fn()
t_fwd_end = t_start
@ -348,7 +378,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
if detail:
t_bwd_end = self.time_fn(True)
self.optimizer.step()
t_end = self.time_fn(True)
t_end = self.time_fn(sync)
if detail:
delta_fwd = t_fwd_end - t_start
delta_bwd = t_bwd_end - t_fwd_end
@ -367,7 +397,21 @@ class TrainBenchmarkRunner(BenchmarkRunner):
for _ in range(self.num_warm_iter):
_step()
t_run_start = self.time_fn()
if self.cuda_graph:
assert self.detail is False, "mode --detail is not supported with CUDA Graph in training benchmark"
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_step(sync=False)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
self.optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
_step(sync=False)
t_run_start = self.time_fn(True)
if self.detail:
total_fwd = 0.
total_bwd = 0.
@ -405,9 +449,13 @@ class TrainBenchmarkRunner(BenchmarkRunner):
total_step = 0.
num_samples = 0
for i in range(self.num_bench_iter):
delta_step = _step(False)
if self.cuda_graph:
g.replay()
total_step = self.time_fn(True) - t_run_start
else:
delta_step = _step(False)
total_step += delta_step
num_samples += self.batch_size
total_step += delta_step
num_steps = (i + 1)
if num_steps % self.log_freq == 0:
_logger.info(
@ -506,7 +554,8 @@ def benchmark(args):
args.precision = 'amp'
_logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}')
f'torchscript {"enabled" if args.torchscript else "disabled"}. '
f'cuda graph {"enabled" if args.cuda_graph else "disabled"}')
bench_kwargs = vars(args).copy()
bench_kwargs.pop('amp')

@ -17,7 +17,7 @@ from .clip_grad import dispatch_clip_grad
class ApexScaler:
state_dict_key = "amp"
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True):
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward(create_graph=create_graph)
if clip_grad is not None:
@ -32,6 +32,8 @@ class ApexScaler:
if 'load_state_dict' in amp.__dict__:
amp.load_state_dict(state_dict)
def step_and_update(self):
pass
class NativeScaler:
state_dict_key = "amp_scaler"
@ -39,17 +41,25 @@ class NativeScaler:
def __init__(self):
self._scaler = torch.cuda.amp.GradScaler()
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False):
def __call__(self, loss, optimizer, clip_grad=None, clip_mode='norm', parameters=None, create_graph=False, step_and_update=True):
self._scaler.scale(loss).backward(create_graph=create_graph)
self._optimizer = optimizer
if clip_grad is not None:
assert parameters is not None
self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
dispatch_clip_grad(parameters, clip_grad, mode=clip_mode)
self._scaler.step(optimizer)
self._scaler.update()
if step_and_update:
self._scaler.step(optimizer)
self._scaler.update()
def state_dict(self):
return self._scaler.state_dict()
def load_state_dict(self, state_dict):
self._scaler.load_state_dict(state_dict)
def step_and_update(self):
# need to separate this step from "scaler.scale(loss)" since scalar.step() syncs GPU with CPU,
# which is not allowed for cuda graph capture
self._scaler.step(self._optimizer)
self._scaler.update()

@ -61,6 +61,12 @@ try:
except ImportError:
has_wandb = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train')
@ -129,6 +135,10 @@ group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--cuda-graph', default=False, action='store_true',
help="Enable CUDA Graph support")
group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` but without `--torchscript`)")
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
@ -354,6 +364,10 @@ def main():
args.world_size = 1
args.rank = 0 # global rank
if args.distributed:
if args.cuda_graph:
os.environ['NCCL_ASYNC_ERROR_HANDLING'] = '0'
if 'LOCAL_RANK' in os.environ:
args.local_rank = int(os.environ['LOCAL_RANK'])
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
@ -365,6 +379,9 @@ def main():
_logger.info('Training with a single process on 1 GPUs.')
assert args.rank >= 0
if args.cuda_graph:
_logger.info("Training with CUDA Graph support is enabled.")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
if args.amp:
@ -445,6 +462,9 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
@ -494,7 +514,13 @@ def main():
else:
if args.local_rank == 0:
_logger.info("Using native Torch DistributedDataParallel.")
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
# Wrap DDP model in a side stream is needed for cuda graph workload.
# Since it does no harm to regular eager mode, we'll just use this with one less if-statement.
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
model = NativeDDP(model, device_ids=[args.local_rank], broadcast_buffers=not args.no_ddp_bb)
torch.cuda.current_stream().wait_stream(s)
# NOTE: EMA model does not need to be wrapped by DDP
# setup learning rate schedule and starting epoch
@ -636,14 +662,51 @@ def main():
f.write(args_text)
try:
cuda_graph_ = None
cg_static_input = None
cg_static_target = None
cg_static_loss = None
for epoch in range(start_epoch, num_epochs):
if args.distributed and hasattr(loader_train.sampler, 'set_epoch'):
loader_train.sampler.set_epoch(epoch)
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
if not args.cuda_graph:
# The default eager mode without cuda graph.
train_metrics = train_one_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn)
else:
# Replay an already captured cuda graph from previous epochs.
if cuda_graph_ is not None:
train_metrics = train_one_epoch( # cuda graph replay
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
cuda_graph=cuda_graph_, cg_stage='replay',
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
# Capture a cuda graph.
else:
input_, target_, loss_ = train_one_epoch( # cuda graph get_static_shapes
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
cg_stage='get_static_shapes')
cg_static_input = torch.empty_like(input_)
cg_static_target = torch.empty_like(target_)
cg_static_loss = torch.empty_like(loss_)
cuda_graph_ = train_one_epoch( # cuda graph capture
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
cuda_graph=None, cg_stage='capture',
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
train_metrics = train_one_epoch( # cuda graph replay
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
amp_autocast=amp_autocast, loss_scaler=loss_scaler, model_ema=model_ema, mixup_fn=mixup_fn,
cuda_graph=cuda_graph_, cg_stage='replay',
cg_static_input=cg_static_input, cg_static_target=cg_static_target, cg_static_loss=cg_static_loss)
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
if args.local_rank == 0:
@ -682,7 +745,9 @@ def main():
def train_one_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir=None, amp_autocast=suppress,
loss_scaler=None, model_ema=None, mixup_fn=None):
loss_scaler=None, model_ema=None, mixup_fn=None,
cuda_graph=None, cg_stage=None,
cg_static_input=None, cg_static_target=None, cg_static_loss=None):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
@ -700,22 +765,16 @@ def train_one_epoch(
end = time.time()
last_idx = len(loader) - 1
num_updates = epoch * len(loader)
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
def _step(input_, target_, loss_=None, scaler_step_and_update=True):
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
input = input_.contiguous(memory_format=torch.channels_last)
else:
input = input_
with amp_autocast():
output = model(input)
loss = loss_fn(output, target)
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
loss = loss_fn(output, target_)
optimizer.zero_grad()
if loss_scaler is not None:
@ -723,7 +782,8 @@ def train_one_epoch(
loss, optimizer,
clip_grad=args.clip_grad, clip_mode=args.clip_mode,
parameters=model_parameters(model, exclude_head='agc' in args.clip_mode),
create_graph=second_order)
create_graph=second_order,
step_and_update=scaler_step_and_update)
else:
loss.backward(create_graph=second_order)
if args.clip_grad is not None:
@ -735,6 +795,55 @@ def train_one_epoch(
if model_ema is not None:
model_ema.update(model)
if loss_ is not None:
loss_.copy_(loss)
return None
else:
return loss
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
data_time_m.update(time.time() - end)
if not args.prefetcher:
input, target = input.cuda(), target.cuda()
if mixup_fn is not None:
input, target = mixup_fn(input, target)
if cg_stage is None: # The default non-CUDAGraph mode
loss = _step(input, target)
elif cg_stage == 'get_static_shapes':
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
optimizer.zero_grad(set_to_none=True)
loss = _step(input, target)
torch.cuda.current_stream().wait_stream(s)
return (input, target, loss)
elif cg_stage == 'capture':
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(11):
optimizer.zero_grad(set_to_none=True)
_step(input, target)
torch.cuda.current_stream().wait_stream(s)
g = torch.cuda.CUDAGraph()
optimizer.zero_grad(set_to_none=True)
with torch.cuda.graph(g):
_step(cg_static_input, cg_static_target, cg_static_loss, scaler_step_and_update=False)
return g
elif cg_stage == 'replay':
cg_static_input.copy_(input)
cg_static_target.copy_(target)
cuda_graph.replay()
loss = cg_static_loss
if loss_scaler is not None:
loss_scaler.step_and_update()
if not args.distributed:
losses_m.update(loss.item(), input.size(0))
torch.cuda.synchronize()
num_updates += 1
batch_time_m.update(time.time() - end)

Loading…
Cancel
Save