Move aot-autograd opt after model metadata used to setup data config in benchmark.py

pull/1294/head
Ross Wightman 2 years ago
parent ca991c1fa5
commit 2d7ab06503

@ -229,14 +229,14 @@ class BenchmarkRunner:
if torchscript:
self.model = torch.jit.script(self.model)
self.scripted = True
if aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model)
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256)
if aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model)
self.example_inputs = None
self.num_warm_iter = num_warm_iter
self.num_bench_iter = num_bench_iter

Loading…
Cancel
Save