Add --fuser arg to train/validate/benchmark scripts to select jit fuser type

pull/1094/head
Ross Wightman 2 years ago
parent 010b486590
commit f0f9eccda8

@ -21,7 +21,7 @@ from functools import partial
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.data import resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
from timm.utils import setup_default_logging, set_jit_fuser
has_apex = False
@ -95,7 +95,8 @@ parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
# train optimizer parameters
@ -186,7 +187,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
class BenchmarkRunner:
def __init__(
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
self.model_name = model_name
self.detail = detail
self.device = device
@ -194,6 +195,8 @@ class BenchmarkRunner:
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
if fuser:
set_jit_fuser(fuser)
self.model = create_model(
model_name,
num_classes=kwargs.pop('num_classes', None),

@ -3,7 +3,7 @@ from .checkpoint_saver import CheckpointSaver
from .clip_grad import dispatch_clip_grad
from .cuda import ApexScaler, NativeScaler
from .distributed import distribute_bn, reduce_tensor
from .jit import set_jit_legacy
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg

@ -2,6 +2,8 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import os
import torch
@ -16,3 +18,33 @@ def set_jit_legacy():
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
#torch._C._jit_set_texpr_fuser_enabled(True)
def set_jit_fuser(fuser):
if fuser == "te":
# default fuser should be == 'te'
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(True)
elif fuser == "old" or fuser == "legacy":
torch._C._jit_set_profiling_executor(False)
torch._C._jit_set_profiling_mode(False)
torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False)
elif fuser == "nvfuser" or fuser == "nvf":
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
torch._C._jit_can_fuse_on_cpu()
torch._C._jit_can_fuse_on_gpu()
torch._C._jit_override_can_fuse_on_cpu(False)
torch._C._jit_override_can_fuse_on_gpu(False)
torch._C._jit_set_nvfuser_guard_mode(True)
torch._C._jit_set_nvfuser_enabled(True)
else:
assert False, f"Invalid jit fuser ({fuser})"

@ -295,6 +295,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
@ -364,6 +366,9 @@ def main():
random_seed(args.seed, args.rank)
if args.fuser:
set_jit_fuser(args.fuser)
model = create_model(
args.model,
pretrained=args.pretrained,

@ -21,7 +21,7 @@ from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser
has_apex = False
try:
@ -102,8 +102,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@ -133,8 +133,8 @@ def validate(args):
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.legacy_jit:
set_jit_legacy()
if args.fuser:
set_jit_fuser(args.fuser)
# create model
model = create_model(

Loading…
Cancel
Save