From 02a30411ad516b559635216e38bdf1fc2dc1d31f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 27 Apr 2020 12:14:40 -0700 Subject: [PATCH] Replace fp16 with amp support for validate.py script --- validate.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/validate.py b/validate.py index 34ce95c0..f8ac7c55 100755 --- a/validate.py +++ b/validate.py @@ -18,6 +18,12 @@ import torch.nn as nn import torch.nn.parallel from collections import OrderedDict +try: + from apex import amp + has_apex = True +except ImportError: + has_apex = False + from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging @@ -61,8 +67,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') -parser.add_argument('--fp16', action='store_true', default=False, - help='Use half precision (fp16)') +parser.add_argument('--amp', action='store_true', default=False, + help='Use AMP mixed precision') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', @@ -98,13 +104,13 @@ def validate(args): torch.jit.optimized_execution(True) model = torch.jit.script(model) - if args.num_gpu > 1: - model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + if args.amp: + model = amp.initialize(model.cuda(), opt_level='O1') else: model = model.cuda() - if args.fp16: - model = model.half() + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) criterion = nn.CrossEntropyLoss().cuda() @@ -127,7 +133,6 @@ def validate(args): num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, - fp16=args.fp16, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter()