Add --aot-autograd (functorch efficient mem fusion) support to validate.py

pull/1327/head
Ross Wightman 2 years ago
parent 28e0152043
commit 500c190860

@ -38,6 +38,12 @@ try:
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate')
@ -101,8 +107,11 @@ parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
@ -162,7 +171,10 @@ def validate(args):
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size']))
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
model = model.cuda()
if args.apex_amp:

Loading…
Cancel
Save