|
|
@ -199,7 +199,11 @@ class BenchmarkRunner:
|
|
|
|
num_classes=kwargs.pop('num_classes', None),
|
|
|
|
num_classes=kwargs.pop('num_classes', None),
|
|
|
|
in_chans=3,
|
|
|
|
in_chans=3,
|
|
|
|
global_pool=kwargs.pop('gp', 'fast'),
|
|
|
|
global_pool=kwargs.pop('gp', 'fast'),
|
|
|
|
scriptable=torchscript)
|
|
|
|
scriptable=torchscript,
|
|
|
|
|
|
|
|
drop_rate=kwargs.pop('drop', 0.),
|
|
|
|
|
|
|
|
drop_path_rate=kwargs.pop('drop_path', None),
|
|
|
|
|
|
|
|
drop_block_rate=kwargs.pop('drop_block', None),
|
|
|
|
|
|
|
|
)
|
|
|
|
self.model.to(
|
|
|
|
self.model.to(
|
|
|
|
device=self.device,
|
|
|
|
device=self.device,
|
|
|
|
dtype=self.model_dtype,
|
|
|
|
dtype=self.model_dtype,
|
|
|
|