Update README.md

pull/1239/head
Ross Wightman 3 years ago committed by GitHub
parent 5c5cadfe4c
commit 847b4af144
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,7 +1,7 @@
# Timm Bits
## Intro
A collection of reusable components and lightweight abstractions for training and evaluating NN with PyTorch.
A collection of reusable components and lightweight abstractions for training and evaluating NN with PyTorch and PyTorch XLA.
This is an early WIP (consider it pre-alpha) with the primary goal to get up and running on TPUs w/ PyTorch XLA as the first priority. Expect significant changes, rewrites, additions, and of course bugs.
@ -9,30 +9,32 @@ The current train.py and validate.py scripts are evolving to use the timm.bits c
## Bits Design Brief
`bits` is designed to be a lightweight and modular set of training abstractions. It certainly shares concepts with other libraries (fastai, ignite, lightning, keras, etc, etc) but is not modeled after any specific one. It is supposed to be a 'bit different', hackable, and not everything to everyone.
`bits` is designed to be a lightweight and modular set of training abstractions. It shares concepts with other libraries (fastai, ignite, lightning, keras, etc, etc) but is not modeled after any specific one. It is supposed to be a 'bit different', hackable, and is purposely not trying to serve every use case or be everything to everyone.
`timm` models will always be useable in pure PyTorch w/o `bits` or anything besides the utils / helpers for pretrained models, feature extraction, default data config. I may breakout bits into a diff project if there is any interest besides my own use for timm image and video model training.
`timm` models will always be useable in pure PyTorch w/o `bits` or anything dependencies besides the model utils and helpers for pretrained models, feature extraction, default data config.
I may breakout bits into a diff project if there is any interest besides my own use for timm image and video model training.
The layers:
* DeviceEnv - DeviceEnv dataclass abstraction deals with PyTorch CPU, GPU and XLA device differences, incl distributed helpers, wrappers, etc. There is more than a passing similarity to HuggingFace Accelerate, but developed in parallel and with some difference in the detail.
* Updater - Dataclass that combines the backward pass, optimizer step, grad scaling, grad accumulation is a possibly device specific abstraction.
* Currently basic single optimizer, single forward/backward Updaters are included for GPU, XLA.
* DeviceEnv - DeviceEnv dataclass abstraction handles PyTorch CPU, GPU and XLA device differences, incl distributed functions, parallel wrappers, etc. There is more than a passing similarity to HuggingFace Accelerate, but developed in parallel and with some difference in the detail and separation of concerns.
* Updater - A dataclass that combines the backward pass, optimizer step, grad scaling, grad accumulation in device specific abstraction.
* Currently, basic single optimizer, single forward/backward Updaters are included for GPU, XLA.
* Deepseed will need its own Updater(s) since its Engine is a monolith of epic proportions that breaks all separations of concern in PyTorch (UGH!). NOTE Deepspeed not working yet nor is it a priority.
* Monitor - pull together all console logging, csv summaries, tensorboard, and WandB summaries into one module for monitoring your training.
* Checkpoint Manager - keeps track of your checkpoints
* Metrics - yet another set of metrics, although this may be replaced w/ an external set of classes. Uses same update / reset / compute interface as Ignite and Lightning (in theory interchangeable w/ an adapter). Metrics keep state on GPU / TPU to avoid device -> cpu transfers (esp for XLA).
* Task (not implemented yet) - combine your model(s) w/ losses in a task specific module, will also allow task factory for easy build of related metrics
* Train State - dataclasses to hold your tasks (models), updater state, etc
* Metrics - yet another set of metrics, although this may be replaced w/ an external set of classes. Uses same update / reset / compute interface as Ignite and Lightning (in theory interchangeable w/ a thin adapter). Metrics keep state on GPU / TPU to avoid device -> cpu transfers (esp for XLA).
* Task (not implemented yet) - combine your model(s) w/ losses in a task specific module, will also allow task factory for easy build of appripriate metrics
* TrainState - dataclasses to hold your tasks (models), updater state, etc
* Train Loop Functions (still in train.py script, not refined) - set of functions for train step, 'after step', evaluate using all of the components mentioned
How is this different than other options?
* I'm very much trying to avoid a monolithic trainer / learner / model wrapping type class with billions of hooks (avoiding granular inversion of control!).
* I'm very much trying to avoid a monolithic trainer / learner / model wrapping type class with numerous hooks and overrides to keep track of (avoiding granular inversion of control!).
* The goal is to provide reusable modules that can (hopefully) be mixed and matched w/ other code.
* Many of the components are based on Python dataclasses to reduce boilerplate.
* The train loop components are (will be) functional with easy to follow flow control, and are intended to be replaced when something different is needed, not augmented with hooks via callbacks or inheritence at every conceivable touch point.
## Quick Start
## Quick Start for PyTorch XLA on TPU-VM
Most initial users will likely be interested in training timm models w/ PyTorch XLA on TPU-VM instances, this quick start will get you moving.
@ -51,16 +53,17 @@ One thing to watch, be very careful that you don't use a GS based dataset in a d
### Install TFDS (if using GS buckets)
```
pip3 install tensorflow-datasets
pip3 install tensorflow-datasets
```
In some earlier tpu-vm instances the installed tensorflow version had issues with the GS bucket reading support and I often ended up installing a diff version. This could conflict with other use cases so only do it if needed.
In some tpu-vm instances may have tensorflow version pre-installed that conflict with tensorflow-datasets, especially the bucket reading support. If training crashes with errors about inability to ready from buckets, tensorflow symbol errors, tensorflow datasets missing functions, etc, you should try removing the pre-installed tensorflow and installing one from pypi.
```
pip3 install --upgrade tensorflow-cpu
sudo pip3 uninstall tf-nightly
pip3 install tensorflow-cpu
```
You may run into some numpy / pytorch version dependency issues here, try capping the version of tensorflow at 2.4.1 in above command.
You may run into some numpy / pytorch version dependency issues here, try capping the version of tensorflow at `==2.4.1` in above command.
### Get your dataset into buckets
@ -74,9 +77,9 @@ The TFDS dataset pages (https://www.tensorflow.org/datasets/catalog/imagenet2012
With PyTorch XLA on a TPU-VM and TFDS you'll end up with a lot of processes and buffering. The instance memory will be used up quickly. I highly recommend using a custom allocator via `LD_PRELOAD`. tcmalloc may now be a default in the tpu-vm instanecs (check first). jemalloc also worked well for me. If LD_PRELOAD is not set in your env, do the following
```
sudo apt update
sudo apt install google-perftools
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
sudo apt update
sudo apt install google-perftools
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
```
# Train, train, train
@ -85,17 +88,32 @@ With all the above done, you should be ready to train... below is one particular
Make sure the TPU config for PyTorch XLA on TPU-VM is set:
```
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
```
Then, launch fighters!
```
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
NOTE: build my TFDS dataset at ver 5.0.0 and it defaults to a newer version now. Change accordingly.
# Quick Start w/ GPU
`timm bits` should work great on your multi-GPU setups just like the old `timm` training script with either TFDS based datasets or a local folder.
The equivalent training command of the XLA setup above if you were on an 8-GPU machine and using TFDS would be,
```
./distrbuted_train.sh 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
Or this for imagenet in a local folder,
```
./distrbuted_train.sh 8 train.py /path/to/imagenet --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
# Gotchas and Known Issues
* When PyTorch XLA crashes, you hit a TPU OOM etc, lots of processes get orphaned. Get in the habit of killing all python processes before starting a new train run.
* `alias fml='pkill -f python3'`

Loading…
Cancel
Save