From f411724de475989bec2873572fe9d2f0a7b7740e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 4 Jun 2021 12:49:53 -0700 Subject: [PATCH] Fix checkpoint delete issue. Add README about bits and initial Pytorch XLA usage on TPU-VM. Add some FIXMEs and fold train_cfg into train_state by default. --- timm/bits/README.md | 102 +++++++++++++++++++++++++++++++- timm/bits/checkpoint_manager.py | 8 +-- timm/bits/train_cfg.py | 4 +- timm/bits/train_services.py | 4 +- timm/bits/train_state.py | 22 +++++-- train.py | 85 +++++++++++++++++--------- 6 files changed, 183 insertions(+), 42 deletions(-) diff --git a/timm/bits/README.md b/timm/bits/README.md index 02ba6dc6..76071164 100644 --- a/timm/bits/README.md +++ b/timm/bits/README.md @@ -1,8 +1,104 @@ # Timm Bits -A collection of reusable components and lightweight abstractions for training and evaluating NN. +## Intro +A collection of reusable components and lightweight abstractions for training and evaluating NN with PyTorch. -This is an early WIP with the primary goal to get up and running on TPUs first. Expect significant changes, rewrites, additions... +This is an early WIP (consider it pre-alpha) with the primary goal to get up and running on TPUs w/ PyTorch XLA as the first priority. Expect significant changes, rewrites, additions, and of course bugs. -The current train.py and validate.py scipts are evolving to use the timm.bits components, they will also change significantly. +The current train.py and validate.py scripts are evolving to use the timm.bits components, they will also change significantly. +## Bits Design Brief + +`bits` is designed to be a lightweight and modular set of training abstractions. It certainly shares concepts with other libraries (fastai, ignite, lightning, keras, etc, etc) but is not modeled after any specific one. It is supposed to be a 'bit different', hackable, and not everything to everyone. + +`timm` models will always be useable in pure PyTorch w/o `bits` or anything besides the utils / helpers for pretrained models, feature extraction, default data config. I may breakout bits into a diff project if there is any interest besides my own use for timm image and video model training. + +The layers: +* Device - DeviceEnv dataclass abstraction deals with PyTorch CPU, GPU and XLA device differences, incl distributed helpers, wrappers, etc. There is more than a passing similarity to HuggingFace Accelerate, but developed in parallel and with some difference in the detail. +* Updater - Dataclass that combines the backward pass, optimizer step, grad scaling, grad accumulation is a possibly device specific abstraction. + * Currently basic single optimizer, single forward/backward Updaters are included for GPU, XLA. + * Deepseed will need its own Updater(s) since its Engine is a monolith of epic proportions that breaks all separations of concern in PyTorch (UGH!). NOTE Deepspeed not working yet nor is it a priority. +* Monitor - pull together all console logging, csv summaries, tensorboard, and WandB summaries into one module for monitoring your training. +* Checkpoint Manager - keeps track of your checkpoints +* Metrics - yet another set of metrics, although this may be replaced w/ an external set of classes. Uses same update / reset / compute interface as Ignite and Lightning (in theory interchangeable w/ an adapter). Metrics keep state on GPU / TPU to avoid device -> cpu transfers (esp for XLA). +* Task (not implemented yet) - combine your model(s) w/ losses in a task specific module, will also allow task factory for easy build of related metrics +* Train State - dataclasses to hold your tasks (models), updater state, etc +* Train Loop Functions (still in train.py script, not refined) - set of functions for train step, 'after step', evaluate using all of the components mentioned + +How is this different than other options? +* I'm very much trying to avoid a monolithic trainer / learner / model wrapping type class with billions of hooks (avoiding granular inversion of control!). +* The goal is to provide reusable modules that can (hopefully) be mixed and matched w/ other code. +* Many of the components are based on Python dataclasses to reduce boilerplate. +* The train loop components are (will be) functional with easy to follow flow control, and are intended to be replaced when something different is needed, not augmented with extremely granular hooks. + + +## Quick Start + +Most initial users will likely be interested in training timm models w/ PyTorch XLA on TPU-VM instances, this quick start will get you moving. + +If you haven't noticed, this code is on a branch, make sure you checkout the `bits_and_tpu` branch on `timm` before doing this. You can test locally on your GPU too, in either XLA + GPU in a container or the usual PyTorch w/ GPU. + +## Setup Python environment + +This setup assumes you've SSH'd into your TPU-VM after setting it up (https://cloud.google.com/tpu/docs/users-guide-tpu-vm). Don't forget to do this in a TMUX session or you'll be sad if you lose your connection! + +The TPU-VM instances I've been using have a usable version of PyTorch XLA 1.8.1 installed in the python3 environment, we will be using that. + +I've found that leveraging TFDS w/ datasets in TFRecord format, streamed from Google Storage buckets is the most practical / cost-effective solution. I've written a PyTorch IterabeDataset wrapper around TFDS so we will install Tensorflow datasets and use that. + +One thing to watch, be very careful that you don't use a GS based dataset in a different continent from you TPU-VM instances. I burned through a few thousand USD leaving some wires crossed for 1 day. Otherwise the cost of training w/ buckets in same region are quite low. + +### Install TFDS (if using GS buckets) + +``` + pip3 install tensorflow-datasets +``` + +In some earlier tpu-vm instances the installed tensorflow version had issues with the GS bucket reading support and I often ended up installing a diff version. This could conflict with other use cases so only do it if needed. + +``` + pip3 install --upgrade tensorflow-cpu +``` + +You may run into some numpy / pytorch version dependency issues here, try capping the version of tensorflow at 2.4.1 in above command. + + +### Get your dataset into buckets + +You will need to host your dataset in buckets. I have tried creating custom datasets for this setup, but have used a number of TFDS datasets such as ImageNet, Flowers, caltech Birds, Oxford Pets that are available in TFDS. + +The TFDS dataset pages (https://www.tensorflow.org/datasets/catalog/imagenet2012) have directions for the various datasets, I recommend building them in a different VM or local machine and then uploading to your training bucket. Many of them will auto-download and build the tfrecord shards for you. ImageNet needs to be downloaded manually. + +### Use a custom allocator + +With PyTorch XLA on a TPU-VM and TFDS you'll end up with a lot of processes and buffering. The instance memory will be used up quickly. I highly recommend using a custom allocator via `LD_PRELOAD`. tcmalloc may now be a default in the tpu-vm instanecs (check first). jemalloc also worked well for me. If LD_PRELOAD is not set in your env, do the following + +``` + sudo apt update + sudo apt install google-perftools + export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 +``` + +# Train, train, train + +With all the above done, you should be ready to train... below is one particular train command I've just recently been using for some trials on vision MLP models... + +``` + python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --lr 8.8e-4 -b 256 +``` + +NOTE: build my TFDS dataset at ver 5.0.0 and it defaults to a newer version now. Change accordingly. + +# Gotchas and Known Issues +* When PyTorch XLA crashes, you hit a TPU OOM etc, lots of processes get orphaned. Get in the habit of killing all python processes before starting a new train run. + * `alias fml='pkill -f python3'` +* For TFDS use, due to the way PyTorch IterableDatasets work at the loader level, each worker process builds batches independently -- they are not dequeued and collated across workers. For validation especially, getting all the samples evenly divided across BOTH the distributed processes AND the dataset workers is a bit annoying. For now keeping the num_workers arg (j) low is advisable, especially for very small validation sets. This can limit your throughput though. +* Random erasing for on-device XLA tensors doesn't work. XLA isn't compatible with the array slicing approach to my RE impl, currently it's done by default after moving tensors to device. I need to fix. +* There are a number of models using ops that aren't lowered to XLA, this will REALLY slow things down to the point of being unusable. There are flags you can set to debug this, see PyTorch XLA troubleshooting page (https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md) + * For NFNet models, force the ScaledStdConv `use_layernorm` arg to True, it is lowered, `std_mean` op is not +* This code doesn't currently work when float16 is forced via `XLA_USE_BF16=1` env arg, it will mess up metrics tensors that overflow in bfloat16. Better controlling model activation vs weight precision vs other tensors is a TODO. +* Your first batch, and generally first epoch will be slow with Pytorch XLA, after that things pick up and move along quickly. Be patient. + +# Bugs and Discussion + +If you find bugs, there are likely many. Feel free to file an issue with `[BITS]` as the prefix. Open a discussion if you have design ideas, again use `[BITS]` in the title. \ No newline at end of file diff --git a/timm/bits/checkpoint_manager.py b/timm/bits/checkpoint_manager.py index b051e126..b2c692cb 100644 --- a/timm/bits/checkpoint_manager.py +++ b/timm/bits/checkpoint_manager.py @@ -86,7 +86,7 @@ class CheckpointManager: try: if os.path.exists(dst): os.unlink(dst) # required for Windows support. - except Exception as e: + except (OSError, NotImplementedError) as e: self.can_hardlink = False os.replace(src, dst) @@ -98,7 +98,7 @@ class CheckpointManager: os.unlink(dst) os.link(src, dst) return - except Exception as e: + except (OSError, NotImplementedError) as e: self.can_hardlink = False shutil.copy2(src, dst) @@ -153,8 +153,8 @@ class CheckpointManager: for d in to_delete: try: _logger.debug("Cleaning checkpoint: {}".format(d)) - os.remove(d[0]) - except Exception as e: + os.remove(d.path) + except OSError as e: _logger.error("Exception '{}' while deleting checkpoint".format(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] diff --git a/timm/bits/train_cfg.py b/timm/bits/train_cfg.py index d7b35faf..df627809 100644 --- a/timm/bits/train_cfg.py +++ b/timm/bits/train_cfg.py @@ -4,9 +4,9 @@ from dataclasses import dataclass @dataclass class TrainCfg: """ Train Loop Configuration - Dataclass to propagate training configuration values + Dataclass to hold training configuration values """ - num_epochs: int = 0 + num_epochs: int = 100 log_interval: int = 50 recovery_interval: int = 0 accumulate_steps: int = 0 diff --git a/timm/bits/train_services.py b/timm/bits/train_services.py index 5ead002d..d36d8c22 100644 --- a/timm/bits/train_services.py +++ b/timm/bits/train_services.py @@ -8,6 +8,6 @@ from .checkpoint_manager import CheckpointManager class TrainServices: """ Train Loop Services """ - logger: Monitor = None - checkpoint_manager: CheckpointManager = None + monitor: Monitor = None + checkpoint: CheckpointManager = None diff --git a/timm/bits/train_state.py b/timm/bits/train_state.py index 9c47b5fd..91fcf76f 100644 --- a/timm/bits/train_state.py +++ b/timm/bits/train_state.py @@ -6,6 +6,7 @@ from torch import nn as nn from timm.scheduler import Scheduler from timm.utils import get_state_dict, unwrap_model +from .train_cfg import TrainCfg from .updater import Updater @@ -18,6 +19,9 @@ class TrainState: lr_scheduler: Scheduler = None model_ema: nn.Module = None + train_cfg: TrainCfg = TrainCfg() + # FIXME collect & include other cfg like data & model so it's in one spot for checkpoints / logging / debugging? + epoch: int = 0 step_count: int = 0 step_count_global: int = 0 @@ -28,23 +32,33 @@ class TrainState: def state_dict(self, unwrap_fn=unwrap_model): state = dict( + # train loop state (counters, etc), saved and restored epoch=self.epoch, step_count=self.step_count, step_count_global=self.step_count_global, + + # model params / state, saved and restored model=get_state_dict(self.model, unwrap_fn), model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn), + + # configuration, saved but currently not restored, determined by args / config file for each run + train_cfg=vars(self.train_cfg) ) - # FIXME lr_scheduler state save? - state.update(self.updater.state_dict()) + # FIXME include lr_scheduler state? + state.update(self.updater.state_dict()) # updater (optimizer, scaler,e tc) state added to state return state - def load_state_dict(self, state_dict, unwrap_fn=unwrap_model): + def load_state_dict(self, state_dict, unwrap_fn=unwrap_model, load_opt=True): + # restore train loop state self.epoch = state_dict['epoch'] self.step_count = state_dict['step_count'] self.step_count_global = state_dict['step_count_global'] + # restore model params / state unwrap_fn(self.model).load_state_dict(state_dict.get('model')) if 'model_ema' in state_dict and self.model_ema is not None: unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema')) - self.updater.load_state_dict(state_dict) + # restore optimizer state + if load_opt: + self.updater.load_state_dict(state_dict) diff --git a/train.py b/train.py index c6142542..c484ad0d 100755 --- a/train.py +++ b/train.py @@ -287,7 +287,8 @@ def main(): random_seed(args.seed, 0) # Set all random seeds the same for model/state init (mandatory for XLA) - train_state, train_cfg = setup_train_task(args, dev_env, mixup_active) + train_state = setup_train_task(args, dev_env, mixup_active) + train_cfg = train_state.train_cfg # Set random seeds across ranks differently for train # FIXME perhaps keep the same and just set diff seeds for dataloader worker process? what about TFDS? @@ -326,9 +327,12 @@ def main(): f.write(args_text) services = TrainServices( - logger=Monitor( - output_dir=output_dir, logger=_logger, hparams=vars(args), output_enabled=dev_env.primary), - checkpoint_manager=checkpoint_manager, + monitor=Monitor( + output_dir=output_dir, + logger=_logger, + hparams=vars(args), + output_enabled=dev_env.primary), + checkpoint=checkpoint_manager, ) try: @@ -341,7 +345,6 @@ def main(): train_metrics = train_one_epoch( state=train_state, - cfg=train_cfg, services=services, loader=loader_train, dev_env=dev_env, @@ -356,7 +359,7 @@ def main(): train_state.model, train_state.eval_loss, loader_eval, - services.logger, + services.monitor, dev_env) if train_state.model_ema is not None and not args.model_ema_force_cpu: @@ -367,7 +370,7 @@ def main(): train_state.model_ema.module, train_state.eval_loss, loader_eval, - services.logger, + services.monitor, dev_env, phase_suffix='EMA') eval_metrics = ema_eval_metrics @@ -376,8 +379,10 @@ def main(): # step LR for next epoch train_state.lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) - if services.logger is not None: - services.logger.write_summary(index=epoch, results=dict(train=train_metrics, eval=eval_metrics)) + if services.monitor is not None: + services.monitor.write_summary( + index=epoch, + results=dict(train=train_metrics, eval=eval_metrics)) if checkpoint_manager is not None: # save proper checkpoint with eval metric @@ -459,18 +464,21 @@ def setup_train_task(args, dev_env: DeviceEnv, mixup_active: bool): if dev_env.primary: _logger.info('Scheduled epochs: {}'.format(num_epochs)) + train_cfg = TrainCfg( + num_epochs=num_epochs, + log_interval=args.log_interval, + recovery_interval=args.recovery_interval, + ) + train_state = replace( train_state, lr_scheduler=lr_scheduler, train_loss=train_loss_fn, - eval_loss=eval_loss_fn) - - train_cfg = TrainCfg( - num_epochs=num_epochs, - log_interval=args.log_interval, - recovery_interval=args.recovery_interval) + eval_loss=eval_loss_fn, + train_cfg=train_cfg, + ) - return train_state, train_cfg + return train_state def setup_data(args, default_cfg, dev_env, mixup_active): @@ -545,13 +553,12 @@ def setup_data(args, default_cfg, dev_env, mixup_active): def train_one_epoch( state: TrainState, - cfg: TrainCfg, services: TrainServices, loader, dev_env: DeviceEnv, ): tracker = Tracker() - loss_meter = AvgTensor() + loss_meter = AvgTensor() # FIXME move loss meter into task specific TaskMetric state.model.train() state.updater.reset() # zero-grad @@ -573,7 +580,6 @@ def train_one_epoch( state.updater.after_step( after_train_step, state, - cfg, services, dev_env, step_idx, @@ -594,7 +600,6 @@ def train_one_epoch( def after_train_step( state: TrainState, - cfg: TrainCfg, services: TrainServices, dev_env: DeviceEnv, step_idx: int, @@ -603,6 +608,27 @@ def after_train_step( loss_meter: AvgTensor, tensors: Tuple[torch.Tensor, ...], ): + """ + After the core loss / backward / gradient apply step, we perform all non-gradient related + activities here including updating meters, metrics, performing logging, and writing checkpoints. + + Many / most of these operations require tensors to be moved to CPU, they shoud not be done + every step and for XLA use they should be done via the optimizer step_closure. This function includes + everything that should be executed within the step closure. + + Args: + state: + services: + dev_env: + step_idx: + step_end_idx: + tracker: + loss_meter: + tensors: + + Returns: + + """ end_step = step_idx == step_end_idx with torch.no_grad(): @@ -610,16 +636,18 @@ def after_train_step( loss_meter.update(loss, output.shape[0]) if state.model_ema is not None: + # FIXME should ema update be included here or in train / updater step? does it matter? state.model_ema.update(state.model) state = replace(state, step_count_global=state.step_count_global + 1) + cfg = state.train_cfg - if services.logger is not None and end_step or (step_idx + 1) % cfg.log_interval == 0: + if services.monitor is not None and end_step or (step_idx + 1) % cfg.log_interval == 0: global_batch_size = dev_env.world_size * output.shape[0] loss_avg = loss_meter.compute() - if services.logger is not None: + if services.monitor is not None: lr_avg = state.updater.get_average_lr() - services.logger.log_step( + services.monitor.log_step( 'Train', step=step_idx, step_end=step_end_idx, @@ -629,11 +657,12 @@ def after_train_step( lr=lr_avg, ) - if services.checkpoint_manager is not None and cfg.recovery_interval and ( + if services.checkpoint is not None and cfg.recovery_interval and ( end_step or (step_idx + 1) % cfg.recovery_interval == 0): - services.checkpoint_manager.save_recovery(state.epoch, batch_idx=step_idx) + services.checkpoint.save_recovery(state.epoch, batch_idx=step_idx) if state.lr_scheduler is not None: + # FIXME perform scheduler update here or via updater after_step call? state.lr_scheduler.step_update(num_updates=state.step_count_global) @@ -649,7 +678,7 @@ def evaluate( tracker = Tracker() losses_m = AvgTensor() - accuracy_m = AccuracyTopK() + accuracy_m = AccuracyTopK() # FIXME move loss and accuracy modules into task specific TaskMetric obj model.eval() @@ -666,7 +695,9 @@ def evaluate( output = output[0] loss = loss_fn(output, target) - dev_env.mark_step() # FIXME + # FIXME, explictly marking step for XLA use since I'm not using the parallel xm loader + # need to investigate whether parallel loader wrapper is helpful on tpu-vm or only usefor for 2-vm setup. + dev_env.mark_step() tracker.mark_iter_step_end() losses_m.update(loss, output.size(0)) accuracy_m.update(output, target)