|
|
|
@ -18,6 +18,12 @@ import torch.nn as nn
|
|
|
|
|
import torch.nn.parallel
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from apex import amp
|
|
|
|
|
has_apex = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
has_apex = False
|
|
|
|
|
|
|
|
|
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
|
|
|
|
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
|
|
|
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
|
|
|
@ -61,8 +67,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
|
parser.add_argument('--pin-mem', action='store_true', default=False,
|
|
|
|
|
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
|
|
|
|
parser.add_argument('--fp16', action='store_true', default=False,
|
|
|
|
|
help='Use half precision (fp16)')
|
|
|
|
|
parser.add_argument('--amp', action='store_true', default=False,
|
|
|
|
|
help='Use AMP mixed precision')
|
|
|
|
|
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
|
|
|
|
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
|
|
|
|
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|
|
|
@ -98,13 +104,13 @@ def validate(args):
|
|
|
|
|
torch.jit.optimized_execution(True)
|
|
|
|
|
model = torch.jit.script(model)
|
|
|
|
|
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
|
|
|
|
if args.amp:
|
|
|
|
|
model = amp.initialize(model.cuda(), opt_level='O1')
|
|
|
|
|
else:
|
|
|
|
|
model = model.cuda()
|
|
|
|
|
|
|
|
|
|
if args.fp16:
|
|
|
|
|
model = model.half()
|
|
|
|
|
if args.num_gpu > 1:
|
|
|
|
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
|
|
|
|
|
|
|
|
|
criterion = nn.CrossEntropyLoss().cuda()
|
|
|
|
|
|
|
|
|
@ -127,7 +133,6 @@ def validate(args):
|
|
|
|
|
num_workers=args.workers,
|
|
|
|
|
crop_pct=crop_pct,
|
|
|
|
|
pin_memory=args.pin_mem,
|
|
|
|
|
fp16=args.fp16,
|
|
|
|
|
tf_preprocessing=args.tf_preprocessing)
|
|
|
|
|
|
|
|
|
|
batch_time = AverageMeter()
|
|
|
|
|