diff --git a/validate.py b/validate.py index fd55d408..a4d41868 100755 --- a/validate.py +++ b/validate.py @@ -181,7 +181,7 @@ def validate(args): if args.torchscript: torch.jit.optimized_execution(True) - model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size'])) + model = torch.jit.script(model) if args.aot_autograd: assert has_functorch, "functorch is needed for --aot-autograd" model = memory_efficient_fusion(model)