Support either deepspeed or fvcore for flop profiling

pull/933/head
Ross Wightman 3 years ago
parent 66253790d4
commit f7325c7b71

@ -18,11 +18,6 @@ from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from functools import partial from functools import partial
try:
from deepspeed.profiling.flops_profiler import get_model_profile
except ImportError as e:
get_model_profile = None
from timm.models import create_model, is_model, list_models from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2 from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config from timm.data import resolve_data_config
@ -43,6 +38,20 @@ try:
except AttributeError: except AttributeError:
pass pass
try:
from deepspeed.profiling.flops_profiler import get_model_profile
has_deepspeed_profiling = True
except ImportError as e:
has_deepspeed_profiling = False
try:
from fvcore.nn import FlopCountAnalysis, flop_count_str
has_fvcore_profiling = True
except ImportError as e:
FlopCountAnalysis = None
has_fvcore_profiling = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -147,9 +156,8 @@ def resolve_precision(precision: str):
return use_amp, model_dtype, data_dtype return use_amp, model_dtype, data_dtype
def profile(model, input_size=(3, 224, 224), detailed=False): def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
batch_size = 1 macs, _ = get_model_profile(
macs, params = get_model_profile(
model=model, model=model,
input_res=(batch_size,) + input_size, # input shape or input to the input_constructor input_res=(batch_size,) + input_size, # input shape or input to the input_constructor
input_constructor=None, # if specified, a constructor taking input_res is used as input to the model input_constructor=None, # if specified, a constructor taking input_res is used as input to the model
@ -159,7 +167,16 @@ def profile(model, input_size=(3, 224, 224), detailed=False):
as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k) as_string=False, # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
output_file=None, # path to the output file. If None, the profiler prints to stdout. output_file=None, # path to the output file. If None, the profiler prints to stdout.
ignore_modules=None) # the list of modules to ignore in the profiling ignore_modules=None) # the list of modules to ignore in the profiling
return macs, params return macs
def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
device, dtype = next(model.parameters()).device, next(model.parameters()).dtype
fca = FlopCountAnalysis(model, torch.ones((batch_size,) + input_size, device=device, dtype=dtype))
if detailed:
fcs = flop_count_str(fca)
print(fcs)
return fca.total()
class BenchmarkRunner: class BenchmarkRunner:
@ -257,8 +274,11 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
param_count=round(self.param_count / 1e6, 2), param_count=round(self.param_count / 1e6, 2),
) )
if get_model_profile is not None: if has_deepspeed_profiling:
macs, _ = profile(self.model, self.input_size) macs = profile_deepspeed(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2)
elif has_fvcore_profiling:
macs = profile_fvcore(self.model, self.input_size)
results['gmacs'] = round(macs / 1e9, 2) results['gmacs'] = round(macs / 1e9, 2)
_logger.info( _logger.info(
@ -390,21 +410,33 @@ class TrainBenchmarkRunner(BenchmarkRunner):
class ProfileRunner(BenchmarkRunner): class ProfileRunner(BenchmarkRunner):
def __init__(self, model_name, device='cuda', **kwargs): def __init__(self, model_name, device='cuda', profiler='', **kwargs):
super().__init__(model_name=model_name, device=device, **kwargs) super().__init__(model_name=model_name, device=device, **kwargs)
if not profiler:
if has_deepspeed_profiling:
profiler = 'deepspeed'
elif has_fvcore_profiling:
profiler = 'fvcore'
assert profiler, "One of deepspeed or fvcore needs to be installed for profiling to work."
self.profiler = profiler
self.model.eval() self.model.eval()
def run(self): def run(self):
_logger.info( _logger.info(
f'Running profiler on {self.model_name} w/ ' f'Running profiler on {self.model_name} w/ '
f'input size {self.input_size} and batch size 1.') f'input size {self.input_size} and batch size {self.batch_size}.')
macs, params = profile(self.model, self.input_size, detailed=True) macs = 0
if self.profiler == 'deepspeed':
macs = profile_deepspeed(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
elif self.profiler == 'fvcore':
macs = profile_fvcore(self.model, self.input_size, batch_size=self.batch_size, detailed=True)
results = dict( results = dict(
gmacs=round(macs / 1e9, 2), gmacs=round(macs / 1e9, 2),
batch_size=self.batch_size,
img_size=self.input_size[-1], img_size=self.input_size[-1],
param_count=round(params / 1e6, 2), param_count=round(self.param_count / 1e6, 2),
) )
_logger.info( _logger.info(
@ -462,9 +494,16 @@ def benchmark(args):
elif args.bench == 'train': elif args.bench == 'train':
bench_fns = TrainBenchmarkRunner, bench_fns = TrainBenchmarkRunner,
prefixes = 'train', prefixes = 'train',
elif args.bench == 'profile': elif args.bench.startswith('profile'):
assert get_model_profile is not None, "deepspeed needs to be installed for profile" # specific profiler used if included in bench mode string, otherwise default to deepspeed, fallback to fvcore
if 'deepspeed' in args.bench:
assert has_deepspeed_profiling, "deepspeed must be installed to use deepspeed flop counter"
bench_kwargs['profiler'] = 'deepspeed'
elif 'fvcore' in args.bench:
assert has_fvcore_profiling, "fvcore must be installed to use fvcore flop counter"
bench_kwargs['profiler'] = 'fvcore'
bench_fns = ProfileRunner, bench_fns = ProfileRunner,
batch_size = 1
model_results = OrderedDict(model=model) model_results = OrderedDict(model=model)
for prefix, bench_fn in zip(prefixes, bench_fns): for prefix, bench_fn in zip(prefixes, bench_fns):
@ -520,12 +559,10 @@ def main():
results = sorted(results, key=lambda x: x[sort_key], reverse=True) results = sorted(results, key=lambda x: x[sort_key], reverse=True)
if len(results): if len(results):
write_results(results_file, results) write_results(results_file, results)
import json
json_str = json.dumps(results, indent=4)
print(json_str)
else: else:
benchmark(args) results = benchmark(args)
json_str = json.dumps(results, indent=4)
print(json_str)
def write_results(results_file, results): def write_results(results_file, results):

Loading…
Cancel
Save