Don't run profile if model is torchscripted

more_datasets
Ross Wightman 3 years ago
parent 7da1b0b61c
commit 71f00bfe9e

@ -205,8 +205,10 @@ class BenchmarkRunner:
self.num_classes = self.model.num_classes
self.param_count = count_params(self.model)
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
self.scripted = False
if torchscript:
self.model = torch.jit.script(self.model)
self.scripted = True
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size']
@ -275,14 +277,14 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
img_size=self.input_size[-1],
param_count=round(self.param_count / 1e6, 2),
)
if has_deepspeed_profiling:
macs, _ = profile_deepspeed(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
elif has_fvcore_profiling:
macs, activations = profile_fvcore(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
results['macts'] = round(activations / 1e6, 2)
if not self.scripted:
if has_deepspeed_profiling:
macs, _ = profile_deepspeed(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
elif has_fvcore_profiling:
macs, activations = profile_fvcore(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
results['macts'] = round(activations / 1e6, 2)
_logger.info(
f"Inference benchmark of {self.model_name} done. "

Loading…
Cancel
Save