Default to native PyTorch AMP instead of APEX amp. Too many APEX issues cropping up lately.

pull/419/head
Ross Wightman 4 years ago
parent b4e216e377
commit 0356e773f5

@ -177,7 +177,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
if cfg is None: if cfg is None:
cfg = getattr(model, 'default_cfg') cfg = getattr(model, 'default_cfg')
if cfg is None or 'url' not in cfg or not cfg['url']: if cfg is None or 'url' not in cfg or not cfg['url']:
_logger.warning("Pretrained model URL does not exist, using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu')

@ -310,11 +310,11 @@ def main():
# resolve AMP arguments based on PyTorch / Apex availability # resolve AMP arguments based on PyTorch / Apex availability
use_amp = None use_amp = None
if args.amp: if args.amp:
# for backwards compat, `--amp` arg tries apex before native amp # `--amp` chooses native amp before apex (APEX ver not actively maintained)
if has_apex: if has_native_amp:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True args.native_amp = True
elif has_apex:
args.apex_amp = True
if args.apex_amp and has_apex: if args.apex_amp and has_apex:
use_amp = 'apex' use_amp = 'apex'
elif args.native_amp and has_native_amp: elif args.native_amp and has_native_amp:

@ -116,15 +116,20 @@ def validate(args):
args.prefetcher = not args.no_prefetcher args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing amp_autocast = suppress # do nothing
if args.amp: if args.amp:
if has_apex: if has_native_amp:
args.apex_amp = True
elif has_native_amp:
args.native_amp = True args.native_amp = True
elif has_apex:
args.apex_amp = True
else: else:
_logger.warning("Neither APEX or Native Torch AMP is available, using FP32.") _logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set." assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp: if args.native_amp:
amp_autocast = torch.cuda.amp.autocast amp_autocast = torch.cuda.amp.autocast
_logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.legacy_jit: if args.legacy_jit:
set_jit_legacy() set_jit_legacy()

Loading…
Cancel
Save