Add drop args to benchmark.py

pull/1239/head
Ross Wightman 2 years ago
parent d829858550
commit 1c21cac8f9

@ -199,7 +199,11 @@ class BenchmarkRunner:
num_classes=kwargs.pop('num_classes', None), num_classes=kwargs.pop('num_classes', None),
in_chans=3, in_chans=3,
global_pool=kwargs.pop('gp', 'fast'), global_pool=kwargs.pop('gp', 'fast'),
scriptable=torchscript) scriptable=torchscript,
drop_rate=kwargs.pop('drop', 0.),
drop_path_rate=kwargs.pop('drop_path', None),
drop_block_rate=kwargs.pop('drop_block', None),
)
self.model.to( self.model.to(
device=self.device, device=self.device,
dtype=self.model_dtype, dtype=self.model_dtype,

Loading…
Cancel
Save