@ -71,6 +71,8 @@ parser.add_argument('--bench', default='both', type=str,
help = " Benchmark mode. One of ' inference ' , ' train ' , ' both ' . Defaults to ' both ' " )
help = " Benchmark mode. One of ' inference ' , ' train ' , ' both ' . Defaults to ' both ' " )
parser . add_argument ( ' --detail ' , action = ' store_true ' , default = False ,
parser . add_argument ( ' --detail ' , action = ' store_true ' , default = False ,
help = ' Provide train fwd/bwd/opt breakdown detail if True. Defaults to False ' )
help = ' Provide train fwd/bwd/opt breakdown detail if True. Defaults to False ' )
parser . add_argument ( ' --no-retry ' , action = ' store_true ' , default = False ,
help = ' Do not decay batch size and retry on error. ' )
parser . add_argument ( ' --results-file ' , default = ' ' , type = str , metavar = ' FILENAME ' ,
parser . add_argument ( ' --results-file ' , default = ' ' , type = str , metavar = ' FILENAME ' ,
help = ' Output csv file for validation results (summary) ' )
help = ' Output csv file for validation results (summary) ' )
parser . add_argument ( ' --num-warm-iter ' , default = 10 , type = int ,
parser . add_argument ( ' --num-warm-iter ' , default = 10 , type = int ,
@ -169,10 +171,9 @@ def resolve_precision(precision: str):
def profile_deepspeed ( model , input_size = ( 3 , 224 , 224 ) , batch_size = 1 , detailed = False ) :
def profile_deepspeed ( model , input_size = ( 3 , 224 , 224 ) , batch_size = 1 , detailed = False ) :
macs, _ = get_model_profile (
_, macs, _ = get_model_profile (
model = model ,
model = model ,
input_res = ( batch_size , ) + input_size , # input shape or input to the input_constructor
input_shape = ( batch_size , ) + input_size , # input shape/resolution
input_constructor = None , # if specified, a constructor taking input_res is used as input to the model
print_profile = detailed , # prints the model graph with the measured profile attached to each module
print_profile = detailed , # prints the model graph with the measured profile attached to each module
detailed = detailed , # print the detailed profile
detailed = detailed , # print the detailed profile
warm_up = 10 , # the number of warm-ups before measuring the time of each module
warm_up = 10 , # the number of warm-ups before measuring the time of each module
@ -197,8 +198,19 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner :
class BenchmarkRunner :
def __init__ (
def __init__ (
self , model_name , detail = False , device = ' cuda ' , torchscript = False , aot_autograd = False , precision = ' float32 ' ,
self ,
fuser = ' ' , num_warm_iter = 10 , num_bench_iter = 50 , use_train_size = False , * * kwargs ) :
model_name ,
detail = False ,
device = ' cuda ' ,
torchscript = False ,
aot_autograd = False ,
precision = ' float32 ' ,
fuser = ' ' ,
num_warm_iter = 10 ,
num_bench_iter = 50 ,
use_train_size = False ,
* * kwargs
) :
self . model_name = model_name
self . model_name = model_name
self . detail = detail
self . detail = detail
self . device = device
self . device = device
@ -256,7 +268,13 @@ class BenchmarkRunner:
class InferenceBenchmarkRunner ( BenchmarkRunner ) :
class InferenceBenchmarkRunner ( BenchmarkRunner ) :
def __init__ ( self , model_name , device = ' cuda ' , torchscript = False , * * kwargs ) :
def __init__ (
self ,
model_name ,
device = ' cuda ' ,
torchscript = False ,
* * kwargs
) :
super ( ) . __init__ ( model_name = model_name , device = device , torchscript = torchscript , * * kwargs )
super ( ) . __init__ ( model_name = model_name , device = device , torchscript = torchscript , * * kwargs )
self . model . eval ( )
self . model . eval ( )
@ -325,7 +343,13 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
class TrainBenchmarkRunner ( BenchmarkRunner ) :
class TrainBenchmarkRunner ( BenchmarkRunner ) :
def __init__ ( self , model_name , device = ' cuda ' , torchscript = False , * * kwargs ) :
def __init__ (
self ,
model_name ,
device = ' cuda ' ,
torchscript = False ,
* * kwargs
) :
super ( ) . __init__ ( model_name = model_name , device = device , torchscript = torchscript , * * kwargs )
super ( ) . __init__ ( model_name = model_name , device = device , torchscript = torchscript , * * kwargs )
self . model . train ( )
self . model . train ( )
@ -492,7 +516,7 @@ def decay_batch_exp(batch_size, factor=0.5, divisor=16):
return max ( 0 , int ( out_batch_size ) )
return max ( 0 , int ( out_batch_size ) )
def _try_run ( model_name , bench_fn , initial_batch_size, bench_kwargs ) :
def _try_run ( model_name , bench_fn , bench_kwargs, initial_batch_size , no_batch_size_retry = False ) :
batch_size = initial_batch_size
batch_size = initial_batch_size
results = dict ( )
results = dict ( )
error_str = ' Unknown '
error_str = ' Unknown '
@ -507,8 +531,11 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
if ' channels_last ' in error_str :
if ' channels_last ' in error_str :
_logger . error ( f ' { model_name } not supported in channels_last, skipping. ' )
_logger . error ( f ' { model_name } not supported in channels_last, skipping. ' )
break
break
_logger . warning ( f ' " { error_str } " while running benchmark. Reducing batch size to { batch_size } for retry. ' )
_logger . error ( f ' " { error_str } " while running benchmark. ' )
if no_batch_size_retry :
break
batch_size = decay_batch_exp ( batch_size )
batch_size = decay_batch_exp ( batch_size )
_logger . warning ( f ' Reducing batch size to { batch_size } for retry. ' )
results [ ' error ' ] = error_str
results [ ' error ' ] = error_str
return results
return results
@ -550,7 +577,13 @@ def benchmark(args):
model_results = OrderedDict ( model = model )
model_results = OrderedDict ( model = model )
for prefix , bench_fn in zip ( prefixes , bench_fns ) :
for prefix , bench_fn in zip ( prefixes , bench_fns ) :
run_results = _try_run ( model , bench_fn , initial_batch_size = batch_size , bench_kwargs = bench_kwargs )
run_results = _try_run (
model ,
bench_fn ,
bench_kwargs = bench_kwargs ,
initial_batch_size = batch_size ,
no_batch_size_retry = args . no_retry ,
)
if prefix and ' error ' not in run_results :
if prefix and ' error ' not in run_results :
run_results = { ' _ ' . join ( [ prefix , k ] ) : v for k , v in run_results . items ( ) }
run_results = { ' _ ' . join ( [ prefix , k ] ) : v for k , v in run_results . items ( ) }
model_results . update ( run_results )
model_results . update ( run_results )