diff --git a/timm/data/loader.py b/timm/data/loader.py index 6a19b805..6b6e2b39 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -21,10 +21,15 @@ class PrefetchLoader: rand_erase_prob=0., rand_erase_mode='const', mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD): + std=IMAGENET_DEFAULT_STD, + fp16=False): self.loader = loader self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) + self.fp16 = fp16 + if fp16: + self.mean = self.mean.half() + self.std = self.std.half() if rand_erase_prob > 0.: self.random_erasing = RandomErasing( probability=rand_erase_prob, mode=rand_erase_mode) @@ -39,7 +44,10 @@ class PrefetchLoader: with torch.cuda.stream(stream): next_input = next_input.cuda(non_blocking=True) next_target = next_target.cuda(non_blocking=True) - next_input = next_input.float().sub_(self.mean).div_(self.std) + if self.fp16: + next_input = next_input.half().sub_(self.mean).div_(self.std) + else: + next_input = next_input.float().sub_(self.mean).div_(self.std) if self.random_erasing is not None: next_input = self.random_erasing(next_input) @@ -94,6 +102,7 @@ def create_loader( distributed=False, crop_pct=None, collate_fn=None, + fp16=False, tf_preprocessing=False, ): if isinstance(input_size, tuple): @@ -151,6 +160,7 @@ def create_loader( rand_erase_prob=rand_erase_prob if is_training else 0., rand_erase_mode=rand_erase_mode, mean=mean, - std=std) + std=std, + fp16=fp16) return loader diff --git a/timm/utils.py b/timm/utils.py index 36355c2b..7de38a80 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -156,12 +156,7 @@ def accuracy(output, target, topk=(1,)): _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res + return [correct[:k].view(-1).float().sum(0) * 100. / batch_size for k in topk] def get_outdir(path, *paths, inc=False): diff --git a/timm/version.py b/timm/version.py index 124e4620..c3bb2961 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.1.7' +__version__ = '0.1.8' diff --git a/validate.py b/validate.py index a859a098..d90d54fd 100644 --- a/validate.py +++ b/validate.py @@ -50,7 +50,11 @@ parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool') -parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', +parser.add_argument('--no-prefetcher', action='store_true', default=False, + help='disable fast prefetcher') +parser.add_argument('--fp16', action='store_true', default=False, + help='Use half precision (fp16)') +parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') @@ -59,6 +63,7 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true', def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint + args.prefetcher = not args.no_prefetcher # create model model = create_model( @@ -81,6 +86,9 @@ def validate(args): else: model = model.cuda() + if args.fp16: + model = model.half() + criterion = nn.CrossEntropyLoss().cuda() crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] @@ -88,12 +96,13 @@ def validate(args): Dataset(args.data, load_bytes=args.tf_preprocessing), input_size=data_config['input_size'], batch_size=args.batch_size, - use_prefetcher=True, + use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, + fp16=args.fp16, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() @@ -105,8 +114,11 @@ def validate(args): end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): - target = target.cuda() - input = input.cuda() + if args.no_prefetcher: + target = target.cuda() + input = input.cuda() + if args.fp16: + input = input.half() # compute output output = model(input) @@ -125,7 +137,7 @@ def validate(args): if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' - 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' + 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(