diff --git a/benchmark.py b/benchmark.py index ccd9b4fa..f1604a04 100755 --- a/benchmark.py +++ b/benchmark.py @@ -199,7 +199,11 @@ class BenchmarkRunner: num_classes=kwargs.pop('num_classes', None), in_chans=3, global_pool=kwargs.pop('gp', 'fast'), - scriptable=torchscript) + scriptable=torchscript, + drop_rate=kwargs.pop('drop', 0.), + drop_path_rate=kwargs.pop('drop_path', None), + drop_block_rate=kwargs.pop('drop_block', None), + ) self.model.to( device=self.device, dtype=self.model_dtype,