Merge pull request #1294 from xwang233/add-aot-autograd

Add AOT Autograd support
pull/1213/merge
Ross Wightman 3 years ago committed by GitHub
commit db8e33c69f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,6 +51,12 @@ except ImportError as e:
FlopCountAnalysis = None FlopCountAnalysis = None
has_fvcore_profiling = False has_fvcore_profiling = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('validate') _logger = logging.getLogger('validate')
@ -95,10 +101,13 @@ parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--precision', default='float32', type=str, parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str, parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
# train optimizer parameters # train optimizer parameters
@ -188,7 +197,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner: class BenchmarkRunner:
def __init__( def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self.model_name = model_name self.model_name = model_name
self.detail = detail self.detail = detail
@ -220,11 +229,14 @@ class BenchmarkRunner:
if torchscript: if torchscript:
self.model = torch.jit.script(self.model) self.model = torch.jit.script(self.model)
self.scripted = True self.scripted = True
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size) data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
self.input_size = data_config['input_size'] self.input_size = data_config['input_size']
self.batch_size = kwargs.pop('batch_size', 256) self.batch_size = kwargs.pop('batch_size', 256)
if aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
self.model = memory_efficient_fusion(self.model)
self.example_inputs = None self.example_inputs = None
self.num_warm_iter = num_warm_iter self.num_warm_iter = num_warm_iter
self.num_bench_iter = num_bench_iter self.num_bench_iter = num_bench_iter

@ -61,6 +61,13 @@ try:
except ImportError: except ImportError:
has_wandb = False has_wandb = False
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('train') _logger = logging.getLogger('train')
@ -123,8 +130,11 @@ group.add_argument('-vb', '--validation-batch-size', type=int, default=None, met
help='Validation batch size override (default: None)') help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False, group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout') help='Use channels_last memory layout')
group.add_argument('--torchscript', dest='torchscript', action='store_true', scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model') help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
group.add_argument('--fuser', default='', type=str, group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False, group.add_argument('--grad-checkpointing', action='store_true', default=False,
@ -445,6 +455,9 @@ def main():
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model' assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model' assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
model = torch.jit.script(model) model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))

Loading…
Cancel
Save