Add --pin-mem arg to enable dataloader pin_memory (showing more benefit in some scenarios now), also add --torchscript arg to validate.py for testing models with jit.script

pull/82/head
Ross Wightman 5 years ago
parent 53001dd292
commit 4666cc9aed

@ -148,6 +148,7 @@ def create_loader(
distributed=False, distributed=False,
crop_pct=None, crop_pct=None,
collate_fn=None, collate_fn=None,
pin_memory=False,
fp16=False, fp16=False,
tf_preprocessing=False, tf_preprocessing=False,
): ):
@ -183,6 +184,7 @@ def create_loader(
num_workers=num_workers, num_workers=num_workers,
sampler=sampler, sampler=sampler,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training, drop_last=is_training,
) )
if use_prefetcher: if use_prefetcher:

@ -149,6 +149,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training') help='use NVIDIA amp for mixed precision training')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-prefetcher', action='store_true', default=False, parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',
@ -330,6 +332,7 @@ def main():
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed, distributed=args.distributed,
collate_fn=collate_fn, collate_fn=collate_fn,
pin_memory=args.pin_mem,
) )
eval_dir = os.path.join(args.data, 'val') eval_dir = os.path.join(args.data, 'val')
@ -352,6 +355,7 @@ def main():
num_workers=args.workers, num_workers=args.workers,
distributed=args.distributed, distributed=args.distributed,
crop_pct=data_config['crop_pct'], crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
) )
if args.mixup > 0.: if args.mixup > 0.:

@ -52,12 +52,16 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool') help='disable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False, parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--fp16', action='store_true', default=False, parser.add_argument('--fp16', action='store_true', default=False,
help='Use half precision (fp16)') help='Use half precision (fp16)')
parser.add_argument('--tf-preprocessing', action='store_true', default=False, parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true', parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
def validate(args): def validate(args):
@ -81,6 +85,10 @@ def validate(args):
data_config = resolve_data_config(vars(args), model=model) data_config = resolve_data_config(vars(args), model=model)
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
if args.num_gpu > 1: if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
@ -107,6 +115,7 @@ def validate(args):
std=data_config['std'], std=data_config['std'],
num_workers=args.workers, num_workers=args.workers,
crop_pct=crop_pct, crop_pct=crop_pct,
pin_memory=args.pin_mem,
fp16=args.fp16, fp16=args.fp16,
tf_preprocessing=args.tf_preprocessing) tf_preprocessing=args.tf_preprocessing)

Loading…
Cancel
Save