Add --fast-norm arg to benchmark.py, train.py, validate.py

pull/1415/head
Ross Wightman 2 years ago
parent 769ab4b98a
commit ff6a919cf5

@ -19,7 +19,7 @@ import torch.nn as nn
import torch.nn.parallel
from timm.data import resolve_data_config
from timm.models import create_model, is_model, list_models
from timm.models import create_model, is_model, list_models, set_fast_norm
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
help='convert model torchscript for inference')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@ -598,6 +599,9 @@ def main():
model_cfgs = []
model_names = []
if args.fast_norm:
set_fast_norm()
if args.model_list:
args.model = ''
with open(args.model_list) as f:

@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .layers import set_fast_norm
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -33,7 +33,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup,
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
@ -135,6 +135,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
@ -395,6 +397,8 @@ def main():
if args.fuser:
utils.set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
model = create_model(
args.model,

@ -20,7 +20,7 @@ import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
decay_batch_step, check_batch_size_retry
@ -117,6 +117,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
@ -150,6 +152,8 @@ def validate(args):
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
# create model
model = create_model(

Loading…
Cancel
Save