Compare commits
101 Commits
main
...
bits_and_t
Author | SHA1 | Date |
---|---|---|
Ross Wightman | a25bf974a9 | 2 years ago |
Ross Wightman | b4ea69c9ce | 2 years ago |
Ross Wightman | 38594ef7fd | 2 years ago |
Ross Wightman | 87bfb055c0 | 2 years ago |
Edoardo Debenedetti | 1dced6066c | 2 years ago |
Ross Wightman | f07dfe010a | 2 years ago |
Edoardo Debenedetti | 5a40c6a3c4 | 2 years ago |
Ross Wightman | bd6d377c74 | 2 years ago |
Ross Wightman | 6fe01993ad | 2 years ago |
Ross Wightman | 1186fc9c73 | 2 years ago |
Ross Wightman | dff33730b3 | 3 years ago |
Ross Wightman | 9c321be330 | 3 years ago |
Edoardo Debenedetti | c76d772670 | 3 years ago |
Ross Wightman | 754e11402a | 3 years ago |
Ross Wightman | 1ba0ec4c18 | 3 years ago |
Ross Wightman | 749856cf25 | 3 years ago |
Ross Wightman | 95739b45d7 | 3 years ago |
Ross Wightman | 5e1be34a60 | 3 years ago |
Ross Wightman | 59ffab537c | 3 years ago |
Ross Wightman | ef57561d51 | 3 years ago |
Ross Wightman | ab16a358bb | 3 years ago |
Ross Wightman | 7eeaf521a0 | 3 years ago |
Ross Wightman | 229ac6b8d8 | 3 years ago |
Ross Wightman | a444d4b891 | 3 years ago |
Ross Wightman | da2796ae82 | 3 years ago |
Ross Wightman | 3fce010ca8 | 3 years ago |
Ross Wightman | 15cc9eae3e | 3 years ago |
Ross Wightman | bb85b09d2a | 3 years ago |
Ross Wightman | 10fa42b143 | 3 years ago |
Ross Wightman | c639a86c67 | 3 years ago |
Ross Wightman | a16ea1e355 | 3 years ago |
Ross Wightman | fafece230b | 3 years ago |
Ross Wightman | 7148039f9f | 3 years ago |
Ross Wightman | f82fb6b608 | 3 years ago |
Ross Wightman | 066e490605 | 3 years ago |
Ross Wightman | 7eb7e73216 | 3 years ago |
Ross Wightman | 4c8bb295ab | 3 years ago |
Ross Wightman | 0012bf7fb5 | 3 years ago |
Ross Wightman | cbc4f33220 | 3 years ago |
Ross Wightman | 40f4745366 | 3 years ago |
Ross Wightman | 1c21cac8f9 | 3 years ago |
Ross Wightman | d829858550 | 3 years ago |
Ross Wightman | 57fca2b5b2 | 3 years ago |
Ross Wightman | 1f54a1fff7 | 3 years ago |
Ross Wightman | 4d7a5544f7 | 3 years ago |
Ross Wightman | 88a5b54802 | 3 years ago |
Ross Wightman | 66daee4f31 | 3 years ago |
Ross Wightman | 7bbbd5ef1b | 3 years ago |
Ross Wightman | ff0f709c20 | 3 years ago |
Ross Wightman | 69e90dcd8c | 3 years ago |
Ross Wightman | 820ae9925e | 3 years ago |
Ross Wightman | 0e212e8fe5 | 3 years ago |
Ross Wightman | cad170e494 | 3 years ago |
Ross Wightman | 809c7bb1ec | 3 years ago |
Ross Wightman | 871cef4198 | 3 years ago |
Ross Wightman | 4f338556d8 | 3 years ago |
Ross Wightman | d9b0b3d60f | 3 years ago |
Ross Wightman | 80ca078aed | 3 years ago |
Ross Wightman | 406c486ba2 | 3 years ago |
Ross Wightman | 07693f81b0 | 3 years ago |
Ross Wightman | 59a3409182 | 3 years ago |
Ross Wightman | a45186a6e8 | 3 years ago |
Ross Wightman | 690f31d02d | 3 years ago |
Ross Wightman | 3b6ba76126 | 3 years ago |
Ross Wightman | 1fdc7af8fd | 3 years ago |
Ross Wightman | 52c481ea8e | 3 years ago |
Ross Wightman | 25d52ea71d | 3 years ago |
Ross Wightman | 3581affb77 | 3 years ago |
Ross Wightman | c2f02b08b8 | 3 years ago |
Ross Wightman | f2e14685a8 | 3 years ago |
Ross Wightman | 2ee398d501 | 3 years ago |
Ross Wightman | f4fb068b11 | 3 years ago |
Ross Wightman | b0265ef8a6 | 3 years ago |
Ross Wightman | 0d82876132 | 3 years ago |
Ross Wightman | b76b48e8e9 | 3 years ago |
Ross Wightman | f98662b9c9 | 3 years ago |
Ross Wightman | cb621e0f00 | 3 years ago |
Ross Wightman | b974d85026 | 3 years ago |
Ross Wightman | c06c739901 | 3 years ago |
Ross Wightman | 40457e5691 | 3 years ago |
Ross Wightman | 5e95ced5a7 | 3 years ago |
Ross Wightman | 56ed0a0b63 | 3 years ago |
Ross Wightman | 847b4af144 | 3 years ago |
Ross Wightman | 5c5cadfe4c | 3 years ago |
Ross Wightman | ee2b8f49ee | 3 years ago |
Ross Wightman | cc870df7b8 | 3 years ago |
Ross Wightman | 6b2d9c2660 | 3 years ago |
Ross Wightman | c3db5f5801 | 3 years ago |
Ross Wightman | f411724de4 | 3 years ago |
Ross Wightman | b57a03bd0d | 3 years ago |
Ross Wightman | 91ab0b6ce5 | 3 years ago |
Ross Wightman | 5b9c69e80a | 4 years ago |
Ross Wightman | 4210d922d2 | 4 years ago |
Ross Wightman | 72ca831dd4 | 4 years ago |
Ross Wightman | cbd4ee737f | 4 years ago |
Ross Wightman | 6d90fcf282 | 4 years ago |
Ross Wightman | 74d2829341 | 4 years ago |
Ross Wightman | aa92d7b1c5 | 4 years ago |
Ross Wightman | 938716c753 | 4 years ago |
Ross Wightman | 76de984a5f | 4 years ago |
Ross Wightman | 12d9a6d4d2 | 4 years ago |
@ -0,0 +1,66 @@
|
||||
"""
|
||||
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
|
||||
|
||||
`torch.distributed.launch` is a module that spawns up multiple distributed
|
||||
training processes on each of the training nodes.
|
||||
|
||||
"""
|
||||
|
||||
|
||||
import sys
|
||||
import subprocess
|
||||
import importlib
|
||||
import os
|
||||
from argparse import ArgumentParser, REMAINDER
|
||||
from typing import Optional, IO
|
||||
|
||||
import torch_xla.distributed.xla_multiprocessing as xmp
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""
|
||||
Helper function parsing the command line options
|
||||
@retval ArgumentParser
|
||||
"""
|
||||
parser = ArgumentParser(
|
||||
description="PyTorch distributed training launch helper utility"
|
||||
"that will spawn up multiple distributed processes")
|
||||
|
||||
# Optional arguments for the launch helper
|
||||
parser.add_argument("--num-devices", type=int, default=1,
|
||||
help="The number of XLA devices to use for distributed training")
|
||||
|
||||
# positional
|
||||
parser.add_argument(
|
||||
"script", type=str,
|
||||
help="The full path to the single device training script to be launched"
|
||||
"in parallel, followed by all the arguments for the training script")
|
||||
|
||||
# rest from the training program
|
||||
parser.add_argument('script_args', nargs=REMAINDER)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# set PyTorch distributed related environmental variables
|
||||
# current_env = os.environ.copy()
|
||||
# current_env["MASTER_ADDR"] = args.master_addr
|
||||
# current_env["MASTER_PORT"] = str(args.master_port)
|
||||
# current_env["WORLD_SIZE"] = str(dist_world_size)
|
||||
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
|
||||
# current_env["OMP_NUM_THREADS"] = str(1)
|
||||
|
||||
script_abs = os.path.abspath(args.script)
|
||||
script_base, script_rel = os.path.split(script_abs)
|
||||
sys.path.append(script_base)
|
||||
mod = importlib.import_module(os.path.splitext(script_rel)[0])
|
||||
|
||||
sys.argv = [args.script] + args.script_args
|
||||
|
||||
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,133 @@
|
||||
# Timm Bits
|
||||
|
||||
## Intro
|
||||
A collection of reusable components and lightweight abstractions for training and evaluating NN with PyTorch and PyTorch XLA.
|
||||
|
||||
This is an early WIP (consider it pre-alpha) with the primary goal to get up and running on TPUs w/ PyTorch XLA as the first priority. Expect significant changes, rewrites, additions, and of course bugs.
|
||||
|
||||
The current train.py and validate.py scripts are evolving to use the timm.bits components, they will also change significantly.
|
||||
|
||||
## Bits Design Brief
|
||||
|
||||
`bits` is designed to be a lightweight and modular set of training abstractions. It shares concepts with other libraries (fastai, ignite, lightning, keras, etc, etc) but is not modeled after any specific one. It is supposed to be a 'bit different', hackable, and is purposely not trying to serve every use case or be everything to everyone.
|
||||
|
||||
`timm` models will always be useable in pure PyTorch w/o `bits` or anything dependencies besides the model utils and helpers for pretrained models, feature extraction, default data config.
|
||||
|
||||
I may breakout bits into a diff project if there is any interest besides my own use for timm image and video model training.
|
||||
|
||||
The layers:
|
||||
* DeviceEnv - DeviceEnv dataclass abstraction handles PyTorch CPU, GPU and XLA device differences, incl distributed functions, parallel wrappers, etc. There is more than a passing similarity to HuggingFace Accelerate, but developed in parallel and with some difference in the detail and separation of concerns.
|
||||
* Updater - A dataclass that combines the backward pass, optimizer step, grad scaling, grad accumulation in device specific abstraction.
|
||||
* Currently, basic single optimizer, single forward/backward Updaters are included for GPU, XLA.
|
||||
* Deepseed will need its own Updater(s) since its Engine is a monolith of epic proportions that breaks all separations of concern in PyTorch (UGH!). NOTE Deepspeed not working yet nor is it a priority.
|
||||
* Monitor - pull together all console logging, csv summaries, tensorboard, and WandB summaries into one module for monitoring your training.
|
||||
* Checkpoint Manager - keeps track of your checkpoints
|
||||
* Metrics - yet another set of metrics, although this may be replaced w/ an external set of classes. Uses same update / reset / compute interface as Ignite and Lightning (in theory interchangeable w/ a thin adapter). Metrics keep state on GPU / TPU to avoid device -> cpu transfers (esp for XLA).
|
||||
* Task (not implemented yet) - combine your model(s) w/ losses in a task specific module, will also allow task factory for easy build of appripriate metrics
|
||||
* TrainState - dataclasses to hold your tasks (models), updater state, etc
|
||||
* Train Loop Functions (still in train.py script, not refined) - set of functions for train step, 'after step', evaluate using all of the components mentioned
|
||||
|
||||
How is this different than other options?
|
||||
* I'm very much trying to avoid a monolithic trainer / learner / model wrapping type class with numerous hooks and overrides to keep track of (avoiding granular inversion of control!).
|
||||
* The goal is to provide reusable modules that can (hopefully) be mixed and matched w/ other code.
|
||||
* Many of the components are based on Python dataclasses to reduce boilerplate.
|
||||
* The train loop components are (will be) functional with easy to follow flow control, and are intended to be replaced when something different is needed, not augmented with hooks via callbacks or inheritence at every conceivable touch point.
|
||||
|
||||
|
||||
## Quick Start for PyTorch XLA on TPU-VM
|
||||
|
||||
Most initial users will likely be interested in training timm models w/ PyTorch XLA on TPU-VM instances, this quick start will get you moving.
|
||||
|
||||
If you haven't noticed, this code is on a branch, make sure you checkout the `bits_and_tpu` branch on `timm` before doing this. You can test locally on your GPU too, in either XLA + GPU in a container or the usual PyTorch w/ GPU.
|
||||
|
||||
## Setup Python environment
|
||||
|
||||
This setup assumes you've SSH'd into your TPU-VM after setting it up (https://cloud.google.com/tpu/docs/users-guide-tpu-vm). Don't forget to do this in a TMUX session or you'll be sad if you lose your connection!
|
||||
|
||||
The TPU-VM instances I've been using have a usable version of PyTorch XLA 1.8.1 installed in the python3 environment, we will be using that.
|
||||
|
||||
I've found that leveraging TFDS w/ datasets in TFRecord format, streamed from Google Storage buckets is the most practical / cost-effective solution. I've written a PyTorch IterabeDataset wrapper around TFDS so we will install Tensorflow datasets and use that. Traditional PyTorch datasets on local disks do work w/ bits for all of TPU-VM, GPU cloud instances, and your local machine. Setting up persistent disks wasn't the easiest thing to do on TPU-VMs so TFDS was my default in that context.
|
||||
|
||||
One thing to watch, be very careful that you don't use a GS based dataset in a different continent from you TPU-VM instances. I burned through a few thousand USD leaving some wires crossed for 1 day. Otherwise the cost of training w/ buckets in same region are quite low.
|
||||
|
||||
### Install TFDS (if using GS buckets)
|
||||
|
||||
```
|
||||
pip3 install tensorflow-datasets
|
||||
```
|
||||
|
||||
In some tpu-vm instances may have tensorflow version pre-installed that conflict with tensorflow-datasets, especially the bucket reading support. If training crashes with errors about inability to ready from buckets, tensorflow symbol errors, tensorflow datasets missing functions, etc, you should try removing the pre-installed tensorflow and installing one from pypi.
|
||||
|
||||
```
|
||||
sudo pip3 uninstall tf-nightly
|
||||
pip3 install tensorflow-cpu
|
||||
```
|
||||
|
||||
You may run into some numpy / pytorch version dependency issues here, try capping the version of tensorflow at `==2.4.1` in above command.
|
||||
|
||||
|
||||
### Get your dataset into buckets
|
||||
|
||||
You will need to host your dataset in buckets. I have tried creating custom datasets for this setup, but have used a number of TFDS datasets such as ImageNet, Flowers, caltech Birds, Oxford Pets that are available in TFDS.
|
||||
|
||||
The TFDS dataset pages (https://www.tensorflow.org/datasets/catalog/imagenet2012) have directions for the various datasets, I recommend building them in a different VM or local machine and then uploading to your training bucket. Many of them will auto-download and build the tfrecord shards for you. ImageNet needs to be downloaded manually.
|
||||
|
||||
### Use a custom allocator
|
||||
|
||||
With PyTorch XLA on a TPU-VM and TFDS you'll end up with a lot of processes and buffering. The instance memory will be used up quickly. I highly recommend using a custom allocator via `LD_PRELOAD`. tcmalloc may now be a default in the tpu-vm instanecs (check first). jemalloc also worked well for me. If LD_PRELOAD is not set in your env, do the following
|
||||
|
||||
```
|
||||
sudo apt update
|
||||
sudo apt install google-perftools
|
||||
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
|
||||
```
|
||||
|
||||
# Train, train, train
|
||||
|
||||
With all the above done, you should be ready to train... below is one particular train command I've just recently been using for some trials on vision MLP models...
|
||||
|
||||
Make sure the TPU config for PyTorch XLA on TPU-VM is set:
|
||||
```
|
||||
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
|
||||
```
|
||||
|
||||
Then, launch fighters!
|
||||
|
||||
```
|
||||
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
|
||||
```
|
||||
|
||||
NOTE: build my TFDS dataset at ver 5.0.0 and it defaults to a newer version now. Change accordingly.
|
||||
|
||||
# Quick Start w/ GPU
|
||||
|
||||
`timm bits` should work great on your multi-GPU setups just like the old `timm` training script with either TFDS based datasets or a local folder.
|
||||
|
||||
The equivalent training command of the XLA setup above if you were on an 8-GPU machine and using TFDS would be,
|
||||
|
||||
```
|
||||
./distrbuted_train.sh 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
|
||||
```
|
||||
|
||||
Or this for imagenet in a local folder,
|
||||
```
|
||||
./distrbuted_train.sh 8 train.py /path/to/imagenet --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
|
||||
```
|
||||
|
||||
# Gotchas and Known Issues
|
||||
* When PyTorch XLA crashes, you hit a TPU OOM etc, lots of processes get orphaned. Get in the habit of killing all python processes before starting a new train run.
|
||||
* `alias fml='pkill -f python3'`
|
||||
* For TFDS use, due to the way PyTorch IterableDatasets work at the loader level, each worker process builds batches independently -- they are not dequeued and collated across workers. For validation especially, getting all the samples evenly divided across BOTH the distributed processes AND the dataset workers is a bit annoying. For now keeping the num_workers arg (j) low is advisable, especially for very small validation sets. This can limit your throughput though.
|
||||
* Random erasing works with PyTorch XLA but it must be done on the images before they are moved into tensors on the XLA device. This changes the dataloader pipelien a bit and increases the size of the data being moved to device (float instead of int8) so has an impact on dataloading speed.
|
||||
* There are a number of models using ops that aren't lowered to XLA, this will REALLY slow things down to the point of being unusable. There are flags you can set to debug this, see PyTorch XLA troubleshooting page (https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md)
|
||||
* This code doesn't currently work when float16 is forced via `XLA_USE_BF16=1` env arg, it will mess up metrics tensors that overflow in bfloat16. Better controlling model activation vs weight precision vs other tensors is a TODO.
|
||||
* I haven't tested this code with pre TPU-VM (2-VM) setups, but it should work w/ correct config. I intend to make it work with Colab and Kaggle TPU notebooks soon.
|
||||
* Your first batch, and generally first epoch will be slow with Pytorch XLA, after that things pick up and move along quickly. Be patient.
|
||||
|
||||
# Bugs and Discussion
|
||||
|
||||
If you find bugs (there are likely many), feel free to file an issue with `[BITS]` as the title prefix. Open a discussion if you have design ideas, again use `[BITS]` in the title.
|
||||
|
||||
# Acknowledgements
|
||||
|
||||
The TPU-VMs I've used for creating and testing this code, and that I hope to use for many future `timm` models were made available by the TPU Research Cloud (https://sites.research.google/trc/).
|
@ -0,0 +1,26 @@
|
||||
from .avg_scalar import AvgMinMaxScalar
|
||||
from .avg_tensor import AvgTensor
|
||||
from .checkpoint_manager import CheckpointManager
|
||||
from .device_env import DeviceEnv, DeviceEnvType, get_global_device, set_global_device, is_global_device
|
||||
from .device_env_cuda import DeviceEnvCuda
|
||||
from .device_env_factory import initialize_device
|
||||
from .device_env_xla import DeviceEnvXla
|
||||
from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
|
||||
all_reduce_sequence, all_gather_sequence
|
||||
# from .evaluate import evaluate, eval_step
|
||||
from .monitor import Monitor
|
||||
from .metric import Metric, MetricValueT
|
||||
from .metric_accuracy import AccuracyTopK
|
||||
from .tracker import Tracker
|
||||
# from .task_metrics import TaskMetrics, TaskMetricsClassify
|
||||
from .train_cfg import TrainCfg
|
||||
from .train_services import TrainServices
|
||||
from .train_setup import setup_model_and_optimizer
|
||||
from .train_state import TrainState
|
||||
# from .task import TaskClassify
|
||||
from .updater import Updater
|
||||
from .updater_cuda import UpdaterCudaWithScaler
|
||||
from .updater_deepspeed import UpdaterDeepSpeed
|
||||
from .updater_factory import create_updater
|
||||
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
|
||||
# from .train import train_one_epoch, Experiment
|
@ -0,0 +1,24 @@
|
||||
class AvgMinMaxScalar:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.min = None
|
||||
self.max = None
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.min = val if self.min is None else min(self.min, val)
|
||||
self.max = val if self.max is None else max(self.max, val)
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
|
@ -0,0 +1,54 @@
|
||||
import torch
|
||||
|
||||
|
||||
class AvgTensor:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self, accumulate_dtype=torch.float32):
|
||||
self.accumulate_dtype = accumulate_dtype
|
||||
self.sum = None
|
||||
self.count = None
|
||||
self.reset()
|
||||
# FIXME handle distributed operation
|
||||
|
||||
def reset(self):
|
||||
self.sum = None
|
||||
self.count = None
|
||||
|
||||
def update(self, val: torch.Tensor, n=1):
|
||||
if self.sum is None:
|
||||
self.sum = torch.zeros_like(val, dtype=self.accumulate_dtype)
|
||||
self.count = torch.tensor(0, dtype=torch.long, device=val.device)
|
||||
self.sum += (val * n)
|
||||
self.count += n
|
||||
|
||||
def compute(self):
|
||||
return self.sum / self.count
|
||||
|
||||
|
||||
class TensorEma:
|
||||
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(
|
||||
self,
|
||||
smoothing_factor=0.9,
|
||||
init_zero=False,
|
||||
accumulate_dtype=torch.float32
|
||||
):
|
||||
self.accumulate_dtype = accumulate_dtype
|
||||
self.smoothing_factor = smoothing_factor
|
||||
self.init_zero = init_zero
|
||||
self.val = None
|
||||
self.reset()
|
||||
# FIXME handle distributed operation
|
||||
|
||||
def reset(self):
|
||||
self.val = None
|
||||
|
||||
def update(self, val):
|
||||
if self.val is None:
|
||||
if self.init_zero:
|
||||
self.val = torch.zeros_like(val, dtype=self.accumulate_dtype)
|
||||
else:
|
||||
self.val = val.clone().to(dtype=self.accumulate_dtype)
|
||||
self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val
|
@ -0,0 +1,109 @@
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from timm.utils import unwrap_model
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .train_state import TrainState
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def save_train_state(
|
||||
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS?
|
||||
train_state: TrainState,
|
||||
extra_state: Dict[str, Any] = None,
|
||||
unwrap_fn: Callable = unwrap_model,
|
||||
dev_env: DeviceEnv = None,
|
||||
log_info: bool = True):
|
||||
|
||||
assert not train_state.updater.deepspeed
|
||||
# DeepSpeed has a fully custom checkpoint saving setup, it is not possible
|
||||
# specify a filename, checkpoints needed to be saved from all ranks, etc
|
||||
# if train_state.updater.deepspeed:
|
||||
# save_train_state_deepspeed(train_state, checkpoint_path)
|
||||
|
||||
dev_env = dev_env or DeviceEnv.instance()
|
||||
state_dict = train_state.state_dict(unwrap_fn=unwrap_fn)
|
||||
if extra_state:
|
||||
state_dict.update(extra_state)
|
||||
if dev_env.type_xla:
|
||||
# XLA state dict needs to be moved to CPU before save, this is normally done by xm.save
|
||||
state_dict = dev_env.state_dict_to_cpu(state_dict)
|
||||
torch.save(state_dict, checkpoint_path)
|
||||
|
||||
|
||||
def load_train_state(
|
||||
train_state: TrainState,
|
||||
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS
|
||||
unwrap_fn: Callable = None,
|
||||
load_opt: bool = True,
|
||||
dev_env: DeviceEnv = None,
|
||||
log_info: bool = True
|
||||
):
|
||||
unwrap_fn = unwrap_fn or unwrap_model
|
||||
if not os.path.isfile(checkpoint_path):
|
||||
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
if log_info:
|
||||
_logger.info('Restoring training state from checkpoint...')
|
||||
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
assert isinstance(checkpoint, dict)
|
||||
|
||||
if not checkpoint.get('version', 0) > 2:
|
||||
load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info)
|
||||
if log_info:
|
||||
_logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
|
||||
return
|
||||
|
||||
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn, load_opt=load_opt)
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
|
||||
|
||||
|
||||
def _get_state_dict(checkpoint, state_dict_key='state_dict'):
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint[state_dict_key].items():
|
||||
name = k[7:] if k.startswith('module') else k
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def load_legacy_checkpoint(
|
||||
train_state: TrainState,
|
||||
checkpoint,
|
||||
load_opt=True,
|
||||
log_info=True):
|
||||
|
||||
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
|
||||
train_state.model.load_state_dict(_get_state_dict(checkpoint))
|
||||
|
||||
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring model (EMA) state from checkpoint...')
|
||||
unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema'))
|
||||
|
||||
if load_opt:
|
||||
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring optimizer state from checkpoint...')
|
||||
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
scaler_state_dict_key = 'amp_scaler'
|
||||
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||||
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
|
||||
|
||||
if 'epoch' in checkpoint:
|
||||
resume_epoch = checkpoint['epoch']
|
||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only
|
||||
|
@ -0,0 +1,219 @@
|
||||
""" Checkpoint Manager
|
||||
|
||||
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
|
||||
|
||||
Hacked together by / Copyright 2021 Ross Wightman
|
||||
"""
|
||||
import glob
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional, Dict, Callable, List
|
||||
from dataclasses import dataclass, replace
|
||||
|
||||
|
||||
from .checkpoint import save_train_state
|
||||
from .train_state import TrainState
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointInfo:
|
||||
path: str = ''
|
||||
metrics: Dict[str, float] = None # all metrics at time of checkpoint save
|
||||
metric_name: str = 'loss'
|
||||
metric_decreasing: bool = True
|
||||
epoch: int = 0
|
||||
global_step: int = 0
|
||||
|
||||
@property
|
||||
def valid_key(self):
|
||||
return self.metric_name and self.metrics and self.metric_name in self.metrics
|
||||
|
||||
@property
|
||||
def sort_key(self):
|
||||
return self.metrics[self.metric_name] if self.valid_key else self.epoch
|
||||
|
||||
@property
|
||||
def decreasing_key(self):
|
||||
return self.metric_decreasing if self.valid_key else False
|
||||
|
||||
|
||||
class CheckpointManager:
|
||||
def __init__(
|
||||
self,
|
||||
hparams=None,
|
||||
save_state_fn=None,
|
||||
checkpoint_dir='',
|
||||
recovery_dir='',
|
||||
checkpoint_tmpl=None,
|
||||
recovery_tmpl=None,
|
||||
metric_name='loss',
|
||||
metric_decreasing=True,
|
||||
max_history=10):
|
||||
|
||||
# extra items to include in checkpoint
|
||||
self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system
|
||||
|
||||
# state
|
||||
self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness
|
||||
self.best_checkpoint = None
|
||||
self.curr_recovery_file = ''
|
||||
self.prev_recovery_file = ''
|
||||
self.can_hardlink = True
|
||||
|
||||
# util / helper fn
|
||||
self.save_state_fn = save_state_fn or save_train_state
|
||||
|
||||
# file / folder config
|
||||
self.extension = '.pth.tar'
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.recovery_dir = recovery_dir
|
||||
self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension
|
||||
self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension
|
||||
|
||||
# ordering / history config
|
||||
self.metric_name = metric_name
|
||||
self.metric_decreasing = metric_decreasing
|
||||
self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt
|
||||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
|
||||
def _replace(self, src, dst):
|
||||
if self.can_hardlink:
|
||||
try:
|
||||
if os.path.exists(dst):
|
||||
os.unlink(dst) # required for Windows support.
|
||||
except (OSError, NotImplementedError) as e:
|
||||
self.can_hardlink = False
|
||||
os.replace(src, dst)
|
||||
|
||||
def _duplicate(self, src, dst):
|
||||
if self.can_hardlink:
|
||||
try:
|
||||
if os.path.exists(dst):
|
||||
# for Windows
|
||||
os.unlink(dst)
|
||||
os.link(src, dst)
|
||||
return
|
||||
except (OSError, NotImplementedError) as e:
|
||||
self.can_hardlink = False
|
||||
shutil.copy2(src, dst)
|
||||
|
||||
def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
|
||||
extra_state = dict(
|
||||
# version < 2 increments epoch before save
|
||||
# version < 3, pre timm bits
|
||||
# version 3, first timm bits checkpoitns
|
||||
version=3,
|
||||
)
|
||||
if self.hparams is not None:
|
||||
extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams))
|
||||
else:
|
||||
arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None)
|
||||
if arch is None:
|
||||
arch = type(train_state.model).__name__.lower()
|
||||
extra_state.update(dict(arch=arch))
|
||||
if metrics is not None:
|
||||
# save the metrics and how we originally sorted them in the checkpoint for future comparisons
|
||||
extra_state.update(dict(
|
||||
metrics=metrics,
|
||||
metric_name=self.metric_name,
|
||||
metric_decreasing=self.metric_decreasing
|
||||
))
|
||||
|
||||
self.save_state_fn(save_path, train_state, extra_state)
|
||||
|
||||
checkpoint_info = CheckpointInfo(
|
||||
path=save_path,
|
||||
metrics=metrics,
|
||||
metric_name=self.metric_name,
|
||||
metric_decreasing=self.metric_decreasing,
|
||||
epoch=train_state.epoch,
|
||||
global_step=train_state.step_count_global,
|
||||
)
|
||||
return checkpoint_info
|
||||
|
||||
def _udpate_checkpoints(self, info: CheckpointInfo):
|
||||
self.checkpoint_files.append(info)
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files,
|
||||
key=lambda x: x.sort_key,
|
||||
reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better
|
||||
)
|
||||
|
||||
def _cleanup_checkpoints(self, trim=0):
|
||||
trim = min(len(self.checkpoint_files), trim)
|
||||
delete_index = self.max_history - trim
|
||||
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
|
||||
return
|
||||
to_delete = self.checkpoint_files[delete_index:]
|
||||
for d in to_delete:
|
||||
try:
|
||||
_logger.debug("Cleaning checkpoint: {}".format(d))
|
||||
os.remove(d.path)
|
||||
except OSError as e:
|
||||
_logger.error("Exception '{}' while deleting checkpoint".format(e))
|
||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||
|
||||
def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo):
|
||||
# compare metrics against an existing checkpoint
|
||||
if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key:
|
||||
# always assume lhs metrics are better if there are no usable metrics to compare
|
||||
return True
|
||||
return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key)
|
||||
|
||||
def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
|
||||
assert train_state.epoch >= 0
|
||||
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
|
||||
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
|
||||
curr_checkpoint = self._save(tmp_save_path, train_state, metrics)
|
||||
self._replace(tmp_save_path, last_save_path)
|
||||
|
||||
worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint):
|
||||
if len(self.checkpoint_files) >= self.max_history:
|
||||
self._cleanup_checkpoints(1)
|
||||
|
||||
filename = self.checkpoint_tmpl.format(index=train_state.epoch)
|
||||
save_path = os.path.join(self.checkpoint_dir, filename)
|
||||
curr_checkpoint = replace(curr_checkpoint, path=save_path)
|
||||
self._duplicate(last_save_path, save_path)
|
||||
self._udpate_checkpoints(curr_checkpoint)
|
||||
|
||||
checkpoints_str = "Current checkpoints:\n"
|
||||
for c in self.checkpoint_files:
|
||||
checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c)
|
||||
_logger.info(checkpoints_str)
|
||||
|
||||
if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint):
|
||||
self.best_checkpoint = curr_checkpoint
|
||||
best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension)
|
||||
self._duplicate(last_save_path, best_save_path)
|
||||
|
||||
return curr_checkpoint if self.best_checkpoint is None else self.best_checkpoint
|
||||
|
||||
def save_recovery(self, train_state: TrainState):
|
||||
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
|
||||
self._save(tmp_save_path, train_state)
|
||||
|
||||
filename = self.recovery_tmpl.format(index=train_state.step_count_global)
|
||||
save_path = os.path.join(self.recovery_dir, filename)
|
||||
self._replace(tmp_save_path, save_path)
|
||||
|
||||
if os.path.exists(self.prev_recovery_file):
|
||||
try:
|
||||
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
|
||||
os.remove(self.prev_recovery_file)
|
||||
except Exception as e:
|
||||
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
|
||||
self.prev_recovery_file = self.curr_recovery_file
|
||||
self.curr_recovery_file = save_path
|
||||
|
||||
def find_recovery(self):
|
||||
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
|
||||
files = glob.glob(recovery_path + '*' + self.extension)
|
||||
files = sorted(files)
|
||||
return files[0] if len(files) else ''
|
@ -0,0 +1,191 @@
|
||||
import abc
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Callable, Union, Optional, List, Tuple, Dict, Any
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
|
||||
|
||||
|
||||
class DeviceEnvType(Enum):
|
||||
""" Device Environment Types
|
||||
"""
|
||||
CPU = "cpu"
|
||||
CUDA = "cuda"
|
||||
XLA = "xla"
|
||||
|
||||
|
||||
def state_dict_apply(state_dict: Dict[str, Any], apply_fn, select_fn=lambda x: x.isinstance(torch.Tensor)):
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if isinstance(v, dict):
|
||||
out_dict[k] = state_dict_apply(v, apply_fn, select_fn)
|
||||
else:
|
||||
out_dict[k] = apply_fn(v) if select_fn(v) else v
|
||||
return out_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceEnv:
|
||||
device_type: InitVar[Optional[str]] = None
|
||||
device_index: InitVar[Optional[int]] = None
|
||||
channels_last: InitVar[bool] = False
|
||||
|
||||
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
|
||||
world_size: Optional[int] = None # set by post_init from env when None
|
||||
local_rank: Optional[int] = None # set by post_init from env when None
|
||||
global_rank: Optional[int] = None # set by post_init from env when None
|
||||
amp: bool = False
|
||||
autocast: Optional[Callable] = None # set by post_init from env when None
|
||||
memory_format: Optional[torch.memory_format] = None
|
||||
dtype: Optional[torch.dtype] = None
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
device_type: Optional[str],
|
||||
device_index: Optional[int],
|
||||
channels_last: bool,
|
||||
):
|
||||
device_type = device_type or 'cpu'
|
||||
self.device = torch.device(device_type) if device_index is None \
|
||||
else torch.device(device_type, device_index)
|
||||
self.world_size = 1 if self.world_size is None else self.world_size
|
||||
self.local_rank = 0 if self.local_rank is None else self.local_rank
|
||||
self.global_rank = 0 if self.global_rank is None else self.global_rank
|
||||
if self.autocast is None:
|
||||
self.autocast = suppress
|
||||
if channels_last:
|
||||
self.memory_format = torch.channels_last
|
||||
|
||||
@staticmethod
|
||||
def is_instance():
|
||||
return is_global_device()
|
||||
|
||||
@staticmethod
|
||||
def instance():
|
||||
# throws if called before global device is set / initialized
|
||||
return get_global_device()
|
||||
|
||||
@property
|
||||
def type(self) -> DeviceEnvType:
|
||||
if self.device.type == 'cpu':
|
||||
return DeviceEnvType.CPU
|
||||
elif self.device.type == 'cuda':
|
||||
return DeviceEnvType.CUDA
|
||||
elif self.device.type == 'xla':
|
||||
return DeviceEnvType.XLA
|
||||
else:
|
||||
assert False, "Unexpected device type for base DevEnv impl."
|
||||
|
||||
@property
|
||||
def type_cuda(self):
|
||||
# shortcut for common cuda device type
|
||||
return self.type == DeviceEnvType.CUDA
|
||||
|
||||
@property
|
||||
def type_xla(self):
|
||||
# shortcut for common xla device type
|
||||
return self.type == DeviceEnvType.XLA
|
||||
|
||||
@property
|
||||
def distributed(self):
|
||||
return self.world_size > 1
|
||||
|
||||
@property
|
||||
def primary(self):
|
||||
return self.local_rank == 0
|
||||
|
||||
@property
|
||||
def global_primary(self):
|
||||
return self.global_rank == 0
|
||||
|
||||
def wrap_distributed(self, *modules):
|
||||
pass
|
||||
|
||||
def wrap_parallel(self, *modules):
|
||||
pass
|
||||
|
||||
def to_cpu(self, *modules: torch.nn.Module):
|
||||
moved = [m.cpu() for m in modules]
|
||||
return moved[0] if len(moved) == 1 else moved
|
||||
|
||||
def to_device(self, *modules: torch.nn.Module):
|
||||
# FIXME handling dtype? Do we want separate dtype for data vs model?
|
||||
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
|
||||
return moved[0] if len(moved) == 1 else moved
|
||||
|
||||
def state_dict_to_cpu(self, state: Dict[str, Any]):
|
||||
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.cpu())
|
||||
return cpu_state
|
||||
|
||||
def state_dict_to_device(self, state: Dict[str, Any]):
|
||||
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.to(self.device))
|
||||
return cpu_state
|
||||
|
||||
def mark_step(self):
|
||||
pass # NO-OP for non-XLA devices
|
||||
|
||||
def synchronize(self, tensors: Optional[TensorList] = None):
|
||||
pass
|
||||
|
||||
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
|
||||
dist.all_reduce(tensor, op=op)
|
||||
if average:
|
||||
tensor.div_(self.world_size)
|
||||
return tensor
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False):
|
||||
reduce_tensor = tensor.clone()
|
||||
dist.all_reduce(reduce_tensor, op=op)
|
||||
if average:
|
||||
reduce_tensor = reduce_tensor / self.world_size
|
||||
return reduce_tensor
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
|
||||
output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)]
|
||||
dist.all_gather(output_tensors, tensor)
|
||||
return torch.cat(output_tensors, cat_dim)
|
||||
|
||||
def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0):
|
||||
input_tensors = torch.chunk(tensor, num_splits, split_dim)
|
||||
output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)]
|
||||
dist.all_to_all(output_tensors, input_tensors)
|
||||
return torch.cat(output_tensors, cat_dim)
|
||||
|
||||
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
|
||||
dist.broadcast(tensor, src=src_rank)
|
||||
return tensor
|
||||
|
||||
def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0):
|
||||
if self.global_rank != src_rank:
|
||||
tensor = torch.empty_like(tensor)
|
||||
assert tensor is not None
|
||||
dist.broadcast(tensor, src=src_rank)
|
||||
return tensor
|
||||
|
||||
def barrier(self):
|
||||
dist.barrier()
|
||||
|
||||
|
||||
# Global device environment singleton instance
|
||||
_global_device_env: Optional[DeviceEnv] = None
|
||||
|
||||
|
||||
def is_global_device():
|
||||
return _global_device_env is not None
|
||||
|
||||
|
||||
def get_global_device() -> DeviceEnv:
|
||||
if not is_global_device():
|
||||
raise RuntimeError('Please initialize device environment by calling initialize_device / set_global_device.')
|
||||
return _global_device_env
|
||||
|
||||
|
||||
def set_global_device(device: DeviceEnv):
|
||||
global _global_device_env
|
||||
if _global_device_env is not None:
|
||||
raise RuntimeError('Global device is already set, it should NOT be set again.')
|
||||
_global_device_env = device
|
@ -0,0 +1,68 @@
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel, DataParallel
|
||||
|
||||
from .device_env import DeviceEnv, DeviceEnvType, TensorList
|
||||
|
||||
|
||||
def is_cuda_available():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceEnvCuda(DeviceEnv):
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
device_type: Optional[str],
|
||||
device_index: Optional[int],
|
||||
channels_last: bool,
|
||||
):
|
||||
assert torch.cuda.device_count()
|
||||
torch.backends.cudnn.benchmark = True
|
||||
setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
|
||||
assert setup_world_size
|
||||
if setup_world_size > 1:
|
||||
# setup distributed
|
||||
assert device_index is None
|
||||
if self.local_rank is None:
|
||||
lr = os.environ.get('LOCAL_RANK', None)
|
||||
if lr is None:
|
||||
raise RuntimeError(
|
||||
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
|
||||
self.local_rank = int(lr)
|
||||
self.device = torch.device('cuda:%d' % self.local_rank)
|
||||
torch.cuda.set_device(self.local_rank)
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
self.world_size = torch.distributed.get_world_size()
|
||||
assert self.world_size == setup_world_size
|
||||
self.global_rank = torch.distributed.get_rank()
|
||||
else:
|
||||
self.device = torch.device('cuda' if device_index is None else f'cuda:{device_index}')
|
||||
self.local_rank = 0
|
||||
self.world_size = 1
|
||||
self.global_rank = 0
|
||||
if self.autocast is None:
|
||||
self.autocast = torch.cuda.amp.autocast if self.amp else suppress
|
||||
if channels_last:
|
||||
self.memory_format = torch.channels_last
|
||||
|
||||
@property
|
||||
def type(self) -> DeviceEnvType:
|
||||
return DeviceEnvType.CUDA
|
||||
|
||||
def wrap_distributed(self, *modules, **kwargs):
|
||||
wrapped = [DistributedDataParallel(m, device_ids=[self.local_rank], **kwargs) for m in modules]
|
||||
return wrapped[0] if len(wrapped) == 1 else wrapped
|
||||
|
||||
def wrap_parallel(self, *modules, **kwargs):
|
||||
assert not self.distributed
|
||||
wrapped = [DataParallel(m, **kwargs) for m in modules]
|
||||
return wrapped[0] if len(wrapped) == 1 else wrapped
|
||||
|
||||
def synchronize(self, tensors: Optional[TensorList] = None):
|
||||
torch.cuda.synchronize(self.device)
|
@ -0,0 +1,36 @@
|
||||
import logging
|
||||
|
||||
from .device_env import DeviceEnv, is_global_device, get_global_device, set_global_device
|
||||
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
|
||||
from .device_env_xla import DeviceEnvXla, is_xla_available
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
|
||||
if is_global_device():
|
||||
return get_global_device()
|
||||
|
||||
denv = None
|
||||
if not force_cpu:
|
||||
xla_device_type = kwargs.get('xla_device_type', None)
|
||||
if is_xla_available(xla_device_type):
|
||||
# XLA supports more than just TPU, will search in order TPU, GPU, CPU
|
||||
denv = DeviceEnvXla(**kwargs)
|
||||
elif is_cuda_available():
|
||||
denv = DeviceEnvCuda(**kwargs)
|
||||
|
||||
# CPU fallback
|
||||
if denv is None:
|
||||
if is_xla_available('CPU'):
|
||||
denv = DeviceEnvXla(device_type='CPU', **kwargs)
|
||||
else:
|
||||
denv = DeviceEnv()
|
||||
|
||||
_logger.info(f'Initialized device {denv.device}. '
|
||||
f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.')
|
||||
print(denv) # FIXME temporary print for debugging
|
||||
|
||||
set_global_device(denv)
|
||||
return denv
|
||||
|
@ -0,0 +1,136 @@
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Optional, Dict
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla
|
||||
_HAS_XLA = True
|
||||
except ImportError as e:
|
||||
xm = None
|
||||
torch_xla = None
|
||||
_HAS_XLA = False
|
||||
|
||||
try:
|
||||
# only the very latest XLA builds have AMP
|
||||
import torch_xla.amp as xa
|
||||
except ImportError as e:
|
||||
xa = None
|
||||
|
||||
from .device_env import DeviceEnv, DeviceEnvType, TensorList
|
||||
|
||||
|
||||
_PT_TO_XM_OP = {
|
||||
ReduceOp.SUM: 'sum',
|
||||
ReduceOp.PRODUCT: 'mul',
|
||||
ReduceOp.MIN: 'min',
|
||||
ReduceOp.MAX: 'max',
|
||||
ReduceOp.BAND: 'and',
|
||||
ReduceOp.BOR: 'or',
|
||||
}
|
||||
|
||||
|
||||
def is_xla_available(xla_device_type=None):
|
||||
if not _HAS_XLA:
|
||||
return False
|
||||
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
|
||||
return len(supported_devs) >= 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceEnvXla(DeviceEnv):
|
||||
|
||||
def __post_init__(
|
||||
self,
|
||||
device_type: Optional[str],
|
||||
device_idx: Optional[int],
|
||||
channels_last: bool,
|
||||
):
|
||||
if device_type is not None:
|
||||
device_type = device_type.upper()
|
||||
assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
|
||||
self.device = xm.xla_device(n=device_idx, devkind=device_type)
|
||||
self.world_size = xm.xrt_world_size()
|
||||
if self.distributed:
|
||||
assert device_idx is None, "device_index is based on local rank for distributed XLA mode"
|
||||
self.local_rank = xm.get_local_ordinal()
|
||||
self.global_rank = xm.get_ordinal()
|
||||
else:
|
||||
self.local_rank = 0
|
||||
self.global_rank = 0
|
||||
if self.amp:
|
||||
assert xa is not None, 'XLA AMP is not present on this build'
|
||||
if self.autocast is None:
|
||||
self.autocast = xa.autocast if self.amp else suppress
|
||||
if channels_last:
|
||||
self.memory_format = torch.channels_last
|
||||
|
||||
@property
|
||||
def type(self) -> DeviceEnvType:
|
||||
return DeviceEnvType.XLA
|
||||
|
||||
def wrap_distributed(self, *modules):
|
||||
wrapped = [m for m in modules] # NO-OP
|
||||
return wrapped[0] if len(wrapped) == 1 else wrapped
|
||||
|
||||
def wrap_parallel(self, *modules):
|
||||
assert False, "Not implemented"
|
||||
|
||||
def mark_step(self):
|
||||
xm.mark_step()
|
||||
|
||||
def synchronize(self, tensors: Optional[TensorList] = None):
|
||||
torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True)
|
||||
|
||||
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
|
||||
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
|
||||
op = _PT_TO_XM_OP[op]
|
||||
scale = 1.0 / self.world_size if average else 1.0
|
||||
return xm.all_reduce(op, tensor, scale=scale)
|
||||
|
||||
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
|
||||
op = _PT_TO_XM_OP[op]
|
||||
scale = 1.0 / self.world_size if average else 1.0
|
||||
wrapped = False
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
tensor = [tensor] # bare tensors are not operated on in-place
|
||||
wrapped = True
|
||||
xm.all_reduce(op, tensor, scale=scale)
|
||||
if wrapped:
|
||||
tensor = tensor[0]
|
||||
return tensor
|
||||
|
||||
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
|
||||
output = xm.all_gather(tensor, cat_dim)
|
||||
return output
|
||||
|
||||
def all_to_all(self, tensor, num_splits, split_dim, cat_dim=0):
|
||||
output = xm.all_to_all(tensor, split_dim, cat_dim, num_splits)
|
||||
return output
|
||||
|
||||
def broadcast(self, tensor: torch.Tensor, src_rank=0):
|
||||
if self.global_rank != src_rank:
|
||||
reduce_tensor = torch.zeros_like(tensor)
|
||||
xm.all_reduce('sum', reduce_tensor)
|
||||
else:
|
||||
xm.all_reduce('sum', tensor)
|
||||
return tensor
|
||||
|
||||
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
|
||||
out_tensor = self.broadcast(tensor, src_rank)
|
||||
return tensor.copy_(out_tensor)
|
||||
|
||||
def barrier(self):
|
||||
xm.rendezvous('timm.bits.dist_barrier')
|
||||
|
||||
def state_dict_to_cpu(self, state: Dict[str, torch.Tensor]):
|
||||
cpu_state = xm._maybe_convert_to_cpu(state, convert=True)
|
||||
return cpu_state
|
||||
|
||||
def state_dict_to_device(self, state: Dict[str, torch.Tensor]):
|
||||
device_state = xm.send_cpu_data_to_device(state, device=self.device)
|
||||
return device_state
|
@ -0,0 +1,150 @@
|
||||
from typing import Dict, Tuple, List, Union, Any, Callable
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from timm.utils import unwrap_model
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
|
||||
|
||||
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
|
||||
|
||||
|
||||
def _validate_type(tensor: TensorSeq):
|
||||
if isinstance(tensor, (dict, list, tuple)):
|
||||
if not tensor:
|
||||
return
|
||||
else:
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
|
||||
|
||||
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
# ensure every node has the same running bn stats
|
||||
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
|
||||
if ('running_mean' in bn_name) or ('running_var' in bn_name):
|
||||
if reduce:
|
||||
# average bn stats across whole group
|
||||
dev_env.all_reduce_(bn_buf, average=True)
|
||||
else:
|
||||
# broadcast bn stats from rank 0 to whole group
|
||||
dev_env.broadcast_(bn_buf, 0)
|
||||
|
||||
|
||||
def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None):
|
||||
""" Recursive all gather via DeviceEnv distributed primitives
|
||||
FIXME add group support
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return dev_env.all_gather(tensor, cat_dim=cat_dim)
|
||||
elif isinstance(tensor, dict):
|
||||
return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor)
|
||||
|
||||
|
||||
def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
|
||||
""" Recursive all reduce via DeviceEnv distributed primitives
|
||||
FIXME add group support
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return dev_env.all_reduce_(tensor, op=op, average=average)
|
||||
elif isinstance(tensor, dict):
|
||||
return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor)
|
||||
|
||||
|
||||
def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None):
|
||||
""" Recursive broadcast via DeviceEnv distributed primitives
|
||||
FIXME add group support
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return dev_env.broadcast_(tensor, src_rank=src_rank)
|
||||
elif isinstance(tensor, dict):
|
||||
return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor)
|
||||
|
||||
|
||||
def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None):
|
||||
""" All gather a flat Tensor sequence (dict, list, tuple) of same shape
|
||||
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
# merge values into one tensor for reduction
|
||||
if isinstance(tensor, dict):
|
||||
names = tensor.keys()
|
||||
gather_values = tuple(tensor.values())
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
gather_values = tensor
|
||||
else:
|
||||
gather_values = (tensor,)
|
||||
|
||||
gather_values = torch.stack(gather_values, dim=0)
|
||||
gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0)
|
||||
|
||||
# separate reduced values into original structure
|
||||
if isinstance(tensor, dict):
|
||||
gather_values = {k: v for k, v in zip(names, gather_values)}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
gather_values = type(tensor)(v for v in gather_values)
|
||||
else:
|
||||
gather_values = gather_values[0]
|
||||
|
||||
return gather_values
|
||||
|
||||
|
||||
def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
|
||||
"""
|
||||
All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape
|
||||
|
||||
Args:
|
||||
tensor (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a sequence with the same type as input (dict, list, tuple)
|
||||
"""
|
||||
_validate_type(tensor)
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
# merge values into one tensor for reduction
|
||||
if isinstance(tensor, dict):
|
||||
names = tensor.keys()
|
||||
reduce_values = tuple(tensor.values())
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
reduce_values = tensor
|
||||
else:
|
||||
reduce_values = (tensor,)
|
||||
|
||||
reduce_values = torch.stack(reduce_values, dim=0)
|
||||
dev_env.all_reduce_(reduce_values, op=op, average=average)
|
||||
reduce_values = reduce_values.unbind(dim=0)
|
||||
# separate reduced values into original structure
|
||||
if isinstance(tensor, dict):
|
||||
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||||
elif isinstance(tensor, (tuple, list)):
|
||||
reduce_values = type(tensor)(v for v in reduce_values)
|
||||
else:
|
||||
reduce_values = reduce_values[0]
|
||||
|
||||
return reduce_values
|
@ -0,0 +1,190 @@
|
||||
""" PyTorch distributed helpers
|
||||
|
||||
Some of this lifted from Detectron2 with other fns added by myself.
|
||||
|
||||
FIXME many functions remain unfinished/untested
|
||||
"""
|
||||
from typing import Dict, Tuple, List, Union, Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
|
||||
|
||||
|
||||
def synchronize_torch():
|
||||
"""
|
||||
Helper function to synchronize (barrier) among all processes when
|
||||
using distributed training
|
||||
"""
|
||||
if not dist.is_available():
|
||||
return
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
world_size = dist.get_world_size()
|
||||
if world_size == 1:
|
||||
return
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
|
||||
"""
|
||||
All reduce the tensors in a sequence (dict, list, tuple)
|
||||
|
||||
Args:
|
||||
values (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a sequence with the same type as input (dict, list, tuple)
|
||||
"""
|
||||
world_size = dist.get_world_size(group)
|
||||
if world_size <= 1:
|
||||
return values
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
if isinstance(values, dict):
|
||||
names = values.keys()
|
||||
reduce_values = torch.stack(tuple(values.values()), dim=0)
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = torch.stack(values, dim=0)
|
||||
else:
|
||||
reduce_values = values
|
||||
dist.all_reduce(reduce_values, op=op, group=group)
|
||||
if average:
|
||||
reduce_values /= world_size
|
||||
if isinstance(values, dict):
|
||||
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = type(values)(v for v in reduce_values)
|
||||
return reduce_values
|
||||
|
||||
|
||||
def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None):
|
||||
"""
|
||||
All reduce the tensors in a sequence (dict, list, tuple)
|
||||
|
||||
Args:
|
||||
values (dict): inputs to be reduced. All the values must be scalar Tensor.
|
||||
average (bool): whether to do average or sum
|
||||
Returns:
|
||||
a sequence with the same type as input (dict, list, tuple)
|
||||
"""
|
||||
world_size = dist.get_world_size(group)
|
||||
this_rank = dist.get_rank()
|
||||
if world_size <= 1:
|
||||
return values
|
||||
|
||||
with torch.no_grad():
|
||||
names = None
|
||||
if isinstance(values, dict):
|
||||
names = values.keys()
|
||||
reduce_values = torch.stack(tuple(values.values()), dim=0)
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = torch.stack(values, dim=0)
|
||||
else:
|
||||
reduce_values = values
|
||||
reduce_values = torch.stack(reduce_values, dim=0)
|
||||
dist.reduce(reduce_values, dst=dst_rank, op=op, group=group)
|
||||
if average and this_rank == dst_rank:
|
||||
reduce_values /= world_size
|
||||
if isinstance(values, dict):
|
||||
reduce_values = {k: v for k, v in zip(names, reduce_values)}
|
||||
elif isinstance(values, (tuple, list)):
|
||||
reduce_values = type(values)(v for v in reduce_values)
|
||||
return reduce_values
|
||||
|
||||
|
||||
def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0):
|
||||
world_size = dist.get_world_size(group)
|
||||
|
||||
def _do_gather(tensor):
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.all_gather(tensor_list, tensor, group=group)
|
||||
return join_fn(tensor_list, dim=join_dim)
|
||||
|
||||
if isinstance(values, dict):
|
||||
gathered = {k: _do_gather(v) for k, v in values.items()}
|
||||
return gathered
|
||||
elif isinstance(values, (list, tuple)):
|
||||
gathered = type(values)(_do_gather(v) for v in values)
|
||||
return gathered
|
||||
else:
|
||||
# if not a dict, list, tuple, expect a singular tensor
|
||||
assert isinstance(values, torch.Tensor)
|
||||
return _do_gather(values)
|
||||
|
||||
|
||||
def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0):
|
||||
world_size = dist.get_world_size(group)
|
||||
this_rank = dist.get_rank(group)
|
||||
|
||||
def _do_gather(tensor):
|
||||
tensor_list = None
|
||||
if this_rank == dst_rank:
|
||||
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
|
||||
dist.gather(tensor, tensor_list, dst=dst_rank, group=group)
|
||||
return join_fn(tensor_list, dim=join_dim)
|
||||
|
||||
if isinstance(values, dict):
|
||||
gathered = {k: _do_gather(v) for k, v in values.items()}
|
||||
return gathered
|
||||
elif isinstance(values, (list, tuple)):
|
||||
gathered = type(values)(_do_gather(v) for v in values)
|
||||
return gathered
|
||||
else:
|
||||
# if not a dict, list, tuple, expect a singular tensor
|
||||
assert isinstance(values, torch.Tensor)
|
||||
return _do_gather(values)
|
||||
|
||||
|
||||
def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0):
|
||||
if isinstance(value, torch.Tensor):
|
||||
world_size = dist.get_world_size(group)
|
||||
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
|
||||
dist.all_gather(out_tensors, value, group=group)
|
||||
if join_fn is not None:
|
||||
out_tensors = join_fn(out_tensors, dim=join_dim)
|
||||
return out_tensors
|
||||
elif isinstance(value, dict):
|
||||
return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value)
|
||||
|
||||
|
||||
def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0):
|
||||
if isinstance(value, torch.Tensor):
|
||||
world_size = dist.get_world_size(group)
|
||||
this_rank = dist.get_rank()
|
||||
out_tensors = None
|
||||
if this_rank == dst_rank:
|
||||
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
|
||||
dist.gather(value, out_tensors, dst=dst_rank, group=group)
|
||||
if join_fn is not None:
|
||||
out_tensors = join_fn(out_tensors, dim=join_dim)
|
||||
return out_tensors
|
||||
elif isinstance(value, dict):
|
||||
return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value)
|
||||
|
||||
|
||||
def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
|
||||
if isinstance(value, torch.Tensor):
|
||||
dist.all_reduce(value, op=op, group=group)
|
||||
if average:
|
||||
value /= dist.get_world_size(group)
|
||||
elif isinstance(value, dict):
|
||||
return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value)
|
||||
|
||||
|
||||
def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None):
|
||||
if isinstance(value, torch.Tensor):
|
||||
return dist.broadcast(value, src=src_rank, group=group)
|
||||
elif isinstance(value, dict):
|
||||
return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()}
|
||||
elif isinstance(value, (tuple, list)):
|
||||
return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value)
|
@ -0,0 +1,26 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
|
||||
from timm.utils.agc import adaptive_clip_grad
|
||||
|
||||
|
||||
def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0):
|
||||
if mode == 'norm':
|
||||
return partial(torch.nn.utils.clip_grad_norm_, norm_type=norm_type)
|
||||
elif mode == 'value':
|
||||
return torch.nn.utils.clip_grad_value_
|
||||
elif mode == 'agc':
|
||||
return partial(adaptive_clip_grad, norm_type=norm_type)
|
||||
else:
|
||||
assert False, f"Unknown clip mode ({mode})."
|
||||
|
||||
|
||||
def get_clip_parameters(model, skip_last=0):
|
||||
if hasattr(model, 'get_clip_parameters'):
|
||||
return model.get_clip_parameters()
|
||||
else:
|
||||
if skip_last:
|
||||
return list(model.parameters())[::-skip_last]
|
||||
else:
|
||||
return model.parameters()
|
@ -0,0 +1,145 @@
|
||||
import abc
|
||||
from typing import Callable, Union, Optional, List, Tuple, Dict
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .distributed import all_gather_sequence, all_reduce_sequence
|
||||
|
||||
MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValueInfo:
|
||||
initial: Optional[MetricValueT] = 0.
|
||||
dtype: torch.dtype = torch.float32
|
||||
dist_reduce: str = 'sum'
|
||||
dist_average: bool = False
|
||||
|
||||
|
||||
class Metric(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dev_env: DeviceEnv = None
|
||||
):
|
||||
self._infos: Dict[str, ValueInfo] = {}
|
||||
self._values: Dict[str, Optional[MetricValueT]] = {}
|
||||
self._values_dist: Dict[str, Optional[MetricValueT]] = {}
|
||||
if dev_env is None:
|
||||
dev_env = DeviceEnv.instance()
|
||||
self._dev_env = dev_env
|
||||
|
||||
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
|
||||
info = info or ValueInfo()
|
||||
self._infos[name] = info
|
||||
|
||||
# def get_value(self, name: str, use_dist=True):
|
||||
# if use_dist:
|
||||
# return self._values_dist.get(name, self._values.get(name))
|
||||
# else:
|
||||
# return self._values.get(name)
|
||||
|
||||
def __getattr__(self, item):
|
||||
if item not in self._infos:
|
||||
raise AttributeError
|
||||
value = self._values_dist.get(item, self._values.get(item, None))
|
||||
return value
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if '_infos' in self.__dict__ and key in self._infos:
|
||||
self._values[key] = value
|
||||
else:
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def update(
|
||||
self,
|
||||
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
|
||||
self._update(predictions, target)
|
||||
|
||||
def _update(
|
||||
self,
|
||||
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
|
||||
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
self._values = {}
|
||||
self._values_dist = {}
|
||||
for name, info in self._infos.items():
|
||||
# if info specifies an initial value, we reset here, otherwise set to None and leave it to child class
|
||||
if info.initial is not None:
|
||||
if isinstance(info.initial, torch.Tensor):
|
||||
tensor = info.initial.detach().clone()
|
||||
else:
|
||||
tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar
|
||||
self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype)
|
||||
else:
|
||||
self._values[name] = None
|
||||
self._reset()
|
||||
|
||||
def _reset(self):
|
||||
pass
|
||||
|
||||
def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
|
||||
if self._dev_env.distributed:
|
||||
self._distribute_values()
|
||||
results = self._compute()
|
||||
self._values_dist = {}
|
||||
return results
|
||||
|
||||
@abc.abstractmethod
|
||||
def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
|
||||
pass
|
||||
|
||||
def _distribute_values(self):
|
||||
if not self._infos or not self._values:
|
||||
return
|
||||
|
||||
def _args(op: str):
|
||||
if op == 'cat':
|
||||
return True, dict(cat_dim=0)
|
||||
else:
|
||||
return False, dict(op=ReduceOp.SUM)
|
||||
|
||||
prev_dsr = None
|
||||
same_dsr = True
|
||||
names = []
|
||||
values = []
|
||||
reductions = []
|
||||
for name, value in self._values.items():
|
||||
if value is not None:
|
||||
info = self._infos[name]
|
||||
dsr = (value.dtype, value.shape, info.dist_reduce)
|
||||
if prev_dsr is not None and prev_dsr != dsr:
|
||||
same_dsr = False
|
||||
prev_dsr = dsr
|
||||
names.append(name)
|
||||
values.append(value)
|
||||
reductions.append(_args(info.dist_reduce))
|
||||
|
||||
if same_dsr:
|
||||
do_gather, reduce_kwargs = reductions[0]
|
||||
if do_gather:
|
||||
reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
|
||||
else:
|
||||
reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
|
||||
for name, reduced_value in zip(names, reduced_values):
|
||||
info = self._infos[name]
|
||||
if info.dist_average:
|
||||
reduced_value /= self._dev_env.world_size
|
||||
self._values_dist[name] = reduced_value
|
||||
else:
|
||||
for n, v, r in zip(names, values, reductions):
|
||||
info = self._infos[n]
|
||||
do_gather, reduce_kwargs = r
|
||||
if do_gather:
|
||||
reduced_value = self._dev_env.all_gather(v, **reduce_kwargs)
|
||||
else:
|
||||
reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs)
|
||||
if info.dist_average:
|
||||
reduced_value /= self._dev_env.world_size
|
||||
self._values_dist[n] = reduced_value
|
@ -0,0 +1,71 @@
|
||||
import torch
|
||||
from typing import Optional, Tuple, Dict
|
||||
|
||||
from .device_env import DeviceEnv
|
||||
from .metric import Metric, ValueInfo
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
threshold=0.5,
|
||||
multi_label=False,
|
||||
accumulate_dtype=torch.float32,
|
||||
dev_env=None,
|
||||
):
|
||||
super().__init__(dev_env=dev_env)
|
||||
self.accumulate_dtype = accumulate_dtype
|
||||
self.threshold = threshold
|
||||
self.eps = 1e-8
|
||||
self.multi_label = multi_label
|
||||
|
||||
# statistics / counts
|
||||
self._register_value('correct', ValueInfo(dtype=accumulate_dtype))
|
||||
self._register_value('total', ValueInfo(dtype=accumulate_dtype))
|
||||
|
||||
def _update(self, predictions, target):
|
||||
raise NotImplemented()
|
||||
|
||||
def _compute(self):
|
||||
raise NotImplemented()
|
||||
|
||||
|
||||
class AccuracyTopK(Metric):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
topk=(1, 5),
|
||||
accumulate_dtype=torch.float32,
|
||||
dev_env: DeviceEnv = None
|
||||
):
|
||||
super().__init__(dev_env=dev_env)
|
||||
self.accumulate_dtype = accumulate_dtype
|
||||
self.eps = 1e-8
|
||||
self.topk = topk
|
||||
self.maxk = max(topk)
|
||||
|
||||
# statistics / counts
|
||||
for k in self.topk:
|
||||
self._register_value(f'top{k}', ValueInfo(dtype=accumulate_dtype))
|
||||
self._register_value('total', ValueInfo(dtype=accumulate_dtype))
|
||||
self.reset()
|
||||
|
||||
def _update(self, predictions: torch.Tensor, target: torch.Tensor):
|
||||
batch_size = predictions.shape[0]
|
||||
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
|
||||
target_reshape = target.reshape(-1, 1).expand_as(sorted_indices)
|
||||
correct = sorted_indices.eq(target_reshape).to(dtype=self.accumulate_dtype).sum(0)
|
||||
for k in self.topk:
|
||||
attr_name = f'top{k}'
|
||||
correct_at_k = correct[:k].sum()
|
||||
setattr(self, attr_name, getattr(self, attr_name) + correct_at_k)
|
||||
self.total += batch_size
|
||||
|
||||
def _compute(self) -> Dict[str, torch.Tensor]:
|
||||
assert self.total is not None
|
||||
output = {}
|
||||
for k in self.topk:
|
||||
attr_name = f'top{k}'
|
||||
output[attr_name] = 100 * getattr(self, attr_name) / self.total
|
||||
return output
|
@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class PrecisionRecall:
|
||||
|
||||
def __init__(self, threshold=0.5, multi_label=False, device=None):
|
||||
self.threshold = threshold
|
||||
self.device = device
|
||||
self.multi_label = multi_label
|
||||
|
||||
# statistics
|
||||
|
||||
# the total number of true positive instances under each class
|
||||
# Shape: (num_classes, )
|
||||
self._tp_sum = None
|
||||
|
||||
# the total number of instances
|
||||
# Shape: (num_classes, )
|
||||
self._total_sum = None
|
||||
|
||||
# the total number of instances under each _predicted_ class,
|
||||
# including true positives and false positives
|
||||
# Shape: (num_classes, )
|
||||
self._pred_sum = None
|
||||
|
||||
# the total number of instances under each _true_ class,
|
||||
# including true positives and false negatives
|
||||
# Shape: (num_classes, )
|
||||
self._true_sum = None
|
||||
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self._tp_sum = None
|
||||
self._total_sum = None
|
||||
self._pred_sum = None
|
||||
self._true_sum = None
|
||||
|
||||
def update(self, predictions, target):
|
||||
output_type = predictions.type()
|
||||
num_classes = predictions.size(-1)
|
||||
if self.multi_label:
|
||||
if self.threshold is not None:
|
||||
predictions = (predictions > self.threshold).type(output_type)
|
||||
predictions = predictions.t().reshape(num_classes, -1)
|
||||
target = target.t().reshape(num_classes, -1)
|
||||
else:
|
||||
target = F.one_hot(target.view(-1), num_classes=num_classes)
|
||||
indices = torch.argmax(predictions, dim=1).view(-1)
|
||||
predictions = F.one_hot(indices, num_classes=num_classes)
|
||||
# FIXME make sure binary case works
|
||||
|
||||
target = target.type(output_type)
|
||||
correct = (target * predictions > 0).type(output_type)
|
||||
pred_positives = predictions.sum(dim=0)
|
||||
target_positives = target.sum(dim=0)
|
||||
if correct.sum() == 0:
|
||||
true_positives = torch.zeros_like(pred_positives)
|
||||
else:
|
||||
true_positives = correct.sum(dim=0)
|
||||
|
||||
if self._tp_sum is None:
|
||||
self._tp_sum = torch.zeros(num_classes, device=self.device)
|
||||
self._true_sum = torch.zeros(num_classes, device=self.device)
|
||||
self._pred_sum = torch.zeros(num_classes, device=self.device)
|
||||
self._total_sum = torch.tensor(0, device=self.device)
|
||||
|
||||
self._tp_sum += true_positives
|
||||
self._pred_sum += pred_positives
|
||||
self._true_sum += target_positives
|
||||
self._total_sum += target.shape[0]
|
||||
|
||||
def counts_as_tuple(self, reduce=False):
|
||||
tp_sum = self._tp_sum
|
||||
pred_sum = self._pred_sum
|
||||
true_sum = self._true_sum
|
||||
total_sum = self._total_sum
|
||||
if reduce:
|
||||
tp_sum = reduce_tensor_sum(tp_sum)
|
||||
pred_sum = reduce_tensor_sum(pred_sum)
|
||||
true_sum = reduce_tensor_sum(true_sum)
|
||||
total_sum = reduce_tensor_sum(total_sum)
|
||||
return tp_sum, pred_sum, true_sum, total_sum
|
||||
|
||||
def counts(self, reduce=False):
|
||||
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
|
||||
return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum)
|
||||
|
||||
def confusion(self, reduce=False):
|
||||
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
|
||||
fp = pred_sum - tp_sum
|
||||
fn = true_sum - tp_sum
|
||||
tp = tp_sum
|
||||
tn = total_sum - tp - fp - fn
|
||||
return dict(tp=tp, fp=fp, fn=fn, tn=tn)
|
||||
|
||||
def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False):
|
||||
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed)
|
||||
if average == 'micro':
|
||||
tp_sum = tp_sum.sum()
|
||||
pred_sum = pred_sum.sum()
|
||||
true_sum = true_sum.sum()
|
||||
|
||||
precision = tp_sum / pred_sum
|
||||
recall = tp_sum / true_sum
|
||||
beta_sq = fscore_beta ** 2
|
||||
f1_denom = beta_sq * precision + recall
|
||||
fscore = (1 + beta_sq) * precision * recall / f1_denom
|
||||
|
||||
if average == 'macro' and not no_reduce:
|
||||
precision = precision.mean()
|
||||
recall = recall.mean()
|
||||
fscore = fscore.mean()
|
||||
return dict(fscore=fscore, precision=precision, recall=recall)
|
||||
|
||||
return dict(fscore=fscore, precision=precision, recall=recall)
|
@ -0,0 +1,238 @@
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Tuple, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
HAS_TB = True
|
||||
except ImportError as e:
|
||||
HAS_TB = False
|
||||
|
||||
try:
|
||||
import wandb
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
|
||||
# FIXME old formatting for reference, to remove
|
||||
#
|
||||
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
|
||||
# log_name = 'Test' + log_suffix
|
||||
# logging.info(
|
||||
# f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
|
||||
# f'Time: {batch_time.smooth_val:.3f} ({batch_time.avg:.3f}) '
|
||||
# f'Loss: {loss.smooth_val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
# f'Acc@1: {top1.smooth_val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
# f'Acc@5: {top5.smooth_val:>7.4f} ({top5.avg:>7.4f})'
|
||||
# )
|
||||
#
|
||||
#
|
||||
# def log_train(epoch, step, num_steps, loss, batch_size, batch_time, data_time, lr, world_size=1):
|
||||
# last_step = max(0, num_steps - 1)
|
||||
# progress = 100. * step / last_step if last_step else 0.
|
||||
# log_str = f'Train: {epoch} [{step:>4d}/{num_steps} ({progress:>3.0f}%)]' \
|
||||
# f' Time: {batch_time.smooth_val:.3f}s, {batch_size * world_size / batch_time.smooth_val:>7.2f}/s' \
|
||||
# f' ({batch_time.avg:.3f}s, {batch_size * world_size / batch_time.avg:>7.2f}/s)' \
|
||||
# f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})'
|
||||
# log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) '
|
||||
# log_str += f' LR: {lr:.3e} '
|
||||
|
||||
|
||||
def summary_row_dict(results, index=None, index_name='epoch'):
|
||||
assert isinstance(results, dict)
|
||||
row_dict = OrderedDict()
|
||||
if index is not None:
|
||||
row_dict[index_name] = index
|
||||
if not results:
|
||||
return row_dict
|
||||
if isinstance(next(iter(results.values())), dict):
|
||||
# each key in results is a per-phase results dict, flatten by prefixing with phase name
|
||||
for p, pr in results.items():
|
||||
assert isinstance(pr, dict)
|
||||
row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()])
|
||||
else:
|
||||
row_dict.update(results)
|
||||
return row_dict
|
||||
|
||||
|
||||
class SummaryCsv:
|
||||
def __init__(self, output_dir, filename='summary.csv'):
|
||||
self.output_dir = output_dir
|
||||
self.filename = os.path.join(output_dir, filename)
|
||||
self.needs_header = not os.path.exists(self.filename)
|
||||
|
||||
def update(self, row_dict):
|
||||
with open(self.filename, mode='a') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=row_dict.keys())
|
||||
if self.needs_header: # first iteration (epoch == 1 can't be used)
|
||||
dw.writeheader()
|
||||
self.needs_header = False
|
||||
dw.writerow(row_dict)
|
||||
|
||||
|
||||
_sci_keys = {'lr'}
|
||||
|
||||
|
||||
def _add_kwargs(text_update, name_map=None, **kwargs):
|
||||
def _to_str(key, val):
|
||||
if isinstance(val, float):
|
||||
if key.lower() in _sci_keys:
|
||||
return f'{key}: {val:.3e} '
|
||||
else:
|
||||
return f'{key}: {val:.4f}'
|
||||
else:
|
||||
return f'{key}: {val}'
|
||||
|
||||
def _map_name(key, name_map, capitalize=True):
|
||||
if name_map is None:
|
||||
if capitalize:
|
||||
return key.capitalize() if not key.isupper() else key
|
||||
else:
|
||||
return key
|
||||
return name_map.get(key, None)
|
||||
|
||||
for k, v in kwargs.items():
|
||||
if isinstance(v, dict):
|
||||
# log each k, v of a dict kwarg as separate items
|
||||
for kk, vv in v.items():
|
||||
name = _map_name(kk, name_map)
|
||||
if not name:
|
||||
continue
|
||||
text_update += [_to_str(kk, vv)]
|
||||
else:
|
||||
name = _map_name(k, name_map, capitalize=True)
|
||||
if not name:
|
||||
continue
|
||||
text_update += [_to_str(name, v)]
|
||||
|
||||
|
||||
class Monitor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name=None,
|
||||
output_dir=None,
|
||||
logger=None,
|
||||
hparams=None,
|
||||
log_wandb=False,
|
||||
output_enabled=True,
|
||||
):
|
||||
self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging
|
||||
self.logger = logger or logging.getLogger('log')
|
||||
hparams = hparams or {}
|
||||
|
||||
# Setup CSV writer(s)
|
||||
if output_dir is not None:
|
||||
self.csv_writer = SummaryCsv(output_dir=output_dir)
|
||||
else:
|
||||
self.csv_writer = None
|
||||
|
||||
# Setup Tensorboard
|
||||
self.summary_writer = None # FIXME tensorboard
|
||||
|
||||
# Setup W&B
|
||||
self.wandb_run = None
|
||||
if log_wandb:
|
||||
if HAS_WANDB:
|
||||
self.wandb_run = wandb.init(project=experiment_name, config=hparams)
|
||||
else:
|
||||
_logger.warning("You've requested to log metrics to wandb but package not found. "
|
||||
"Metrics not being logged to wandb, try `pip install wandb`")
|
||||
|
||||
self.output_enabled = output_enabled
|
||||
# FIXME image save
|
||||
|
||||
def log_step(
|
||||
self,
|
||||
phase: str,
|
||||
step_idx: int,
|
||||
step_end_idx: Optional[int] = None,
|
||||
epoch: Optional[int] = None,
|
||||
loss: Optional[float] = None,
|
||||
rate: Optional[Union[float, Tuple[float, float]]] = None,
|
||||
phase_suffix: str = '',
|
||||
**kwargs,
|
||||
):
|
||||
""" log train/eval step
|
||||
"""
|
||||
if not self.output_enabled:
|
||||
return
|
||||
if 'num_steps' in kwargs:
|
||||
step_end_idx = max(0, kwargs.pop('num_steps') - 1)
|
||||
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:'
|
||||
progress = 100. * step_idx / step_end_idx if step_end_idx else 0.
|
||||
rate_str = ''
|
||||
if isinstance(rate, (tuple, list)):
|
||||
rate_str = f'Rate: {rate[0]:.2f}/s ({rate[1]:.2f}/s)'
|
||||
elif rate is not None:
|
||||
rate_str = f'Rate: {rate:.2f}/s'
|
||||
text_update = [
|
||||
phase_title,
|
||||
f'{epoch}' if epoch is not None else None,
|
||||
f'[{step_idx}]' if step_end_idx is None else None,
|
||||
f'[{step_idx}/{step_end_idx} ({progress:>3.0f}%)]' if step_end_idx is not None else None,
|
||||
rate_str,
|
||||
f'Loss: {loss:.5f}' if loss is not None else None,
|
||||
]
|
||||
_add_kwargs(text_update, **kwargs)
|
||||
log_str = ' '.join(item for item in text_update if item)
|
||||
self.logger.info(log_str)
|
||||
if self.summary_writer is not None:
|
||||
# FIXME log step values to tensorboard
|
||||
pass
|
||||
|
||||
def log_phase(
|
||||
self,
|
||||
phase: str = 'eval',
|
||||
epoch: Optional[int] = None,
|
||||
name_map: Optional[dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""log completion of evaluation or training phase
|
||||
"""
|
||||
if not self.output_enabled:
|
||||
return
|
||||
|
||||
title = [
|
||||
f'{phase.capitalize()}',
|
||||
f'epoch: {epoch}' if epoch is not None else None,
|
||||
'completed. ',
|
||||
]
|
||||
title_str = ' '.join(i for i in title if i)
|
||||
results = []
|
||||
_add_kwargs(results, name_map=name_map, **kwargs)
|
||||
log_str = title_str + ', '.join(item for item in results if item)
|
||||
self.logger.info(log_str)
|
||||
|
||||
def write_summary(
|
||||
self,
|
||||
results: Dict, # Dict or Dict of Dict where first level keys are treated as per-phase results
|
||||
index: Optional[Union[int, str]] = None,
|
||||
index_name: str = 'epoch',
|
||||
):
|
||||
""" Log complete results for all phases (typically called at end of epoch)
|
||||
|
||||
Args:
|
||||
results (dict or dict[dict]): dict of results to write, or multiple dicts where first level
|
||||
key is the name of results dict for each phase
|
||||
index: value for row index (typically epoch #)
|
||||
index_name: name for row index header (typically 'epoch')
|
||||
"""
|
||||
if not self.output_enabled:
|
||||
return
|
||||
|
||||
row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
|
||||
if self.csv_writer:
|
||||
self.csv_writer.update(row_dict)
|
||||
if self.wandb_run is not None:
|
||||
wandb.log(row_dict)
|
||||
if self.summary_writer:
|
||||
# FIXME log epoch summaries to tensorboard
|
||||
pass
|
@ -0,0 +1,59 @@
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from .avg_scalar import AvgMinMaxScalar
|
||||
|
||||
|
||||
class Tracker:
|
||||
|
||||
def __init__(self):
|
||||
self.data_time = AvgMinMaxScalar() # time for data loader to produce batch of samples
|
||||
self.step_time = AvgMinMaxScalar() # time for model step
|
||||
self.iter_time = AvgMinMaxScalar() # full iteration time incl. data, step, and book-keeping
|
||||
self.epoch_time = AvgMinMaxScalar()
|
||||
|
||||
self.iter_timestamp: Optional[float] = None
|
||||
self.prev_timestamp: Optional[float] = None
|
||||
self.epoch_timestamp: Optional[float] = None
|
||||
|
||||
def _measure_iter(self, ref_timestamp=None):
|
||||
timestamp = time.perf_counter()
|
||||
self.prev_timestamp = timestamp
|
||||
|
||||
def mark_iter(self):
|
||||
timestamp = time.perf_counter()
|
||||
if self.iter_timestamp is not None:
|
||||
iter_time = timestamp - self.iter_timestamp
|
||||
self.iter_time.update(iter_time)
|
||||
self.iter_timestamp = self.prev_timestamp = timestamp
|
||||
|
||||
def mark_iter_data_end(self):
|
||||
assert self.prev_timestamp is not None
|
||||
timestamp = time.perf_counter()
|
||||
data_time = timestamp - self.prev_timestamp
|
||||
self.data_time.update(data_time)
|
||||
self.prev_timestamp = timestamp
|
||||
|
||||
def mark_iter_step_end(self):
|
||||
assert self.prev_timestamp is not None
|
||||
timestamp = time.perf_counter()
|
||||
step_time = timestamp - self.prev_timestamp
|
||||
self.step_time.update(step_time)
|
||||
self.prev_timestamp = timestamp
|
||||
|
||||
def mark_epoch(self):
|
||||
timestamp = time.perf_counter()
|
||||
if self.epoch_timestamp is not None:
|
||||
epoch_time = timestamp - self.epoch_timestamp
|
||||
self.epoch_time.update(epoch_time)
|
||||
self.epoch_timestamp = timestamp
|
||||
|
||||
def get_avg_iter_rate(self, num_per_iter: int):
|
||||
if num_per_iter == 0 or self.iter_time.avg == 0:
|
||||
return 0
|
||||
return num_per_iter / self.iter_time.avg
|
||||
|
||||
def get_last_iter_rate(self, num_per_iter: int):
|
||||
if num_per_iter == 0 or self.iter_time.val == 0:
|
||||
return 0
|
||||
return num_per_iter / self.iter_time.val
|
@ -0,0 +1,12 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainCfg:
|
||||
""" Train Loop Configuration
|
||||
Dataclass to hold training configuration values
|
||||
"""
|
||||
num_epochs: int = 100
|
||||
log_interval: int = 50
|
||||
recovery_interval: int = 0
|
||||
accumulate_steps: int = 0
|
@ -0,0 +1,13 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
from .monitor import Monitor
|
||||
from .checkpoint_manager import CheckpointManager
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainServices:
|
||||
""" Train Loop Services
|
||||
"""
|
||||
monitor: Monitor = None
|
||||
checkpoint: CheckpointManager = None
|
||||
|
@ -0,0 +1,151 @@
|
||||
import dataclasses
|
||||
from typing import Callable, Union, Optional
|
||||
import logging
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.models.layers import convert_sync_batchnorm
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import ModelEmaV2
|
||||
|
||||
try:
|
||||
import deepspeed as ds
|
||||
except ImportError:
|
||||
ds = None
|
||||
|
||||
from .checkpoint import load_train_state
|
||||
from .device_env import DeviceEnv
|
||||
from .train_cfg import TrainCfg
|
||||
from .train_state import TrainState
|
||||
from .updater_factory import create_updater
|
||||
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_model_and_optimizer(
|
||||
dev_env: DeviceEnv,
|
||||
model: nn.Module,
|
||||
optimizer: Union[Callable, str],
|
||||
optimizer_cfg,
|
||||
clip_fn: Optional[Union[Callable, str]] = None,
|
||||
clip_value: Optional[float] = None,
|
||||
model_ema: bool = False,
|
||||
model_ema_decay: float = 0.9999,
|
||||
use_syncbn: bool = False,
|
||||
resume_path: str = '',
|
||||
resume_opt: bool = True,
|
||||
deepspeed: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
dev_env:
|
||||
model:
|
||||
optimizer:
|
||||
optimizer_cfg:
|
||||
clip_value:
|
||||
clip_fn:
|
||||
model_ema:
|
||||
model_ema_decay:
|
||||
use_syncbn:
|
||||
resume_path:
|
||||
resume_opt:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if deepspeed:
|
||||
return setup_model_and_optimizer_deepspeed(
|
||||
dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg,
|
||||
clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay,
|
||||
resume_path=resume_path, resume_opt=resume_opt,
|
||||
)
|
||||
|
||||
dev_env.to_device(model)
|
||||
|
||||
if use_syncbn and dev_env.distributed:
|
||||
model = convert_sync_batchnorm(model)
|
||||
if dev_env.primary:
|
||||
_logger.info(
|
||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||
|
||||
if isinstance(optimizer, Callable):
|
||||
# FIXME this interface needs to be figured out, model, model and/or parameters, or just parameters?
|
||||
optimizer = optimizer(model, **optimizer_cfg)
|
||||
else:
|
||||
optimizer = create_optimizer_v2(model, **optimizer_cfg)
|
||||
|
||||
updater = create_updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
clip_fn=clip_fn,
|
||||
clip_value=clip_value,
|
||||
)
|
||||
|
||||
# ema model
|
||||
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
|
||||
|
||||
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
|
||||
|
||||
if resume_path:
|
||||
load_train_state(
|
||||
train_state,
|
||||
resume_path,
|
||||
load_opt=resume_opt,
|
||||
log_info=dev_env.primary)
|
||||
|
||||
if dev_env.distributed:
|
||||
train_state = dataclasses.replace(
|
||||
train_state, model=dev_env.wrap_distributed(train_state.model))
|
||||
|
||||
return train_state
|
||||
|
||||
|
||||
def setup_model_and_optimizer_deepspeed(
|
||||
dev_env: DeviceEnv,
|
||||
model: nn.Module,
|
||||
optimizer: Union[Callable, str],
|
||||
optimizer_cfg,
|
||||
clip_fn: Optional[Union[Callable, str]] = None,
|
||||
clip_value: Optional[float] = None,
|
||||
model_ema: bool = False,
|
||||
model_ema_decay: float = 0.9999,
|
||||
use_syncbn: bool = False,
|
||||
resume_path: str = '',
|
||||
resume_opt: bool = True,
|
||||
):
|
||||
dev_env.to_device(model)
|
||||
|
||||
if isinstance(optimizer, Callable):
|
||||
optimizer = optimizer(model=model, **optimizer_cfg)
|
||||
else:
|
||||
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
|
||||
|
||||
model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False)
|
||||
|
||||
updater = create_updater(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
clip_fn=clip_fn,
|
||||
clip_value=clip_value,
|
||||
deepspeed=True,
|
||||
)
|
||||
|
||||
# ema model
|
||||
# FIXME how to do EMA w/ deepspeed?
|
||||
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
|
||||
|
||||
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
|
||||
|
||||
if resume_path:
|
||||
# FIXME deepspeed resumes differently
|
||||
assert False
|
||||
|
||||
if dev_env.distributed:
|
||||
train_state = dataclasses.replace(
|
||||
train_state, model=dev_env.wrap_distributed(train_state.model))
|
||||
|
||||
return train_state
|
@ -0,0 +1,64 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.scheduler import Scheduler
|
||||
from timm.utils import get_state_dict, unwrap_model
|
||||
|
||||
from .train_cfg import TrainCfg
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainState:
|
||||
model: nn.Module = None
|
||||
train_loss: nn.Module = None
|
||||
eval_loss: nn.Module = None
|
||||
updater: Updater = None
|
||||
lr_scheduler: Scheduler = None
|
||||
model_ema: nn.Module = None
|
||||
|
||||
train_cfg: TrainCfg = TrainCfg()
|
||||
# FIXME collect & include other cfg like data & model so it's in one spot for checkpoints / logging / debugging?
|
||||
|
||||
epoch: int = 0
|
||||
step_count: int = 0
|
||||
step_count_global: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.model is not None
|
||||
assert self.updater is not None
|
||||
|
||||
def state_dict(self, unwrap_fn=unwrap_model):
|
||||
state = dict(
|
||||
# train loop state (counters, etc), saved and restored
|
||||
epoch=self.epoch,
|
||||
step_count=self.step_count,
|
||||
step_count_global=self.step_count_global,
|
||||
|
||||
# model params / state, saved and restored
|
||||
model=get_state_dict(self.model, unwrap_fn),
|
||||
model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn),
|
||||
|
||||
# configuration, saved but currently not restored, determined by args / config file for each run
|
||||
train_cfg=vars(self.train_cfg)
|
||||
)
|
||||
# FIXME include lr_scheduler state?
|
||||
state.update(self.updater.state_dict()) # updater (optimizer, scaler, etc.) state added to state
|
||||
return state
|
||||
|
||||
def load_state_dict(self, state_dict, unwrap_fn=unwrap_model, load_opt=True):
|
||||
# restore train loop state
|
||||
self.epoch = state_dict['epoch'] + 1
|
||||
self.step_count = 0 # FIXME need more logic to restore part way through epoch
|
||||
self.step_count_global = state_dict['step_count_global']
|
||||
|
||||
# restore model params / state
|
||||
unwrap_fn(self.model).load_state_dict(state_dict.get('model'))
|
||||
if 'model_ema' in state_dict and self.model_ema is not None:
|
||||
unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema'))
|
||||
|
||||
# restore optimizer state
|
||||
if load_opt:
|
||||
self.updater.load_state_dict(state_dict)
|
@ -0,0 +1,72 @@
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from functools import partial
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .grad_clip import get_clip_grad_fn, get_clip_parameters
|
||||
|
||||
|
||||
@dataclass
|
||||
class Updater:
|
||||
|
||||
model: nn.Module = None
|
||||
optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
|
||||
clip_fn: Optional[Union[Callable, str]] = None
|
||||
clip_value: Optional[float] = None
|
||||
clip_params_fn: Optional[Callable] = None
|
||||
grad_scaler: Optional[Callable] = None
|
||||
create_graph: Optional[bool] = None
|
||||
after_step_closure: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
assert self.model is not None
|
||||
assert self.optimizer is not None
|
||||
if self.clip_fn is not None:
|
||||
if isinstance(self.clip_fn, Callable):
|
||||
skip_last = 0
|
||||
else:
|
||||
assert isinstance(self.clip_fn, str)
|
||||
skip_last = 2 if 'agc' in self.clip_fn else 0
|
||||
self.clip_fn = get_clip_grad_fn(self.clip_fn)
|
||||
assert self.clip_value is not None
|
||||
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
|
||||
if self.create_graph is None:
|
||||
self.create_graph = getattr(self.optimizer, 'second_order', False)
|
||||
self.after_step_closure = False
|
||||
|
||||
def reset(self):
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
loss.backward(create_graph=self.create_graph)
|
||||
if accumulate:
|
||||
return
|
||||
if self.clip_fn is not None:
|
||||
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
||||
self.optimizer.step()
|
||||
self.reset()
|
||||
|
||||
def get_average_lr(self):
|
||||
lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
|
||||
return sum(lrl) / len(lrl)
|
||||
|
||||
def state_dict(self):
|
||||
state_dict = dict(optimizer=self.optimizer.state_dict())
|
||||
if self.grad_scaler is not None:
|
||||
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
if 'optimizer' in state_dict:
|
||||
self.optimizer.load_state_dict(state_dict['optimizer'])
|
||||
if 'grad_scaler' in state_dict and self.grad_scaler is not None:
|
||||
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
|
||||
|
||||
def after_step(self, after_step_fn, *args):
|
||||
after_step_fn(*args)
|
||||
|
||||
@property
|
||||
def deepspeed(self):
|
||||
return False
|
@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Dict, Any
|
||||
|
||||
import torch
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterCudaWithScaler(Updater):
|
||||
|
||||
scaler_kwargs: InitVar[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
|
||||
super().__post_init__()
|
||||
scaler_kwargs = scaler_kwargs or {}
|
||||
self.grad_scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
self.grad_scaler.scale(loss).backward(create_graph=self.create_graph)
|
||||
if accumulate:
|
||||
# unscale first?
|
||||
return
|
||||
if self.clip_fn is not None:
|
||||
# unscale the gradients of optimizer's assigned params in-place
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
self.reset()
|
@ -0,0 +1,30 @@
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
|
||||
import torch
|
||||
try:
|
||||
import deepspeed as ds
|
||||
except ImportError as e:
|
||||
ds = None
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterDeepSpeed(Updater):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface
|
||||
assert isinstance(self.model, ds.DeepSpeedEngine)
|
||||
|
||||
def reset(self):
|
||||
self.model.zero_grad()
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate=False):
|
||||
self.model.backward(loss)
|
||||
self.model.step()
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def deepspeed(self):
|
||||
return True
|
@ -0,0 +1,38 @@
|
||||
from typing import Callable, Optional, Union, Any
|
||||
|
||||
import torch
|
||||
|
||||
from .device_env import DeviceEnv, DeviceEnvType
|
||||
from .updater import Updater
|
||||
from .updater_cuda import UpdaterCudaWithScaler
|
||||
from .updater_deepspeed import UpdaterDeepSpeed
|
||||
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
|
||||
|
||||
|
||||
def create_updater(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
clip_fn: Optional[Union[Callable, str]] = None,
|
||||
clip_value: Optional[float] = None,
|
||||
scaler_kwargs: Any = None,
|
||||
dev_env: Optional[DeviceEnv] = None,
|
||||
deepspeed: bool = False,
|
||||
) -> Updater:
|
||||
|
||||
if not dev_env:
|
||||
dev_env = DeviceEnv.instance()
|
||||
|
||||
updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
|
||||
use_scaler = dev_env.amp
|
||||
if use_scaler:
|
||||
updater_kwargs['scaler_kwargs'] = scaler_kwargs
|
||||
updater_cls = Updater
|
||||
if dev_env.type == DeviceEnvType.XLA:
|
||||
updater_cls = UpdaterXlaWithScaler if use_scaler else UpdaterXla
|
||||
elif dev_env.type == DeviceEnvType.CUDA and use_scaler:
|
||||
updater_cls = UpdaterCudaWithScaler
|
||||
elif deepspeed:
|
||||
del updater_kwargs['scaler_kwargs']
|
||||
updater_cls = UpdaterDeepSpeed
|
||||
|
||||
return updater_cls(**updater_kwargs)
|
@ -0,0 +1,68 @@
|
||||
from dataclasses import dataclass, field, InitVar
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
_HAS_XLA = True
|
||||
except ImportError as e:
|
||||
xm = None
|
||||
_HAS_XLA = False
|
||||
|
||||
try:
|
||||
# only the very latest XLA builds have AMP
|
||||
import torch_xla.amp as xa
|
||||
except ImportError as e:
|
||||
xa = None
|
||||
|
||||
from .updater import Updater
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterXla(Updater):
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
self.after_step_closure = True
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate: bool = False):
|
||||
loss.backward(create_graph=self.create_graph)
|
||||
if accumulate:
|
||||
return
|
||||
xm.reduce_gradients(self.optimizer)
|
||||
if self.clip_fn is not None:
|
||||
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
||||
self.optimizer.step()
|
||||
xm.mark_step()
|
||||
self.reset()
|
||||
|
||||
def after_step(self, after_step_fn, *args):
|
||||
xm.add_step_closure(after_step_fn, args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdaterXlaWithScaler(UpdaterXla):
|
||||
|
||||
scaler_kwargs: InitVar[Dict[str, Any]] = None
|
||||
|
||||
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
|
||||
super().__post_init__()
|
||||
scaler_kwargs = scaler_kwargs or {}
|
||||
assert xa is not None, 'XLA AMP not present in this build'
|
||||
self.scaler = xa.GradScaler(**scaler_kwargs)
|
||||
|
||||
def apply(self, loss: torch.Tensor, accumulate: bool = False):
|
||||
self.scaler.scale(loss).backward(create_graph=self.create_graph)
|
||||
if accumulate:
|
||||
# unscale first?
|
||||
return
|
||||
xm.reduce_gradients(self.optimizer)
|
||||
if self.clip_fn is not None:
|
||||
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
|
||||
self.clip_fn(self.clip_params_fn(), self.clip_value)
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
xm.mark_step()
|
||||
self.reset()
|
@ -1,13 +1,13 @@
|
||||
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
|
||||
rand_augment_transform, auto_augment_transform
|
||||
from .config import resolve_data_config
|
||||
from .config import resolve_data_config, PreprocessCfg, AugCfg, MixupCfg
|
||||
from .constants import *
|
||||
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
|
||||
from .dataset_factory import create_dataset
|
||||
from .loader import create_loader
|
||||
from .loader import create_loader_v2
|
||||
from .mixup import Mixup, FastCollateMixup
|
||||
from .parsers import create_parser,\
|
||||
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
|
||||
from .real_labels import RealLabelsImagenet
|
||||
from .transforms import *
|
||||
from .transforms_factory import create_transform
|
||||
from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy
|
||||
from .transforms_factory import create_transform_v2, create_transform
|
||||
|
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def fast_collate(batch):
|
||||
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
|
||||
assert isinstance(batch[0], tuple)
|
||||
batch_size = len(batch)
|
||||
if isinstance(batch[0][0], tuple):
|
||||
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
|
||||
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
|
||||
inner_tuple_size = len(batch[0][0])
|
||||
flattened_batch_size = batch_size * inner_tuple_size
|
||||
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
|
||||
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
|
||||
for j in range(inner_tuple_size):
|
||||
targets[i + j * batch_size] = batch[i][1]
|
||||
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], np.ndarray):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
|
||||
for i in range(batch_size):
|
||||
tensor[i] += torch.from_numpy(batch[i][0])
|
||||
return tensor, targets
|
||||
elif isinstance(batch[0][0], torch.Tensor):
|
||||
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
|
||||
assert len(targets) == batch_size
|
||||
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=batch[0][0].dtype)
|
||||
for i in range(batch_size):
|
||||
tensor[i].copy_(batch[i][0])
|
||||
return tensor, targets
|
||||
else:
|
||||
assert False
|
@ -0,0 +1,88 @@
|
||||
import torch
|
||||
|
||||
from .constants import *
|
||||
from .random_erasing import RandomErasing
|
||||
from .mixup import FastCollateMixup
|
||||
|
||||
|
||||
class FetcherXla:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
class Fetcher:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader,
|
||||
device: torch.device,
|
||||
dtype=torch.float32,
|
||||
normalize=True,
|
||||
normalize_shape=(1, 3, 1, 1),
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
num_aug_splits=0,
|
||||
use_mp_loader=False,
|
||||
):
|
||||
self.loader = loader
|
||||
self.device = torch.device(device)
|
||||
self.dtype = dtype
|
||||
if normalize:
|
||||
self.mean = torch.tensor(
|
||||
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
|
||||
self.std = torch.tensor(
|
||||
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
|
||||
else:
|
||||
self.mean = None
|
||||
self.std = None
|
||||
if re_prob > 0.:
|
||||
# NOTE RandomErasing shouldn't be used here w/ XLA devices
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
self.use_mp_loader = use_mp_loader
|
||||
if use_mp_loader:
|
||||
# FIXME testing for TPU use
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
self._loader = pl.MpDeviceLoader(loader, device)
|
||||
else:
|
||||
self._loader = loader
|
||||
|
||||
def __iter__(self):
|
||||
for sample, target in self._loader:
|
||||
if not self.use_mp_loader:
|
||||
sample = sample.to(device=self.device)
|
||||
target = target.to(device=self.device)
|
||||
sample = sample.to(dtype=self.dtype)
|
||||
if self.mean is not None:
|
||||
sample.sub_(self.mean).div_(self.std)
|
||||
if self.random_erasing is not None:
|
||||
sample = self.random_erasing(sample)
|
||||
yield sample, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self.loader.dataset
|
||||
|
||||
@property
|
||||
def mixup_enabled(self):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
return self.loader.collate_fn.mixup_enabled
|
||||
else:
|
||||
return False
|
||||
|
||||
@mixup_enabled.setter
|
||||
def mixup_enabled(self, x):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
self.loader.collate_fn.mixup_enabled = x
|
@ -0,0 +1,330 @@
|
||||
""" Dataset parser interface for webdataset
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import yaml
|
||||
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from functools import partial
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
try:
|
||||
import webdataset as wds
|
||||
from webdataset.shardlists import expand_urls
|
||||
except ImportError:
|
||||
wds = None
|
||||
expand_urls = None
|
||||
|
||||
from .parser import Parser
|
||||
from timm.bits import get_global_device, is_global_device
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
SHUFFLE_SIZE = 8192
|
||||
|
||||
|
||||
def _load_info(root, basename='info'):
|
||||
info_json = os.path.join(root, basename + '.json')
|
||||
info_yaml = os.path.join(root, basename + '.yaml')
|
||||
err_str = ''
|
||||
try:
|
||||
with wds.gopen.gopen(info_json) as f:
|
||||
info_dict = json.load(f)
|
||||
return info_dict
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
with wds.gopen.gopen(info_yaml) as f:
|
||||
info_dict = yaml.safe_load(f)
|
||||
return info_dict
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
# FIXME change to log
|
||||
print(f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
|
||||
f'Falling back to provided split and size arg.')
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class SplitInfo:
|
||||
num_samples: int
|
||||
filenames: Tuple[str]
|
||||
shard_lengths: Tuple[int] = ()
|
||||
alt_label: str = ''
|
||||
name: str = ''
|
||||
|
||||
|
||||
def _parse_split_info(split: str, info: Dict):
|
||||
def _info_convert(dict_info):
|
||||
return SplitInfo(
|
||||
num_samples=dict_info['num_samples'],
|
||||
filenames=tuple(dict_info['filenames']),
|
||||
shard_lengths=tuple(dict_info['shard_lengths']),
|
||||
alt_label=dict_info.get('alt_label', ''),
|
||||
name=dict_info['name'],
|
||||
)
|
||||
|
||||
if 'tar' in split or '..' in split:
|
||||
# split in WDS string braceexpand format, sample count can be included with a | separator
|
||||
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
|
||||
split = split.split('|')
|
||||
num_samples = 0
|
||||
split_name = ''
|
||||
if len(split) > 1:
|
||||
num_samples = int(split[1])
|
||||
split = split[0]
|
||||
if '::' not in split:
|
||||
split_parts = split.split('-', 3)
|
||||
split_idx = len(split_parts) - 1
|
||||
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
|
||||
split_name = split_parts[split_idx]
|
||||
|
||||
split_filenames = expand_urls(split)
|
||||
if split_name:
|
||||
split_info = info['splits'][split_name]
|
||||
if not num_samples:
|
||||
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
|
||||
num_samples = sum(_fc[f] for f in split_filenames)
|
||||
split_info['filenames'] = tuple(_fc.keys())
|
||||
split_info['shard_lengths'] = tuple(_fc.values())
|
||||
split_info['num_samples'] = num_samples
|
||||
split_info = _info_convert(split_info)
|
||||
else:
|
||||
split_info = SplitInfo(
|
||||
name=split_name,
|
||||
num_samples=num_samples,
|
||||
filenames=split_filenames,
|
||||
)
|
||||
else:
|
||||
if split not in info['splits']:
|
||||
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
|
||||
split = split
|
||||
split_info = info['splits'][split]
|
||||
split_info = _info_convert(split_info)
|
||||
|
||||
return split_info
|
||||
|
||||
|
||||
def log_and_continue(exn):
|
||||
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
|
||||
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
|
||||
return True
|
||||
|
||||
|
||||
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
|
||||
""" Custom sample decode
|
||||
* decode and convert PIL Image
|
||||
* cls byte string label to int
|
||||
* pass through JSON byte string (if it exists) without parse
|
||||
"""
|
||||
# decode class label, skip if alternate label not valid
|
||||
if alt_label:
|
||||
# alternative labels are encoded in json metadata
|
||||
meta = json.loads(sample['json'])
|
||||
class_label = int(meta[alt_label])
|
||||
if class_label < 0:
|
||||
# skipped labels currently encoded as -1, may change to a null/None value
|
||||
return None
|
||||
else:
|
||||
class_label = int(sample[target_key])
|
||||
|
||||
# decode image
|
||||
with io.BytesIO(sample[image_key]) as b:
|
||||
img = Image.open(b)
|
||||
img.load()
|
||||
if image_format:
|
||||
img = img.convert(image_format)
|
||||
|
||||
# json passed through in undecoded state
|
||||
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
|
||||
return decoded
|
||||
|
||||
|
||||
def _decode_samples(
|
||||
data,
|
||||
image_key='jpg',
|
||||
image_format='RGB',
|
||||
target_key='cls',
|
||||
alt_label='',
|
||||
handler=log_and_continue):
|
||||
"""Decode samples with skip."""
|
||||
for sample in data:
|
||||
try:
|
||||
result = _decode(
|
||||
sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label)
|
||||
except Exception as exn:
|
||||
if handler(exn):
|
||||
continue
|
||||
else:
|
||||
break
|
||||
|
||||
# null results are skipped
|
||||
if result is not None:
|
||||
if isinstance(sample, dict) and isinstance(result, dict):
|
||||
result["__key__"] = sample.get("__key__")
|
||||
yield result
|
||||
|
||||
|
||||
class ParserWebdataset(Parser):
|
||||
def __init__(
|
||||
self,
|
||||
root,
|
||||
name,
|
||||
split,
|
||||
is_training=False,
|
||||
batch_size=None,
|
||||
repeats=0,
|
||||
seed=42,
|
||||
input_name='image',
|
||||
input_image='RGB',
|
||||
target_name=None,
|
||||
target_image='',
|
||||
prefetch_size=None,
|
||||
shuffle_size=None,
|
||||
):
|
||||
super().__init__()
|
||||
if wds is None:
|
||||
raise RuntimeError(
|
||||
'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
|
||||
self.root = root
|
||||
self.is_training = is_training
|
||||
self.batch_size = batch_size
|
||||
self.repeats = repeats
|
||||
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
|
||||
self.shard_shuffle_size = 500
|
||||
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
|
||||
|
||||
self.image_key = 'jpg'
|
||||
self.image_format = input_image
|
||||
self.target_key = 'cls'
|
||||
self.filename_key = 'filename'
|
||||
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
|
||||
|
||||
self.info = _load_info(self.root)
|
||||
self.split_info = _parse_split_info(split, self.info)
|
||||
self.num_samples = self.split_info.num_samples
|
||||
if not self.num_samples:
|
||||
raise RuntimeError(f'Invalid split definition, no samples found.')
|
||||
|
||||
# Distributed world state
|
||||
self.dist_rank = 0
|
||||
self.dist_num_replicas = 1
|
||||
if is_global_device():
|
||||
dev_env = get_global_device()
|
||||
if dev_env.distributed and dev_env.world_size > 1:
|
||||
self.dist_rank = dev_env.global_rank
|
||||
self.dist_num_replicas = dev_env.world_size
|
||||
else:
|
||||
# FIXME warn if we fallback to torch distributed?
|
||||
import torch.distributed as dist
|
||||
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
|
||||
self.dist_rank = dist.get_rank()
|
||||
self.dist_num_replicas = dist.get_world_size()
|
||||
|
||||
# Attributes that are updated in _lazy_init
|
||||
self.worker_id = 0
|
||||
self.worker_seed = seed # seed unique to each worker instance
|
||||
self.num_workers = 1
|
||||
self.global_worker_id = 0
|
||||
self.global_num_workers = 1
|
||||
self.init_count = 0
|
||||
|
||||
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
|
||||
# is not handled in manner where it can be deterministic for each worker AND initialized up front
|
||||
self.ds = None
|
||||
|
||||
def _lazy_init(self):
|
||||
""" Lazily initialize worker (in worker processes)
|
||||
"""
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is not None:
|
||||
self.worker_id = worker_info.id
|
||||
self.worker_seed = worker_info.seed
|
||||
self.num_workers = worker_info.num_workers
|
||||
self.global_num_workers = self.dist_num_replicas * self.num_workers
|
||||
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
|
||||
|
||||
# init data pipeline
|
||||
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
|
||||
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
|
||||
# at this point we have an iterator over all the shards
|
||||
if self.is_training:
|
||||
pipeline.extend([
|
||||
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(handler=log_and_continue),
|
||||
wds.shuffle(
|
||||
self.sample_shuffle_size,
|
||||
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
|
||||
])
|
||||
else:
|
||||
pipeline.extend([
|
||||
self._split_by_node_and_worker,
|
||||
# at this point, we have an iterator over the shards assigned to each worker
|
||||
wds.tarfile_to_samples(handler=log_and_continue),
|
||||
])
|
||||
pipeline.extend([
|
||||
partial(
|
||||
_decode_samples,
|
||||
image_key=self.image_key,
|
||||
image_format=self.image_format,
|
||||
alt_label=self.split_info.alt_label)
|
||||
])
|
||||
self.ds = wds.DataPipeline(*pipeline)
|
||||
self.init_count += 1
|
||||
|
||||
def _split_by_node_and_worker(self, src):
|
||||
if self.global_num_workers > 1:
|
||||
for s in islice(src, self.global_worker_id, None, self.global_num_workers):
|
||||
yield s
|
||||
else:
|
||||
for s in src:
|
||||
yield s
|
||||
|
||||
def __iter__(self):
|
||||
if not self.init_count:
|
||||
self._lazy_init()
|
||||
|
||||
i = 0
|
||||
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
|
||||
if self.is_training and self.batch_size is not None:
|
||||
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size
|
||||
ds = self.ds.with_epoch(num_worker_samples)
|
||||
for sample in ds:
|
||||
yield sample[self.image_key], sample[self.target_key]
|
||||
i += 1
|
||||
print('end', i) # FIXME debug
|
||||
|
||||
def __len__(self):
|
||||
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
|
||||
|
||||
def _filename(self, index, basename=False, absolute=False):
|
||||
assert False, "Not supported" # no random access to examples
|
||||
|
||||
def filenames(self, basename=False, absolute=False):
|
||||
""" Return all filenames in dataset, overrides base"""
|
||||
if not self.init_count:
|
||||
self._lazy_init()
|
||||
|
||||
names = []
|
||||
for sample in self.ds:
|
||||
if self.filename_key in sample:
|
||||
name = sample[self.filename_key]
|
||||
elif '__key__' in sample:
|
||||
name = sample['__key__'] + self.key_ext
|
||||
else:
|
||||
assert False, "No supported name field present"
|
||||
names.append(name)
|
||||
if len(names) >= self.num_samples:
|
||||
break # safety for ds.repeat() case
|
||||
return names
|
@ -0,0 +1 @@
|
||||
from .imagenet22k import Imagenet22k, Imagenet12k, imagenet12k_synsets, imagenet22k_synsets
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,87 @@
|
||||
import torch.cuda
|
||||
|
||||
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .mixup import FastCollateMixup
|
||||
from .random_erasing import RandomErasing
|
||||
|
||||
|
||||
class PrefetcherCuda:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
loader,
|
||||
device: torch.device = torch.device('cuda'),
|
||||
dtype=torch.float32,
|
||||
normalize=True,
|
||||
normalize_shape=(1, 3, 1, 1),
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
re_prob=0.,
|
||||
re_mode='const',
|
||||
re_count=1,
|
||||
num_aug_splits=0,
|
||||
):
|
||||
self.loader = loader
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
if normalize:
|
||||
self.mean = torch.tensor(
|
||||
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
|
||||
self.std = torch.tensor(
|
||||
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
|
||||
else:
|
||||
self.mean = None
|
||||
self.std = None
|
||||
if re_prob > 0.:
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
def __iter__(self):
|
||||
stream = torch.cuda.Stream()
|
||||
first = True
|
||||
|
||||
for next_input, next_target in self.loader:
|
||||
with torch.cuda.stream(stream):
|
||||
next_input = next_input.to(device=self.device, non_blocking=True)
|
||||
next_input = next_input.to(dtype=self.dtype)
|
||||
if self.mean is not None:
|
||||
next_input.sub_(self.mean).div_(self.std)
|
||||
next_target = next_target.to(device=self.device, non_blocking=True)
|
||||
if self.random_erasing is not None:
|
||||
next_input = self.random_erasing(next_input)
|
||||
|
||||
if not first:
|
||||
yield input, target
|
||||
else:
|
||||
first = False
|
||||
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
||||
yield input, target
|
||||
|
||||
def __len__(self):
|
||||
return len(self.loader)
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.loader.sampler
|
||||
|
||||
@property
|
||||
def dataset(self):
|
||||
return self.loader.dataset
|
||||
|
||||
@property
|
||||
def mixup_enabled(self):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
return self.loader.collate_fn.mixup_enabled
|
||||
else:
|
||||
return False
|
||||
|
||||
@mixup_enabled.setter
|
||||
def mixup_enabled(self, x):
|
||||
if isinstance(self.loader.collate_fn, FastCollateMixup):
|
||||
self.loader.collate_fn.mixup_enabled = x
|
@ -1 +1 @@
|
||||
__version__ = '0.6.9'
|
||||
__version__ = '0.8.2.dev0'
|
||||
|
Loading…
Reference in new issue