From 500c190860bb80da348dd719dd8f0b73e44f0854 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 7 Jul 2022 15:15:25 -0700 Subject: [PATCH] Add --aot-autograd (functorch efficient mem fusion) support to validate.py --- validate.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/validate.py b/validate.py index 27b88299..708ac2e5 100755 --- a/validate.py +++ b/validate.py @@ -38,6 +38,12 @@ try: except AttributeError: pass +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('validate') @@ -101,8 +107,11 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', - help='convert model torchscript for inference') +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', @@ -162,7 +171,10 @@ def validate(args): if args.torchscript: torch.jit.optimized_execution(True) - model = torch.jit.script(model) + model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size'])) + if args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) model = model.cuda() if args.apex_amp: