Add explicit half/fp16 support to loader and validation script

pull/19/head
Ross Wightman 6 years ago
parent 5684c6af32
commit 6cdf35e670

@ -21,10 +21,15 @@ class PrefetchLoader:
rand_erase_prob=0.,
rand_erase_mode='const',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
std=IMAGENET_DEFAULT_STD,
fp16=False):
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if rand_erase_prob > 0.:
self.random_erasing = RandomErasing(
probability=rand_erase_prob, mode=rand_erase_mode)
@ -39,7 +44,10 @@ class PrefetchLoader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
@ -94,6 +102,7 @@ def create_loader(
distributed=False,
crop_pct=None,
collate_fn=None,
fp16=False,
tf_preprocessing=False,
):
if isinstance(input_size, tuple):
@ -151,6 +160,7 @@ def create_loader(
rand_erase_prob=rand_erase_prob if is_training else 0.,
rand_erase_mode=rand_erase_mode,
mean=mean,
std=std)
std=std,
fp16=fp16)
return loader

@ -156,12 +156,7 @@ def accuracy(output, target, topk=(1,)):
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk]
def get_outdir(path, *paths, inc=False):

@ -1 +1 @@
__version__ = '0.1.7'
__version__ = '0.1.8'

@ -50,7 +50,11 @@ parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--fp16', action='store_true', default=False,
help='Use half precision (fp16)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
@ -59,6 +63,7 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
# create model
model = create_model(
@ -81,6 +86,9 @@ def validate(args):
else:
model = model.cuda()
if args.fp16:
model = model.half()
criterion = nn.CrossEntropyLoss().cuda()
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
@ -88,12 +96,13 @@ def validate(args):
Dataset(args.data, load_bytes=args.tf_preprocessing),
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=True,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
crop_pct=crop_pct,
fp16=args.fp16,
tf_preprocessing=args.tf_preprocessing)
batch_time = AverageMeter()
@ -105,8 +114,11 @@ def validate(args):
end = time.time()
with torch.no_grad():
for i, (input, target) in enumerate(loader):
target = target.cuda()
input = input.cuda()
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
if args.fp16:
input = input.half()
# compute output
output = model(input)
@ -125,7 +137,7 @@ def validate(args):
if i % args.log_freq == 0:
logging.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(

Loading…
Cancel
Save