From f0f9eccda8dcf6fb546e762c544459a2771606c2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 17 Jan 2022 13:54:25 -0800 Subject: [PATCH] Add --fuser arg to train/validate/benchmark scripts to select jit fuser type --- benchmark.py | 9 ++++++--- timm/utils/__init__.py | 2 +- timm/utils/jit.py | 32 ++++++++++++++++++++++++++++++++ train.py | 5 +++++ validate.py | 10 +++++----- 5 files changed, 49 insertions(+), 9 deletions(-) diff --git a/benchmark.py b/benchmark.py index ccd9b4fa..17c095a8 100755 --- a/benchmark.py +++ b/benchmark.py @@ -21,7 +21,7 @@ from functools import partial from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 from timm.data import resolve_data_config -from timm.utils import AverageMeter, setup_default_logging +from timm.utils import setup_default_logging, set_jit_fuser has_apex = False @@ -95,7 +95,8 @@ parser.add_argument('--precision', default='float32', type=str, help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') - +parser.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") # train optimizer parameters @@ -186,7 +187,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False class BenchmarkRunner: def __init__( self, model_name, detail=False, device='cuda', torchscript=False, precision='float32', - num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): + fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs): self.model_name = model_name self.detail = detail self.device = device @@ -194,6 +195,8 @@ class BenchmarkRunner: self.channels_last = kwargs.pop('channels_last', False) self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress + if fuser: + set_jit_fuser(fuser) self.model = create_model( model_name, num_classes=kwargs.pop('num_classes', None), diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index 11de9c9c..b8cef321 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -3,7 +3,7 @@ from .checkpoint_saver import CheckpointSaver from .clip_grad import dispatch_clip_grad from .cuda import ApexScaler, NativeScaler from .distributed import distribute_bn, reduce_tensor -from .jit import set_jit_legacy +from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg diff --git a/timm/utils/jit.py b/timm/utils/jit.py index 185ab7a0..6039823f 100644 --- a/timm/utils/jit.py +++ b/timm/utils/jit.py @@ -2,6 +2,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import os + import torch @@ -16,3 +18,33 @@ def set_jit_legacy(): torch._C._jit_set_profiling_mode(False) torch._C._jit_override_can_fuse_on_gpu(True) #torch._C._jit_set_texpr_fuser_enabled(True) + + +def set_jit_fuser(fuser): + if fuser == "te": + # default fuser should be == 'te' + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(True) + elif fuser == "old" or fuser == "legacy": + torch._C._jit_set_profiling_executor(False) + torch._C._jit_set_profiling_mode(False) + torch._C._jit_override_can_fuse_on_gpu(True) + torch._C._jit_set_texpr_fuser_enabled(False) + elif fuser == "nvfuser" or fuser == "nvf": + os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' + os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' + os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' + torch._C._jit_set_texpr_fuser_enabled(False) + torch._C._jit_set_profiling_executor(True) + torch._C._jit_set_profiling_mode(True) + torch._C._jit_can_fuse_on_cpu() + torch._C._jit_can_fuse_on_gpu() + torch._C._jit_override_can_fuse_on_cpu(False) + torch._C._jit_override_can_fuse_on_gpu(False) + torch._C._jit_set_nvfuser_guard_mode(True) + torch._C._jit_set_nvfuser_enabled(True) + else: + assert False, f"Invalid jit fuser ({fuser})" diff --git a/train.py b/train.py index 6e3b058b..849f40e3 100755 --- a/train.py +++ b/train.py @@ -295,6 +295,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') +parser.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') @@ -364,6 +366,9 @@ def main(): random_seed(args.seed, args.rank) + if args.fuser: + set_jit_fuser(args.fuser) + model = create_model( args.model, pretrained=args.pretrained, diff --git a/validate.py b/validate.py index d69d076f..bbb1e8dc 100755 --- a/validate.py +++ b/validate.py @@ -21,7 +21,7 @@ from contextlib import suppress from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet -from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy +from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser has_apex = False try: @@ -102,8 +102,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') -parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true', - help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance') +parser.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', @@ -133,8 +133,8 @@ def validate(args): else: _logger.info('Validating in float32. AMP not enabled.') - if args.legacy_jit: - set_jit_legacy() + if args.fuser: + set_jit_fuser(args.fuser) # create model model = create_model(