@ -19,7 +19,7 @@ from contextlib import suppress
from functools import partial
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer
from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
@ -53,6 +53,10 @@ parser.add_argument('--detail', action='store_true', default=False,
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--num-warm-iter', default=10, type=int,
metavar='N', help='Number of warmup iterations (default: 10)')
parser.add_argument('--num-bench-iter', default=40, type=int,
metavar='N', help='Number of benchmark iterations (default: 40)')
# common inference / train args
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
@ -70,11 +74,9 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL',
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
@ -117,28 +119,50 @@ def cuda_timestamp(sync=False, device=None):
return time.perf_counter()
def count_params(model):
def count_params(model: nn.Module):
return sum([m.numel() for m in model.parameters()])
def resolve_precision(precision: str):
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
use_amp = False
model_dtype = torch.float32
data_dtype = torch.float32
if precision == 'amp':
use_amp = True
elif precision == 'float16':
model_dtype = torch.float16
data_dtype = torch.float16
elif precision == 'bfloat16':
model_dtype = torch.bfloat16
data_dtype = torch.bfloat16
return use_amp, model_dtype, data_dtype
class BenchmarkRunner:
def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs):
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
num_warm_iter=10, num_bench_iter=50, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
self.model = create_model(
num_classes=kwargs.pop('num_classes', None),
global_pool=kwargs.pop('gp', 'fast'),
memory_format=torch.channels_last if self.channels_last else None)
self.num_classes = self.model.num_classes
self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
self.channels_last = kwargs.pop('channels_last', False)
self.use_amp = kwargs.pop('use_amp', '')
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress
if torchscript:
self.model = torch.jit.script(self.model)
@ -147,16 +171,17 @@ class BenchmarkRunner:
self.batch_size = kwargs.pop('batch_size', 256)
self.example_inputs = None
self.num_warm_iter = 10
self.num_bench_iter = 50
self.log_freq = 10
self.num_warm_iter = num_warm_iter
self.num_bench_iter = num_bench_iter
self.log_freq = num_bench_iter // 5
if 'cuda' in self.device:
self.time_fn = partial(cuda_timestamp, device=self.device)
self.time_fn = timestamp
def _init_input(self):
self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device)
self.example_inputs = torch.randn(
(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
if self.channels_last:
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
@ -166,10 +191,6 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
if self.use_amp == 'apex':
self.model = amp.initialize(self.model, opt_level='O1')
if self.channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
def run(self):
def _step():
@ -231,16 +252,11 @@ class TrainBenchmarkRunner(BenchmarkRunner):
self.loss = nn.CrossEntropyLoss().to(self.device)
self.target_shape = tuple()
self.optimizer = create_optimizer(
self.optimizer = create_optimizer_v2(
opt_name=kwargs.pop('opt', 'sgd'),
lr=kwargs.pop('lr', 1e-4))
if self.use_amp == 'apex':
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')
if self.channels_last:
self.model = self.model.to(memory_format=torch.channels_last)
def _gen_target(self, batch_size):
return torch.empty(
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
@ -331,6 +347,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
samples_per_sec=round(num_samples / t_run_elapsed, 2),
step_time=round(1000 * total_step / num_samples, 3),
param_count=round(self.param_count / 1e6, 2),
@ -367,23 +384,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
def benchmark(args):
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
_logger.warning("Neither APEX or Native Torch AMP is available.")
if args.native_amp:
args.use_amp = 'native'
_logger.info('Benchmarking in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
args.use_amp = 'apex'
_logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.')
args.use_amp = ''
_logger.info('Benchmarking in float32. AMP not enabled.')
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
args.precision = 'amp'
_logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}')
bench_kwargs = vars(args).copy()
model = bench_kwargs.pop('model')
batch_size = bench_kwargs.pop('batch_size')