Replace fp16 with amp support for validate.py script

pull/136/head
Ross Wightman 5 years ago
parent e6f24e5578
commit 02a30411ad

@ -18,6 +18,12 @@ import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
@ -61,8 +67,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False, parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--fp16', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='Use half precision (fp16)') help='Use AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False, parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true', parser.add_argument('--use-ema', dest='use_ema', action='store_true',
@ -98,13 +104,13 @@ def validate(args):
torch.jit.optimized_execution(True) torch.jit.optimized_execution(True)
model = torch.jit.script(model) model = torch.jit.script(model)
if args.num_gpu > 1: if args.amp:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = amp.initialize(model.cuda(), opt_level='O1')
else: else:
model = model.cuda() model = model.cuda()
if args.fp16: if args.num_gpu > 1:
model = model.half() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda() criterion = nn.CrossEntropyLoss().cuda()
@ -127,7 +133,6 @@ def validate(args):
num_workers=args.workers, num_workers=args.workers,
crop_pct=crop_pct, crop_pct=crop_pct,
pin_memory=args.pin_mem, pin_memory=args.pin_mem,
fp16=args.fp16,
tf_preprocessing=args.tf_preprocessing) tf_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter() batch_time = AverageMeter()

Loading…
Cancel
Save