|
|
|
@ -20,7 +20,7 @@ import torch.nn.parallel
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from contextlib import suppress
|
|
|
|
|
|
|
|
|
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
|
|
|
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
|
|
|
|
|
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
|
|
|
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
|
|
|
|
|
decay_batch_step, check_batch_size_retry
|
|
|
|
@ -117,6 +117,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
|
|
|
|
|
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
|
|
|
|
parser.add_argument('--fuser', default='', type=str,
|
|
|
|
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
|
|
|
|
parser.add_argument('--fast-norm', default=False, action='store_true',
|
|
|
|
|
help='enable experimental fast-norm')
|
|
|
|
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
|
|
|
|
help='Output csv file for validation results (summary)')
|
|
|
|
|
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
|
|
|
@ -150,6 +152,8 @@ def validate(args):
|
|
|
|
|
|
|
|
|
|
if args.fuser:
|
|
|
|
|
set_jit_fuser(args.fuser)
|
|
|
|
|
if args.fast_norm:
|
|
|
|
|
set_fast_norm()
|
|
|
|
|
|
|
|
|
|
# create model
|
|
|
|
|
model = create_model(
|
|
|
|
|