From ff6a919cf5f0a325236cf57c07548f779123173f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 25 Aug 2022 17:00:54 -0700 Subject: [PATCH] Add --fast-norm arg to benchmark.py, train.py, validate.py --- benchmark.py | 8 ++++++-- timm/models/__init__.py | 1 + train.py | 6 +++++- validate.py | 6 +++++- 4 files changed, 17 insertions(+), 4 deletions(-) diff --git a/benchmark.py b/benchmark.py index 4679a009..4a89441b 100755 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,7 @@ import torch.nn as nn import torch.nn.parallel from timm.data import resolve_data_config -from timm.models import create_model, is_model, list_models +from timm.models import create_model, is_model, list_models, set_fast_norm from timm.optim import create_optimizer_v2 from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry @@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_ help='convert model torchscript for inference') scripting_group.add_argument('--aot-autograd', default=False, action='store_true', help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") - +scripting_group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', @@ -598,6 +599,9 @@ def main(): model_cfgs = [] model_names = [] + if args.fast_norm: + set_fast_norm() + if args.model_list: args.model = '' with open(args.model_list) as f: diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 51a38d0c..5ff79595 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit +from .layers import set_fast_norm from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/train.py b/train.py index e5d40566..ee137217 100755 --- a/train.py +++ b/train.py @@ -33,7 +33,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ LabelSmoothingCrossEntropy from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ - convert_splitbn_model, convert_sync_batchnorm, model_parameters + convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -135,6 +135,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +group.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') @@ -395,6 +397,8 @@ def main(): if args.fuser: utils.set_jit_fuser(args.fuser) + if args.fast_norm: + set_fast_norm() model = create_model( args.model, diff --git a/validate.py b/validate.py index a4d41868..6244f052 100755 --- a/validate.py +++ b/validate.py @@ -20,7 +20,7 @@ import torch.nn.parallel from collections import OrderedDict from contextlib import suppress -from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models +from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\ decay_batch_step, check_batch_size_retry @@ -117,6 +117,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)") parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--fast-norm', default=False, action='store_true', + help='enable experimental fast-norm') parser.add_argument('--results-file', default='', type=str, metavar='FILENAME', help='Output csv file for validation results (summary)') parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', @@ -150,6 +152,8 @@ def validate(args): if args.fuser: set_jit_fuser(args.fuser) + if args.fast_norm: + set_fast_norm() # create model model = create_model(