diff --git a/benchmark.py b/benchmark.py index f348fcb9..1362eeab 100755 --- a/benchmark.py +++ b/benchmark.py @@ -225,11 +225,12 @@ class BenchmarkRunner: self.num_classes = self.model.num_classes self.param_count = count_params(self.model) _logger.info('Model %s created, param count: %d' % (model_name, self.param_count)) + + data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.scripted = False if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256)