From 0356e773f5405eea1032e5c4c0be528128e5684e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 10 Feb 2021 14:31:18 -0800 Subject: [PATCH] Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately. --- timm/models/helpers.py | 2 +- train.py | 8 ++++---- validate.py | 13 +++++++++---- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 33744eb5..d9b501da 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -177,7 +177,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non if cfg is None: cfg = getattr(model, 'default_cfg') if cfg is None or 'url' not in cfg or not cfg['url']: - _logger.warning("Pretrained model URL does not exist, using random initialization.") + _logger.warning("No pretrained weights exist for this model. Using random initialization.") return state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') diff --git a/train.py b/train.py index f0fcd2af..2f9c3744 100755 --- a/train.py +++ b/train.py @@ -310,11 +310,11 @@ def main(): # resolve AMP arguments based on PyTorch / Apex availability use_amp = None if args.amp: - # for backwards compat, `--amp` arg tries apex before native amp - if has_apex: - args.apex_amp = True - elif has_native_amp: + # `--amp` chooses native amp before apex (APEX ver not actively maintained) + if has_native_amp: args.native_amp = True + elif has_apex: + args.apex_amp = True if args.apex_amp and has_apex: use_amp = 'apex' elif args.native_amp and has_native_amp: diff --git a/validate.py b/validate.py index 83f66fa5..ca69df08 100755 --- a/validate.py +++ b/validate.py @@ -116,15 +116,20 @@ def validate(args): args.prefetcher = not args.no_prefetcher amp_autocast = suppress # do nothing if args.amp: - if has_apex: - args.apex_amp = True - elif has_native_amp: + if has_native_amp: args.native_amp = True + elif has_apex: + args.apex_amp = True else: - _logger.warning("Neither APEX or Native Torch AMP is available, using FP32.") + _logger.warning("Neither APEX or Native Torch AMP is available.") assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." if args.native_amp: amp_autocast = torch.cuda.amp.autocast + _logger.info('Validating in mixed precision with native PyTorch AMP.') + elif args.apex_amp: + _logger.info('Validating in mixed precision with NVIDIA APEX AMP.') + else: + _logger.info('Validating in float32. AMP not enabled.') if args.legacy_jit: set_jit_legacy()