diff --git a/timm/bits/device_env_factory.py b/timm/bits/device_env_factory.py index bb92daab..620e400d 100644 --- a/timm/bits/device_env_factory.py +++ b/timm/bits/device_env_factory.py @@ -20,8 +20,12 @@ def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv: elif is_cuda_available(): denv = DeviceEnvCuda(**kwargs) + # CPU fallback if denv is None: - denv = DeviceEnv() + if is_xla_available('CPU'): + denv = DeviceEnvXla(device_type='CPU', **kwargs) + else: + denv = DeviceEnv() _logger.info(f'Initialized device {denv.device}. ' f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.') diff --git a/train.py b/train.py index 217f9a88..cad41bca 100755 --- a/train.py +++ b/train.py @@ -60,8 +60,8 @@ parser.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') -parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', - help='Name of model to train (default: "countception"') +parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', + help='Name of model to train (default: "resnet50"') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', @@ -215,8 +215,6 @@ parser.add_argument('--split-bn', action='store_true', # Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') -parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, - help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') @@ -252,6 +250,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa help='use the multi-epochs-loader to save time at the beginning of every epoch') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') +parser.add_argument('--force-cpu', action='store_true', default=False, + help='Force CPU to be used even if HW accelerator exists.') parser.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') @@ -277,7 +277,7 @@ def main(): setup_default_logging() args, args_text = _parse_args() - dev_env = initialize_device(amp=args.amp, channels_last=args.channels_last) + dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp, channels_last=args.channels_last) if dev_env.distributed: _logger.info('Training in distributed mode with multiple processes, 1 device per process. Process %d, total %d.' % (dev_env.global_rank, dev_env.world_size)) @@ -364,7 +364,7 @@ def main(): services.monitor, dev_env) - if train_state.model_ema is not None and not args.model_ema_force_cpu: + if train_state.model_ema is not None: if dev_env.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(train_state.model_ema, args.dist_bn == 'reduce', dev_env) diff --git a/validate.py b/validate.py index f4dc84e8..f9189171 100755 --- a/validate.py +++ b/validate.py @@ -35,8 +35,8 @@ parser.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') -parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', - help='model architecture (default: dpn92)') +parser.add_argument('--model', '-m', metavar='NAME', default='resnet50', + help='model architecture (default: resnet50)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, @@ -87,13 +87,15 @@ parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME', help='Real labels JSON file for imagenet evaluation') parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME', help='Valid label indices txt file for validation of partial label space') +parser.add_argument('--force-cpu', action='store_true', default=False, + help='Force CPU to be used even if HW accelerator exists.') def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint - dev_env = initialize_device(amp=args.amp) + dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp) # create model model = create_model(