From 4666cc9aed91fc674b5b85ad5b85e3876abc2885 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 2 Jan 2020 16:22:06 -0800 Subject: [PATCH] Add --pin-mem arg to enable dataloader pin_memory (showing more benefit in some scenarios now), also add --torchscript arg to validate.py for testing models with jit.script --- timm/data/loader.py | 2 ++ train.py | 4 ++++ validate.py | 9 +++++++++ 3 files changed, 15 insertions(+) diff --git a/timm/data/loader.py b/timm/data/loader.py index 8c27f1bb..bbb71eca 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -148,6 +148,7 @@ def create_loader( distributed=False, crop_pct=None, collate_fn=None, + pin_memory=False, fp16=False, tf_preprocessing=False, ): @@ -183,6 +184,7 @@ def create_loader( num_workers=num_workers, sampler=sampler, collate_fn=collate_fn, + pin_memory=pin_memory, drop_last=is_training, ) if use_prefetcher: diff --git a/train.py b/train.py index a47f1b4d..b8f37f41 100644 --- a/train.py +++ b/train.py @@ -149,6 +149,8 @@ parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') +parser.add_argument('--pin-mem', action='store_true', default=False, + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', @@ -330,6 +332,7 @@ def main(): num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, + pin_memory=args.pin_mem, ) eval_dir = os.path.join(args.data, 'val') @@ -352,6 +355,7 @@ def main(): num_workers=args.workers, distributed=args.distributed, crop_pct=data_config['crop_pct'], + pin_memory=args.pin_mem, ) if args.mixup > 0.: diff --git a/validate.py b/validate.py index 21dfcf89..004393ab 100644 --- a/validate.py +++ b/validate.py @@ -52,12 +52,16 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') +parser.add_argument('--pin-mem', action='store_true', default=False, + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--fp16', action='store_true', default=False, help='Use half precision (fp16)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') def validate(args): @@ -81,6 +85,10 @@ def validate(args): data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: @@ -107,6 +115,7 @@ def validate(args): std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, + pin_memory=args.pin_mem, fp16=args.fp16, tf_preprocessing=args.tf_preprocessing)