Some further create_optimizer_v2 tweaks, remove some redudnant code, add back safe model str. Benchmark step times per batch.

pull/533/head
Ross Wightman 4 years ago
parent 2bb65bd875
commit 37c71a5609

@ -217,17 +217,18 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
delta_fwd = _step() delta_fwd = _step()
total_step += delta_fwd total_step += delta_fwd
num_samples += self.batch_size num_samples += self.batch_size
if (i + 1) % self.log_freq == 0: num_steps = i + 1
if num_steps % self.log_freq == 0:
_logger.info( _logger.info(
f"Infer [{i + 1}/{self.num_bench_iter}]." f"Infer [{num_steps}/{self.num_bench_iter}]."
f" {num_samples / total_step:0.2f} samples/sec." f" {num_samples / total_step:0.2f} samples/sec."
f" {1000 * total_step / num_samples:0.3f} ms/sample.") f" {1000 * total_step / num_steps:0.3f} ms/step.")
t_run_end = self.time_fn(True) t_run_end = self.time_fn(True)
t_run_elapsed = t_run_end - t_run_start t_run_elapsed = t_run_end - t_run_start
results = dict( results = dict(
samples_per_sec=round(num_samples / t_run_elapsed, 2), samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3), step_time=round(1000 * total_step / self.num_bench_iter, 3),
batch_size=self.batch_size, batch_size=self.batch_size,
img_size=self.input_size[-1], img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2), param_count=round(self.param_count / 1e6, 2),
@ -235,7 +236,7 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
_logger.info( _logger.info(
f"Inference benchmark of {self.model_name} done. " f"Inference benchmark of {self.model_name} done. "
f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/sample") f"{results['samples_per_sec']:.2f} samples/sec, {results['step_time']:.2f} ms/step")
return results return results
@ -254,8 +255,8 @@ class TrainBenchmarkRunner(BenchmarkRunner):
self.optimizer = create_optimizer_v2( self.optimizer = create_optimizer_v2(
self.model, self.model,
opt_name=kwargs.pop('opt', 'sgd'), optimizer_name=kwargs.pop('opt', 'sgd'),
lr=kwargs.pop('lr', 1e-4)) learning_rate=kwargs.pop('lr', 1e-4))
def _gen_target(self, batch_size): def _gen_target(self, batch_size):
return torch.empty( return torch.empty(
@ -309,23 +310,24 @@ class TrainBenchmarkRunner(BenchmarkRunner):
total_fwd += delta_fwd total_fwd += delta_fwd
total_bwd += delta_bwd total_bwd += delta_bwd
total_opt += delta_opt total_opt += delta_opt
if (i + 1) % self.log_freq == 0: num_steps = (i + 1)
if num_steps % self.log_freq == 0:
total_step = total_fwd + total_bwd + total_opt total_step = total_fwd + total_bwd + total_opt
_logger.info( _logger.info(
f"Train [{i + 1}/{self.num_bench_iter}]." f"Train [{num_steps}/{self.num_bench_iter}]."
f" {num_samples / total_step:0.2f} samples/sec." f" {num_samples / total_step:0.2f} samples/sec."
f" {1000 * total_fwd / num_samples:0.3f} ms/sample fwd," f" {1000 * total_fwd / num_steps:0.3f} ms/step fwd,"
f" {1000 * total_bwd / num_samples:0.3f} ms/sample bwd," f" {1000 * total_bwd / num_steps:0.3f} ms/step bwd,"
f" {1000 * total_opt / num_samples:0.3f} ms/sample opt." f" {1000 * total_opt / num_steps:0.3f} ms/step opt."
) )
total_step = total_fwd + total_bwd + total_opt total_step = total_fwd + total_bwd + total_opt
t_run_elapsed = self.time_fn() - t_run_start t_run_elapsed = self.time_fn() - t_run_start
results = dict( results = dict(
samples_per_sec=round(num_samples / t_run_elapsed, 2), samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3), step_time=round(1000 * total_step / self.num_bench_iter, 3),
fwd_time=round(1000 * total_fwd / num_samples, 3), fwd_time=round(1000 * total_fwd / self.num_bench_iter, 3),
bwd_time=round(1000 * total_bwd / num_samples, 3), bwd_time=round(1000 * total_bwd / self.num_bench_iter, 3),
opt_time=round(1000 * total_opt / num_samples, 3), opt_time=round(1000 * total_opt / self.num_bench_iter, 3),
batch_size=self.batch_size, batch_size=self.batch_size,
img_size=self.input_size[-1], img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2), param_count=round(self.param_count / 1e6, 2),
@ -337,15 +339,16 @@ class TrainBenchmarkRunner(BenchmarkRunner):
delta_step = _step(False) delta_step = _step(False)
num_samples += self.batch_size num_samples += self.batch_size
total_step += delta_step total_step += delta_step
if (i + 1) % self.log_freq == 0: num_steps = (i + 1)
if num_steps % self.log_freq == 0:
_logger.info( _logger.info(
f"Train [{i + 1}/{self.num_bench_iter}]." f"Train [{num_steps}/{self.num_bench_iter}]."
f" {num_samples / total_step:0.2f} samples/sec." f" {num_samples / total_step:0.2f} samples/sec."
f" {1000 * total_step / num_samples:0.3f} ms/sample.") f" {1000 * total_step / num_steps:0.3f} ms/step.")
t_run_elapsed = self.time_fn() - t_run_start t_run_elapsed = self.time_fn() - t_run_start
results = dict( results = dict(
samples_per_sec=round(num_samples / t_run_elapsed, 2), samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3), step_time=round(1000 * total_step / self.num_bench_iter, 3),
batch_size=self.batch_size, batch_size=self.batch_size,
img_size=self.input_size[-1], img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2), param_count=round(self.param_count / 1e6, 2),

@ -44,14 +44,17 @@ def optimizer_kwargs(cfg):
""" cfg/argparse to kwargs helper """ cfg/argparse to kwargs helper
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn. Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
""" """
kwargs = dict(opt_name=cfg.opt, lr=cfg.lr, weight_decay=cfg.weight_decay) kwargs = dict(
optimizer_name=cfg.opt,
learning_rate=cfg.lr,
weight_decay=cfg.weight_decay,
momentum=cfg.momentum)
if getattr(cfg, 'opt_eps', None) is not None: if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None: if getattr(cfg, 'opt_betas', None) is not None:
kwargs['betas'] = cfg.opt_betas kwargs['betas'] = cfg.opt_betas
if getattr(cfg, 'opt_args', None) is not None: if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args) kwargs.update(cfg.opt_args)
kwargs['momentum'] = cfg.momentum
return kwargs return kwargs
@ -59,20 +62,17 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
""" Legacy optimizer factory for backwards compatibility. """ Legacy optimizer factory for backwards compatibility.
NOTE: Use create_optimizer_v2 for new code. NOTE: Use create_optimizer_v2 for new code.
""" """
opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum) return create_optimizer_v2(
if hasattr(args, 'opt_eps') and args.opt_eps is not None: model,
opt_args['eps'] = args.opt_eps **optimizer_kwargs(cfg=args),
if hasattr(args, 'opt_betas') and args.opt_betas is not None: filter_bias_and_bn=filter_bias_and_bn,
opt_args['betas'] = args.opt_betas )
if hasattr(args, 'opt_args') and args.opt_args is not None:
opt_args.update(args.opt_args)
return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args)
def create_optimizer_v2( def create_optimizer_v2(
model: nn.Module, model: nn.Module,
opt_name: str = 'sgd', optimizer_name: str = 'sgd',
lr: Optional[float] = None, learning_rate: Optional[float] = None,
weight_decay: float = 0., weight_decay: float = 0.,
momentum: float = 0.9, momentum: float = 0.9,
filter_bias_and_bn: bool = True, filter_bias_and_bn: bool = True,
@ -86,8 +86,8 @@ def create_optimizer_v2(
Args: Args:
model (nn.Module): model containing parameters to optimize model (nn.Module): model containing parameters to optimize
opt_name: name of optimizer to create optimizer_name: name of optimizer to create
lr: initial learning rate learning_rate: initial learning rate
weight_decay: weight decay to apply in optimizer weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs) momentum: momentum for momentum based optimizers (others may use betas via kwargs)
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
@ -96,7 +96,7 @@ def create_optimizer_v2(
Returns: Returns:
Optimizer Optimizer
""" """
opt_lower = opt_name.lower() opt_lower = optimizer_name.lower()
if weight_decay and filter_bias_and_bn: if weight_decay and filter_bias_and_bn:
skip = {} skip = {}
if hasattr(model, 'no_weight_decay'): if hasattr(model, 'no_weight_decay'):
@ -108,7 +108,7 @@ def create_optimizer_v2(
if 'fused' in opt_lower: if 'fused' in opt_lower:
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(lr=lr, weight_decay=weight_decay, **kwargs) opt_args = dict(lr=learning_rate, weight_decay=weight_decay, **kwargs)
opt_split = opt_lower.split('_') opt_split = opt_lower.split('_')
opt_lower = opt_split[-1] opt_lower = opt_split[-1]
if opt_lower == 'sgd' or opt_lower == 'nesterov': if opt_lower == 'sgd' or opt_lower == 'nesterov':
@ -132,7 +132,7 @@ def create_optimizer_v2(
elif opt_lower == 'adadelta': elif opt_lower == 'adadelta':
optimizer = optim.Adadelta(parameters, **opt_args) optimizer = optim.Adadelta(parameters, **opt_args)
elif opt_lower == 'adafactor': elif opt_lower == 'adafactor':
if not lr: if not learning_rate:
opt_args['lr'] = None opt_args['lr'] = None
optimizer = Adafactor(parameters, **opt_args) optimizer = Adafactor(parameters, **opt_args)
elif opt_lower == 'adahessian': elif opt_lower == 'adahessian':

@ -552,7 +552,7 @@ def main():
else: else:
exp_name = '-'.join([ exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"), datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model, safe_model_name(args.model),
str(data_config['input_size'][-1]) str(data_config['input_size'][-1])
]) ])
output_dir = get_outdir(args.output if args.output else './output/train', exp_name) output_dir = get_outdir(args.output if args.output else './output/train', exp_name)

Loading…
Cancel
Save