diff --git a/timm/bits/README.md b/timm/bits/README.md index 76071164..c51d4348 100644 --- a/timm/bits/README.md +++ b/timm/bits/README.md @@ -83,8 +83,15 @@ With PyTorch XLA on a TPU-VM and TFDS you'll end up with a lot of processes and With all the above done, you should be ready to train... below is one particular train command I've just recently been using for some trials on vision MLP models... +Make sure the TPU config for PyTorch XLA on TPU-VM is set: ``` - python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --lr 8.8e-4 -b 256 + export XRT_TPU_CONFIG="localservice;0;localhost:51011" +``` + +Then, launch fighters! + +``` + python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256 ``` NOTE: build my TFDS dataset at ver 5.0.0 and it defaults to a newer version now. Change accordingly. diff --git a/train.py b/train.py index c484ad0d..cca814fd 100755 --- a/train.py +++ b/train.py @@ -536,6 +536,10 @@ def setup_data(args, default_cfg, dev_env, mixup_active): use_multi_epochs_loader=args.use_multi_epochs_loader ) + eval_workers = args.workers + if 'tfds' in args.dataset: + # FIXME reduce validation issues when using TFDS w/ workers and distributed training + eval_workers = min(2, args.workers) loader_eval = create_loader( dataset_eval, input_size=data_config['input_size'], @@ -544,7 +548,7 @@ def setup_data(args, default_cfg, dev_env, mixup_active): interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], - num_workers=args.workers, + num_workers=eval_workers, crop_pct=data_config['crop_pct'], pin_memory=args.pin_mem, )