Worker hack for TFDS eval, add TPU env var setting.

pull/1239/head
Ross Wightman 3 years ago
parent f411724de4
commit c3db5f5801

@ -83,8 +83,15 @@ With PyTorch XLA on a TPU-VM and TFDS you'll end up with a lot of processes and
With all the above done, you should be ready to train... below is one particular train command I've just recently been using for some trials on vision MLP models...
Make sure the TPU config for PyTorch XLA on TPU-VM is set:
```
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --lr 8.8e-4 -b 256
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
```
Then, launch fighters!
```
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
NOTE: build my TFDS dataset at ver 5.0.0 and it defaults to a newer version now. Change accordingly.

@ -536,6 +536,10 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
use_multi_epochs_loader=args.use_multi_epochs_loader
)
eval_workers = args.workers
if 'tfds' in args.dataset:
# FIXME reduce validation issues when using TFDS w/ workers and distributed training
eval_workers = min(2, args.workers)
loader_eval = create_loader(
dataset_eval,
input_size=data_config['input_size'],
@ -544,7 +548,7 @@ def setup_data(args, default_cfg, dev_env, mixup_active):
interpolation=data_config['interpolation'],
mean=data_config['mean'],
std=data_config['std'],
num_workers=args.workers,
num_workers=eval_workers,
crop_pct=data_config['crop_pct'],
pin_memory=args.pin_mem,
)

Loading…
Cancel
Save