Add force-cpu flag for train/validate, fix CPU fallback for device init, remove old force cpu flag for EMA model weights

pull/1239/head
Ross Wightman 3 years ago
parent 2ee398d501
commit f2e14685a8

@ -20,8 +20,12 @@ def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
elif is_cuda_available():
denv = DeviceEnvCuda(**kwargs)
# CPU fallback
if denv is None:
denv = DeviceEnv()
if is_xla_available('CPU'):
denv = DeviceEnvXla(device_type='CPU', **kwargs)
else:
denv = DeviceEnv()
_logger.info(f'Initialized device {denv.device}. '
f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.')

@ -60,8 +60,8 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50"')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
@ -215,8 +215,6 @@ parser.add_argument('--split-bn', action='store_true',
# Model Exponential Moving Average
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
@ -252,6 +250,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
help='use the multi-epochs-loader to save time at the beginning of every epoch')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--force-cpu', action='store_true', default=False,
help='Force CPU to be used even if HW accelerator exists.')
parser.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
@ -277,7 +277,7 @@ def main():
setup_default_logging()
args, args_text = _parse_args()
dev_env = initialize_device(amp=args.amp, channels_last=args.channels_last)
dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp, channels_last=args.channels_last)
if dev_env.distributed:
_logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.'
% (dev_env.global_rank, dev_env.world_size))
@ -364,7 +364,7 @@ def main():
services.monitor,
dev_env)
if train_state.model_ema is not None and not args.model_ema_force_cpu:
if train_state.model_ema is not None:
if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'):
distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env)

@ -35,8 +35,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
help='model architecture (default: resnet50)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -87,13 +87,15 @@ parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--force-cpu', action='store_true', default=False,
help='Force CPU to be used even if HW accelerator exists.')
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
dev_env = initialize_device(amp=args.amp)
dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp)
# create model
model = create_model(

Loading…
Cancel
Save