diff --git a/benchmark.py b/benchmark.py index e2370dcc..f348fcb9 100755 --- a/benchmark.py +++ b/benchmark.py @@ -229,14 +229,14 @@ class BenchmarkRunner: if torchscript: self.model = torch.jit.script(self.model) self.scripted = True - if aot_autograd: - assert has_functorch, "functorch is needed for --aot-autograd" - self.model = memory_efficient_fusion(self.model) - data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) self.input_size = data_config['input_size'] self.batch_size = kwargs.pop('batch_size', 256) + if aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + self.model = memory_efficient_fusion(self.model) + self.example_inputs = None self.num_warm_iter = num_warm_iter self.num_bench_iter = num_bench_iter