|
|
|
@ -19,7 +19,7 @@ from contextlib import suppress
|
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
|
|
from timm.models import create_model, is_model, list_models
|
|
|
|
|
from timm.optim import create_optimizer
|
|
|
|
|
from timm.optim import create_optimizer_v2
|
|
|
|
|
from timm.data import resolve_data_config
|
|
|
|
|
from timm.utils import AverageMeter, setup_default_logging
|
|
|
|
|
|
|
|
|
@ -53,6 +53,10 @@ parser.add_argument('--detail', action='store_true', default=False,
|
|
|
|
|
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
|
|
|
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
|
|
|
|
help='Output csv file for validation results (summary)')
|
|
|
|
|
parser.add_argument('--num-warm-iter', default=10, type=int,
|
|
|
|
|
metavar='N', help='Number of warmup iterations (default: 10)')
|
|
|
|
|
parser.add_argument('--num-bench-iter', default=40, type=int,
|
|
|
|
|
metavar='N', help='Number of benchmark iterations (default: 40)')
|
|
|
|
|
|
|
|
|
|
# common inference / train args
|
|
|
|
|
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
|
|
|
|
@ -70,11 +74,9 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
|
|
|
|
parser.add_argument('--channels-last', action='store_true', default=False,
|
|
|
|
|
help='Use channels_last memory layout')
|
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
|
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
|
|
|
|
parser.add_argument('--apex-amp', action='store_true', default=False,
|
|
|
|
|
help='Use NVIDIA Apex AMP mixed precision')
|
|
|
|
|
parser.add_argument('--native-amp', action='store_true', default=False,
|
|
|
|
|
help='Use Native Torch AMP mixed precision')
|
|
|
|
|
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
|
|
|
|
parser.add_argument('--precision', default='float32', type=str,
|
|
|
|
|
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
|
|
|
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
|
|
|
|
help='convert model torchscript for inference')
|
|
|
|
|
|
|
|
|
@ -117,28 +119,50 @@ def cuda_timestamp(sync=False, device=None):
|
|
|
|
|
return time.perf_counter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def count_params(model):
|
|
|
|
|
def count_params(model: nn.Module):
|
|
|
|
|
return sum([m.numel() for m in model.parameters()])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_precision(precision: str):
|
|
|
|
|
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
|
|
|
|
|
use_amp = False
|
|
|
|
|
model_dtype = torch.float32
|
|
|
|
|
data_dtype = torch.float32
|
|
|
|
|
if precision == 'amp':
|
|
|
|
|
use_amp = True
|
|
|
|
|
elif precision == 'float16':
|
|
|
|
|
model_dtype = torch.float16
|
|
|
|
|
data_dtype = torch.float16
|
|
|
|
|
elif precision == 'bfloat16':
|
|
|
|
|
model_dtype = torch.bfloat16
|
|
|
|
|
data_dtype = torch.bfloat16
|
|
|
|
|
return use_amp, model_dtype, data_dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BenchmarkRunner:
|
|
|
|
|
def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs):
|
|
|
|
|
def __init__(
|
|
|
|
|
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
|
|
|
|
num_warm_iter=10, num_bench_iter=50, **kwargs):
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
self.detail = detail
|
|
|
|
|
self.device = device
|
|
|
|
|
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
|
|
|
|
self.channels_last = kwargs.pop('channels_last', False)
|
|
|
|
|
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
|
|
|
|
|
|
|
|
|
self.model = create_model(
|
|
|
|
|
model_name,
|
|
|
|
|
num_classes=kwargs.pop('num_classes', None),
|
|
|
|
|
in_chans=3,
|
|
|
|
|
global_pool=kwargs.pop('gp', 'fast'),
|
|
|
|
|
scriptable=torchscript).to(device=self.device)
|
|
|
|
|
scriptable=torchscript)
|
|
|
|
|
self.model.to(
|
|
|
|
|
device=self.device,
|
|
|
|
|
dtype=self.model_dtype,
|
|
|
|
|
memory_format=torch.channels_last if self.channels_last else None)
|
|
|
|
|
self.num_classes = self.model.num_classes
|
|
|
|
|
self.param_count = count_params(self.model)
|
|
|
|
|
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
|
|
|
|
|
|
|
|
|
|
self.channels_last = kwargs.pop('channels_last', False)
|
|
|
|
|
self.use_amp = kwargs.pop('use_amp', '')
|
|
|
|
|
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress
|
|
|
|
|
if torchscript:
|
|
|
|
|
self.model = torch.jit.script(self.model)
|
|
|
|
|
|
|
|
|
@ -147,16 +171,17 @@ class BenchmarkRunner:
|
|
|
|
|
self.batch_size = kwargs.pop('batch_size', 256)
|
|
|
|
|
|
|
|
|
|
self.example_inputs = None
|
|
|
|
|
self.num_warm_iter = 10
|
|
|
|
|
self.num_bench_iter = 50
|
|
|
|
|
self.log_freq = 10
|
|
|
|
|
self.num_warm_iter = num_warm_iter
|
|
|
|
|
self.num_bench_iter = num_bench_iter
|
|
|
|
|
self.log_freq = num_bench_iter // 5
|
|
|
|
|
if 'cuda' in self.device:
|
|
|
|
|
self.time_fn = partial(cuda_timestamp, device=self.device)
|
|
|
|
|
else:
|
|
|
|
|
self.time_fn = timestamp
|
|
|
|
|
|
|
|
|
|
def _init_input(self):
|
|
|
|
|
self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device)
|
|
|
|
|
self.example_inputs = torch.randn(
|
|
|
|
|
(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
|
|
|
|
|
if self.channels_last:
|
|
|
|
|
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
@ -166,10 +191,6 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
|
|
|
|
|
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
|
|
|
|
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
|
|
|
|
self.model.eval()
|
|
|
|
|
if self.use_amp == 'apex':
|
|
|
|
|
self.model = amp.initialize(self.model, opt_level='O1')
|
|
|
|
|
if self.channels_last:
|
|
|
|
|
self.model = self.model.to(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
|
def run(self):
|
|
|
|
|
def _step():
|
|
|
|
@ -231,16 +252,11 @@ class TrainBenchmarkRunner(BenchmarkRunner):
|
|
|
|
|
self.loss = nn.CrossEntropyLoss().to(self.device)
|
|
|
|
|
self.target_shape = tuple()
|
|
|
|
|
|
|
|
|
|
self.optimizer = create_optimizer(
|
|
|
|
|
self.optimizer = create_optimizer_v2(
|
|
|
|
|
self.model,
|
|
|
|
|
opt_name=kwargs.pop('opt', 'sgd'),
|
|
|
|
|
lr=kwargs.pop('lr', 1e-4))
|
|
|
|
|
|
|
|
|
|
if self.use_amp == 'apex':
|
|
|
|
|
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')
|
|
|
|
|
if self.channels_last:
|
|
|
|
|
self.model = self.model.to(memory_format=torch.channels_last)
|
|
|
|
|
|
|
|
|
|
def _gen_target(self, batch_size):
|
|
|
|
|
return torch.empty(
|
|
|
|
|
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
|
|
|
|
@ -331,6 +347,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
|
|
|
|
|
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
|
|
|
|
step_time=round(1000 * total_step / num_samples, 3),
|
|
|
|
|
batch_size=self.batch_size,
|
|
|
|
|
img_size=self.input_size[-1],
|
|
|
|
|
param_count=round(self.param_count / 1e6, 2),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
@ -367,23 +384,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
|
|
|
|
|
|
|
|
|
def benchmark(args):
|
|
|
|
|
if args.amp:
|
|
|
|
|
if has_native_amp:
|
|
|
|
|
args.native_amp = True
|
|
|
|
|
elif has_apex:
|
|
|
|
|
args.apex_amp = True
|
|
|
|
|
else:
|
|
|
|
|
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
|
|
|
|
if args.native_amp:
|
|
|
|
|
args.use_amp = 'native'
|
|
|
|
|
_logger.info('Benchmarking in mixed precision with native PyTorch AMP.')
|
|
|
|
|
elif args.apex_amp:
|
|
|
|
|
args.use_amp = 'apex'
|
|
|
|
|
_logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.')
|
|
|
|
|
else:
|
|
|
|
|
args.use_amp = ''
|
|
|
|
|
_logger.info('Benchmarking in float32. AMP not enabled.')
|
|
|
|
|
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
|
|
|
|
|
args.precision = 'amp'
|
|
|
|
|
_logger.info(f'Benchmarking in {args.precision} precision. '
|
|
|
|
|
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
|
|
|
|
|
f'torchscript {"enabled" if args.torchscript else "disabled"}')
|
|
|
|
|
|
|
|
|
|
bench_kwargs = vars(args).copy()
|
|
|
|
|
bench_kwargs.pop('amp')
|
|
|
|
|
model = bench_kwargs.pop('model')
|
|
|
|
|
batch_size = bench_kwargs.pop('batch_size')
|
|
|
|
|
|
|
|
|
|