Compare commits

...

101 Commits

Author SHA1 Message Date
Ross Wightman a25bf974a9
Update README.md
2 years ago
Ross Wightman b4ea69c9ce
Merge pull request #1414 from dedeswim/bits_and_tpu
2 years ago
Ross Wightman 38594ef7fd Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman 87bfb055c0 Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Edoardo Debenedetti 1dced6066c
Merge branch 'rwightman:bits_and_tpu' into bits_and_tpu
2 years ago
Ross Wightman f07dfe010a Merge remote-tracking branch 'origin/more_vit' into bits_and_tpu
2 years ago
Edoardo Debenedetti 5a40c6a3c4 Fix issue with torchvision's ImageNet
2 years ago
Ross Wightman bd6d377c74 Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman 6fe01993ad verions 0.8.x for bits_and_tpu branch
2 years ago
Ross Wightman 1186fc9c73 Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman dff33730b3 Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman 9c321be330
Merge pull request #1239 from dedeswim/parser-fix
2 years ago
Edoardo Debenedetti c76d772670 Add support for different TFDS `BuilderConfig`s
2 years ago
Ross Wightman 754e11402a Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman 1ba0ec4c18 Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman 749856cf25 Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 95739b45d7 Fix partially removed alt_lable impl from TFDS variant of ImageNet22/12k
2 years ago
Ross Wightman 5e1be34a60 Add ImageNet-22k/12k TFDS dataset defs
2 years ago
Ross Wightman 59ffab537c Fix mistake in wds sample slicing
2 years ago
Ross Wightman ef57561d51 Fix some TPU (XLA) issues with swin transformer v2
2 years ago
Ross Wightman ab16a358bb Add log and continue handler for WDS errors, fix args.num_gpu for validation script fallback
2 years ago
Ross Wightman 7eeaf521a0 use gopen in wds to open info file in case it's at a url/gs location
2 years ago
Ross Wightman 229ac6b8d8 Fix alternate label handling in WDS parser to skip invalid alt labels
2 years ago
Ross Wightman a444d4b891 Add alternative label support to WDS for imagenet22k/12k split, add 21k/22k/12k indices filters to results/
2 years ago
Ross Wightman da2796ae82 Add webdataset (WDS) support, update TFDS to make some naming in parsers more similar. Fix workers=0 compatibility. Add ImageNet22k/12k synset defs.
2 years ago
Ross Wightman 3fce010ca8 Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 15cc9eae3e Fix Swin v2 tuple type hint
2 years ago
Ross Wightman bb85b09d2a swin v2 fixup for latest changes on norm_norm_norm / bits_and_tpu branch
2 years ago
Ross Wightman 10fa42b143 Merge branch 'ChristophReich1996-master' into bits_and_tpu
2 years ago
Ross Wightman c639a86c67 Change TFDS default to full re-shuffle (init) each epoch (for now)
2 years ago
Ross Wightman a16ea1e355 Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman fafece230b Allow changing base lr batch size from 256 via arg
2 years ago
Ross Wightman 7148039f9f Tweak base lr log
2 years ago
Ross Wightman f82fb6b608 Add base lr w/ linear and sqrt scaling to train script
2 years ago
Ross Wightman 066e490605 Merge branch 'norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 7eb7e73216 File will not stay deleted
2 years ago
Ross Wightman 4c8bb295ab Remove bn-tf arg
2 years ago
Ross Wightman 0012bf7fb5 Merge branch 'norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman cbc4f33220 Merge branch 'norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 40f4745366 Merge branch 'norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 1c21cac8f9 Add drop args to benchmark.py
2 years ago
Ross Wightman d829858550 Significant norm update
2 years ago
Ross Wightman 57fca2b5b2 Fix c16_evos stem / first conv setup
2 years ago
Ross Wightman 1f54a1fff7 Add C16 and E8 EvoNormS0 configs for RegNetZ BYOB nets
2 years ago
Ross Wightman 4d7a5544f7 Remove inplace sigmoid for consistency with other impl
2 years ago
Ross Wightman 88a5b54802 A few small evonorm tweaks for convergence comparisons
2 years ago
Ross Wightman 66daee4f31 Last change wasn't complete, missed adding full evo_norm changeset
2 years ago
Ross Wightman 7bbbd5ef1b EvoNorm and GroupNormAct options for debugging TPU / XLA concerns
2 years ago
Ross Wightman ff0f709c20 Testing TFDS shuffle across epochs
2 years ago
Ross Wightman 69e90dcd8c Merge branch 'norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 820ae9925e Fix load_state_dict to handle None ema entries
2 years ago
Ross Wightman 0e212e8fe5 Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman cad170e494 Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu
2 years ago
Ross Wightman 809c7bb1ec Merge remote-tracking branch 'origin/master' into bits_and_tpu
2 years ago
Ross Wightman 871cef4198 version 0.5.1
2 years ago
Ross Wightman 4f338556d8 Fixes and improvements for metrics, tfds parser, loader / transform handling
2 years ago
Ross Wightman d9b0b3d60f device arg wasn't removed from PrefetcherCuda instantiation of RE
2 years ago
Ross Wightman 80ca078aed Fix a few bugs and formatting/naming issues
2 years ago
Ross Wightman 406c486ba2 Merge remote-tracking branch 'origin/more_datasets' into bits_and_tpu
2 years ago
Ross Wightman 07693f81b0 Validation fix since we don't have multi-GPU DataParallel support yet
2 years ago
Ross Wightman 59a3409182
Update README.md
3 years ago
Ross Wightman a45186a6e8 Merge remote-tracking branch 'origin/master' into bits_and_tpu
3 years ago
Ross Wightman 690f31d02d Post merge cleanup, restore previous unwrap fn
3 years ago
Ross Wightman 3b6ba76126 Merge remote-tracking branch 'origin/master' into bits_and_tpu
3 years ago
Ross Wightman 1fdc7af8fd Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu
3 years ago
Ross Wightman 52c481ea8e Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu
3 years ago
Ross Wightman 25d52ea71d Merge remote-tracking branch 'origin/fixes_bce_regnet' into bits_and_tpu
3 years ago
Ross Wightman 3581affb77 Update train.py with some flags related to scheduler tweaks, fix best checkpoint bug.
3 years ago
Ross Wightman c2f02b08b8 Merge remote-tracking branch 'origin/attn_update' into bits_and_tpu
3 years ago
Ross Wightman f2e14685a8 Add force-cpu flag for train/validate, fix CPU fallback for device init, remove old force cpu flag for EMA model weights
3 years ago
Ross Wightman 2ee398d501 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman f4fb068b11 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman b0265ef8a6 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman 0d82876132 Add comment for reference re PyTorch XLA 'race' issue
3 years ago
Ross Wightman b76b48e8e9 Update optimizer creation for master optimizer changes
3 years ago
Ross Wightman f98662b9c9 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman cb621e0f00 Remove print, arg order
3 years ago
Ross Wightman b974d85026 Merge branch 'bits_and_tpu' of github.com:rwightman/pytorch-image-models into bits_and_tpu
3 years ago
Ross Wightman c06c739901 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman 40457e5691 Transforms, augmentation work for bits, add RandomErasing support for XLA (pushing into transforms), revamp of transform/preproc config, etc ongoing...
3 years ago
Ross Wightman 5e95ced5a7 timm bits checkpoint support for avg_checkpoints.py
3 years ago
Ross Wightman 56ed0a0b63 Merge branch 'vit_and_bit_test_fixes' into bits_and_tpu
3 years ago
Ross Wightman 847b4af144
Update README.md
3 years ago
Ross Wightman 5c5cadfe4c
Update README.md
3 years ago
Ross Wightman ee2b8f49ee
Update README.md
3 years ago
Ross Wightman cc870df7b8
Update README.md
3 years ago
Ross Wightman 6b2d9c2660 Another bits/README.md update
3 years ago
Ross Wightman c3db5f5801 Worker hack for TFDS eval, add TPU env var setting.
3 years ago
Ross Wightman f411724de4 Fix checkpoint delete issue. Add README about bits and initial Pytorch XLA usage on TPU-VM. Add some FIXMEs and fold train_cfg into train_state by default.
3 years ago
Ross Wightman b57a03bd0d Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman 91ab0b6ce5 Add proper TrainState checkpoint save/load. Some reorg/refactoring and other cleanup. More to go...
3 years ago
Ross Wightman 5b9c69e80a Add basic training resume based on legacy code
3 years ago
Ross Wightman 4210d922d2 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman 72ca831dd4 Back to using strings for the enum translation, forgot about import dep
3 years ago
Ross Wightman cbd4ee737f Fix model init for XLA, remove some prints.
3 years ago
Ross Wightman 6d90fcf282 Fix distribute_bn and model_ema
3 years ago
Ross Wightman 74d2829341 Merge branch 'master' into bits_and_tpu
3 years ago
Ross Wightman aa92d7b1c5 Major timm.bits update. Updater and DeviceEnv now dataclasses, after_step closure used, metrics base impl w/ distributed reduce, many tweaks/fixes.
3 years ago
Ross Wightman 938716c753 Fix import issue, use devenv for dist info in parser_tfds
3 years ago
Ross Wightman 76de984a5f Fix some bugs with XLA support, logger, add hacky xla dist launch script since torch.dist.launch doesn't work
3 years ago
Ross Wightman 12d9a6d4d2 First timm.bits commit, add initial abstractions, WIP updates to train, val... some of it working
3 years ago

@ -15,6 +15,7 @@
Thanks to the following for hardware support:
* TPU Research Cloud (TRC) (https://sites.research.google/trc/about/)
* TPU support can be found on the [`bits_and_tpu`](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/) branch, w/ some setup help [here](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)
* Nvidia (https://www.nvidia.com/en-us/)
And a big thanks to all GitHub sponsors who helped with some of my costs before I joined Hugging Face.

@ -13,14 +13,16 @@ import os
import hashlib
import shutil
from collections import OrderedDict
from timm.models.helpers import load_state_dict
from timm.utils import setup_default_logging
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
@ -30,28 +32,25 @@ _TEMP_NAME = './_checkpoint.pth'
def main():
args = parser.parse_args()
setup_default_logging()
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint))
state_dict = load_state_dict(checkpoint, use_ema=use_ema)
new_state_dict = {}
if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> Loading checkpoint '{}'".format(args.checkpoint))
state_dict = load_state_dict(args.checkpoint, use_ema=args.use_ema)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if clean_aux_bn and 'aux_bn' in k:
if args.clean_aux_bn and 'aux_bn' in k:
# If all aux_bn keys are removed, the SplitBN layers will end up as normal and
# load with the unmodified model using BatchNorm2d.
continue
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
@ -61,19 +60,17 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
if args.output:
checkpoint_root, checkpoint_base = os.path.split(args.output)
checkpoint_base = os.path.splitext(checkpoint_base)[0]
else:
checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0]
checkpoint_base = os.path.splitext(args.checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename
else:
print("Error: Checkpoint ({}) doesn't exist".format(checkpoint))
return ''
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
if __name__ == '__main__':

@ -13,7 +13,7 @@ import numpy as np
import torch
from timm.models import create_model, apply_test_time_pool
from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.data import ImageDataset, create_loader_v2, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
torch.backends.cudnn.benchmark = True
@ -82,7 +82,7 @@ def main():
else:
model = model.cuda()
loader = create_loader(
loader = create_loader_v2(
ImageDataset(args.data),
input_size=config['input_size'],
batch_size=args.batch_size,

@ -0,0 +1,66 @@
"""
Adapatation of (pre-elastic) torch.distributed.launch for pytorch xla.
`torch.distributed.launch` is a module that spawns up multiple distributed
training processes on each of the training nodes.
"""
import sys
import subprocess
import importlib
import os
from argparse import ArgumentParser, REMAINDER
from typing import Optional, IO
import torch_xla.distributed.xla_multiprocessing as xmp
def parse_args():
"""
Helper function parsing the command line options
@retval ArgumentParser
"""
parser = ArgumentParser(
description="PyTorch distributed training launch helper utility"
"that will spawn up multiple distributed processes")
# Optional arguments for the launch helper
parser.add_argument("--num-devices", type=int, default=1,
help="The number of XLA devices to use for distributed training")
# positional
parser.add_argument(
"script", type=str,
help="The full path to the single device training script to be launched"
"in parallel, followed by all the arguments for the training script")
# rest from the training program
parser.add_argument('script_args', nargs=REMAINDER)
return parser.parse_args()
def main():
args = parse_args()
# set PyTorch distributed related environmental variables
# current_env = os.environ.copy()
# current_env["MASTER_ADDR"] = args.master_addr
# current_env["MASTER_PORT"] = str(args.master_port)
# current_env["WORLD_SIZE"] = str(dist_world_size)
# if 'OMP_NUM_THREADS' not in os.environ and args.nproc_per_node > 1:
# current_env["OMP_NUM_THREADS"] = str(1)
script_abs = os.path.abspath(args.script)
script_base, script_rel = os.path.split(script_abs)
sys.path.append(script_base)
mod = importlib.import_module(os.path.splitext(script_rel)[0])
sys.argv = [args.script] + args.script_args
xmp.spawn(mod._mp_entry, args=(), nprocs=args.num_devices)
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,133 @@
# Timm Bits
## Intro
A collection of reusable components and lightweight abstractions for training and evaluating NN with PyTorch and PyTorch XLA.
This is an early WIP (consider it pre-alpha) with the primary goal to get up and running on TPUs w/ PyTorch XLA as the first priority. Expect significant changes, rewrites, additions, and of course bugs.
The current train.py and validate.py scripts are evolving to use the timm.bits components, they will also change significantly.
## Bits Design Brief
`bits` is designed to be a lightweight and modular set of training abstractions. It shares concepts with other libraries (fastai, ignite, lightning, keras, etc, etc) but is not modeled after any specific one. It is supposed to be a 'bit different', hackable, and is purposely not trying to serve every use case or be everything to everyone.
`timm` models will always be useable in pure PyTorch w/o `bits` or anything dependencies besides the model utils and helpers for pretrained models, feature extraction, default data config.
I may breakout bits into a diff project if there is any interest besides my own use for timm image and video model training.
The layers:
* DeviceEnv - DeviceEnv dataclass abstraction handles PyTorch CPU, GPU and XLA device differences, incl distributed functions, parallel wrappers, etc. There is more than a passing similarity to HuggingFace Accelerate, but developed in parallel and with some difference in the detail and separation of concerns.
* Updater - A dataclass that combines the backward pass, optimizer step, grad scaling, grad accumulation in device specific abstraction.
* Currently, basic single optimizer, single forward/backward Updaters are included for GPU, XLA.
* Deepseed will need its own Updater(s) since its Engine is a monolith of epic proportions that breaks all separations of concern in PyTorch (UGH!). NOTE Deepspeed not working yet nor is it a priority.
* Monitor - pull together all console logging, csv summaries, tensorboard, and WandB summaries into one module for monitoring your training.
* Checkpoint Manager - keeps track of your checkpoints
* Metrics - yet another set of metrics, although this may be replaced w/ an external set of classes. Uses same update / reset / compute interface as Ignite and Lightning (in theory interchangeable w/ a thin adapter). Metrics keep state on GPU / TPU to avoid device -> cpu transfers (esp for XLA).
* Task (not implemented yet) - combine your model(s) w/ losses in a task specific module, will also allow task factory for easy build of appripriate metrics
* TrainState - dataclasses to hold your tasks (models), updater state, etc
* Train Loop Functions (still in train.py script, not refined) - set of functions for train step, 'after step', evaluate using all of the components mentioned
How is this different than other options?
* I'm very much trying to avoid a monolithic trainer / learner / model wrapping type class with numerous hooks and overrides to keep track of (avoiding granular inversion of control!).
* The goal is to provide reusable modules that can (hopefully) be mixed and matched w/ other code.
* Many of the components are based on Python dataclasses to reduce boilerplate.
* The train loop components are (will be) functional with easy to follow flow control, and are intended to be replaced when something different is needed, not augmented with hooks via callbacks or inheritence at every conceivable touch point.
## Quick Start for PyTorch XLA on TPU-VM
Most initial users will likely be interested in training timm models w/ PyTorch XLA on TPU-VM instances, this quick start will get you moving.
If you haven't noticed, this code is on a branch, make sure you checkout the `bits_and_tpu` branch on `timm` before doing this. You can test locally on your GPU too, in either XLA + GPU in a container or the usual PyTorch w/ GPU.
## Setup Python environment
This setup assumes you've SSH'd into your TPU-VM after setting it up (https://cloud.google.com/tpu/docs/users-guide-tpu-vm). Don't forget to do this in a TMUX session or you'll be sad if you lose your connection!
The TPU-VM instances I've been using have a usable version of PyTorch XLA 1.8.1 installed in the python3 environment, we will be using that.
I've found that leveraging TFDS w/ datasets in TFRecord format, streamed from Google Storage buckets is the most practical / cost-effective solution. I've written a PyTorch IterabeDataset wrapper around TFDS so we will install Tensorflow datasets and use that. Traditional PyTorch datasets on local disks do work w/ bits for all of TPU-VM, GPU cloud instances, and your local machine. Setting up persistent disks wasn't the easiest thing to do on TPU-VMs so TFDS was my default in that context.
One thing to watch, be very careful that you don't use a GS based dataset in a different continent from you TPU-VM instances. I burned through a few thousand USD leaving some wires crossed for 1 day. Otherwise the cost of training w/ buckets in same region are quite low.
### Install TFDS (if using GS buckets)
```
pip3 install tensorflow-datasets
```
In some tpu-vm instances may have tensorflow version pre-installed that conflict with tensorflow-datasets, especially the bucket reading support. If training crashes with errors about inability to ready from buckets, tensorflow symbol errors, tensorflow datasets missing functions, etc, you should try removing the pre-installed tensorflow and installing one from pypi.
```
sudo pip3 uninstall tf-nightly
pip3 install tensorflow-cpu
```
You may run into some numpy / pytorch version dependency issues here, try capping the version of tensorflow at `==2.4.1` in above command.
### Get your dataset into buckets
You will need to host your dataset in buckets. I have tried creating custom datasets for this setup, but have used a number of TFDS datasets such as ImageNet, Flowers, caltech Birds, Oxford Pets that are available in TFDS.
The TFDS dataset pages (https://www.tensorflow.org/datasets/catalog/imagenet2012) have directions for the various datasets, I recommend building them in a different VM or local machine and then uploading to your training bucket. Many of them will auto-download and build the tfrecord shards for you. ImageNet needs to be downloaded manually.
### Use a custom allocator
With PyTorch XLA on a TPU-VM and TFDS you'll end up with a lot of processes and buffering. The instance memory will be used up quickly. I highly recommend using a custom allocator via `LD_PRELOAD`. tcmalloc may now be a default in the tpu-vm instanecs (check first). jemalloc also worked well for me. If LD_PRELOAD is not set in your env, do the following
```
sudo apt update
sudo apt install google-perftools
export LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
```
# Train, train, train
With all the above done, you should be ready to train... below is one particular train command I've just recently been using for some trials on vision MLP models...
Make sure the TPU config for PyTorch XLA on TPU-VM is set:
```
export XRT_TPU_CONFIG="localservice;0;localhost:51011"
```
Then, launch fighters!
```
python3 launch_xla.py --num-devices 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
NOTE: build my TFDS dataset at ver 5.0.0 and it defaults to a newer version now. Change accordingly.
# Quick Start w/ GPU
`timm bits` should work great on your multi-GPU setups just like the old `timm` training script with either TFDS based datasets or a local folder.
The equivalent training command of the XLA setup above if you were on an 8-GPU machine and using TFDS would be,
```
./distrbuted_train.sh 8 train.py gs://my-imagenet-bucket --dataset tfds/imagenet2012:5.0.0 --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
Or this for imagenet in a local folder,
```
./distrbuted_train.sh 8 train.py /path/to/imagenet --model resmlp_24_224 --opt adamw --opt-eps 1e-6 --clip-grad 1.0 --drop-path 0.1 --mixup 0.5 --cutmix 1.0 --aa rand-m6-n4-mstd1.0-inc1 --weight-decay .08 --model-ema --model-ema-decay 0.99993 --sched cosine -j 4 --warmup-lr 1e-6 --warmup-epochs 20 --epochs 500 --lr 8.8e-4 -b 256
```
# Gotchas and Known Issues
* When PyTorch XLA crashes, you hit a TPU OOM etc, lots of processes get orphaned. Get in the habit of killing all python processes before starting a new train run.
* `alias fml='pkill -f python3'`
* For TFDS use, due to the way PyTorch IterableDatasets work at the loader level, each worker process builds batches independently -- they are not dequeued and collated across workers. For validation especially, getting all the samples evenly divided across BOTH the distributed processes AND the dataset workers is a bit annoying. For now keeping the num_workers arg (j) low is advisable, especially for very small validation sets. This can limit your throughput though.
* Random erasing works with PyTorch XLA but it must be done on the images before they are moved into tensors on the XLA device. This changes the dataloader pipelien a bit and increases the size of the data being moved to device (float instead of int8) so has an impact on dataloading speed.
* There are a number of models using ops that aren't lowered to XLA, this will REALLY slow things down to the point of being unusable. There are flags you can set to debug this, see PyTorch XLA troubleshooting page (https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md)
* This code doesn't currently work when float16 is forced via `XLA_USE_BF16=1` env arg, it will mess up metrics tensors that overflow in bfloat16. Better controlling model activation vs weight precision vs other tensors is a TODO.
* I haven't tested this code with pre TPU-VM (2-VM) setups, but it should work w/ correct config. I intend to make it work with Colab and Kaggle TPU notebooks soon.
* Your first batch, and generally first epoch will be slow with Pytorch XLA, after that things pick up and move along quickly. Be patient.
# Bugs and Discussion
If you find bugs (there are likely many), feel free to file an issue with `[BITS]` as the title prefix. Open a discussion if you have design ideas, again use `[BITS]` in the title.
# Acknowledgements
The TPU-VMs I've used for creating and testing this code, and that I hope to use for many future `timm` models were made available by the TPU Research Cloud (https://sites.research.google/trc/).

@ -0,0 +1,26 @@
from .avg_scalar import AvgMinMaxScalar
from .avg_tensor import AvgTensor
from .checkpoint_manager import CheckpointManager
from .device_env import DeviceEnv, DeviceEnvType, get_global_device, set_global_device, is_global_device
from .device_env_cuda import DeviceEnvCuda
from .device_env_factory import initialize_device
from .device_env_xla import DeviceEnvXla
from .distributed import distribute_bn, all_gather_recursive, all_reduce_recursive, broadcast_recursive,\
all_reduce_sequence, all_gather_sequence
# from .evaluate import evaluate, eval_step
from .monitor import Monitor
from .metric import Metric, MetricValueT
from .metric_accuracy import AccuracyTopK
from .tracker import Tracker
# from .task_metrics import TaskMetrics, TaskMetricsClassify
from .train_cfg import TrainCfg
from .train_services import TrainServices
from .train_setup import setup_model_and_optimizer
from .train_state import TrainState
# from .task import TaskClassify
from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
from .updater_factory import create_updater
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
# from .train import train_one_epoch, Experiment

@ -0,0 +1,24 @@
class AvgMinMaxScalar:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.min = None
self.max = None
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.min = val if self.min is None else min(self.min, val)
self.max = val if self.max is None else max(self.max, val)
self.sum += val * n
self.count += n
self.avg = self.sum / self.count

@ -0,0 +1,54 @@
import torch
class AvgTensor:
"""Computes and stores the average and current value"""
def __init__(self, accumulate_dtype=torch.float32):
self.accumulate_dtype = accumulate_dtype
self.sum = None
self.count = None
self.reset()
# FIXME handle distributed operation
def reset(self):
self.sum = None
self.count = None
def update(self, val: torch.Tensor, n=1):
if self.sum is None:
self.sum = torch.zeros_like(val, dtype=self.accumulate_dtype)
self.count = torch.tensor(0, dtype=torch.long, device=val.device)
self.sum += (val * n)
self.count += n
def compute(self):
return self.sum / self.count
class TensorEma:
"""Computes and stores the average and current value"""
def __init__(
self,
smoothing_factor=0.9,
init_zero=False,
accumulate_dtype=torch.float32
):
self.accumulate_dtype = accumulate_dtype
self.smoothing_factor = smoothing_factor
self.init_zero = init_zero
self.val = None
self.reset()
# FIXME handle distributed operation
def reset(self):
self.val = None
def update(self, val):
if self.val is None:
if self.init_zero:
self.val = torch.zeros_like(val, dtype=self.accumulate_dtype)
else:
self.val = val.clone().to(dtype=self.accumulate_dtype)
self.val = (1. - self.smoothing_factor) * val + self.smoothing_factor * self.val

@ -0,0 +1,109 @@
import logging
import os
from collections import OrderedDict
from typing import Dict, Any, Callable
import torch
from timm.utils import unwrap_model
from .device_env import DeviceEnv
from .train_state import TrainState
_logger = logging.getLogger(__name__)
def save_train_state(
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS?
train_state: TrainState,
extra_state: Dict[str, Any] = None,
unwrap_fn: Callable = unwrap_model,
dev_env: DeviceEnv = None,
log_info: bool = True):
assert not train_state.updater.deepspeed
# DeepSpeed has a fully custom checkpoint saving setup, it is not possible
# specify a filename, checkpoints needed to be saved from all ranks, etc
# if train_state.updater.deepspeed:
# save_train_state_deepspeed(train_state, checkpoint_path)
dev_env = dev_env or DeviceEnv.instance()
state_dict = train_state.state_dict(unwrap_fn=unwrap_fn)
if extra_state:
state_dict.update(extra_state)
if dev_env.type_xla:
# XLA state dict needs to be moved to CPU before save, this is normally done by xm.save
state_dict = dev_env.state_dict_to_cpu(state_dict)
torch.save(state_dict, checkpoint_path)
def load_train_state(
train_state: TrainState,
checkpoint_path: str, # FIXME pass base path + file pattern + epoch / step separately for DS
unwrap_fn: Callable = None,
load_opt: bool = True,
dev_env: DeviceEnv = None,
log_info: bool = True
):
unwrap_fn = unwrap_fn or unwrap_model
if not os.path.isfile(checkpoint_path):
_logger.error("No valid resume checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
if log_info:
_logger.info('Restoring training state from checkpoint...')
checkpoint = torch.load(checkpoint_path, map_location='cpu')
assert isinstance(checkpoint, dict)
if not checkpoint.get('version', 0) > 2:
load_legacy_checkpoint(train_state, checkpoint=checkpoint, load_opt=load_opt, log_info=log_info)
if log_info:
_logger.info("Loaded legacy checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
return
train_state.load_state_dict(checkpoint, unwrap_fn=unwrap_fn, load_opt=load_opt)
if log_info:
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, train_state.epoch))
def _get_state_dict(checkpoint, state_dict_key='state_dict'):
new_state_dict = OrderedDict()
for k, v in checkpoint[state_dict_key].items():
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
return new_state_dict
def load_legacy_checkpoint(
train_state: TrainState,
checkpoint,
load_opt=True,
log_info=True):
assert isinstance(checkpoint, dict) and 'state_dict' in checkpoint
train_state.model.load_state_dict(_get_state_dict(checkpoint))
if train_state.model_ema is not None and 'state_dict_ema' in checkpoint:
if log_info:
_logger.info('Restoring model (EMA) state from checkpoint...')
unwrap_model(train_state.model_ema).load_state_dict(_get_state_dict(checkpoint, 'state_dict_ema'))
if load_opt:
if train_state.updater.optimizer is not None and 'optimizer' in checkpoint:
if log_info:
_logger.info('Restoring optimizer state from checkpoint...')
train_state.updater.optimizer.load_state_dict(checkpoint['optimizer'])
scaler_state_dict_key = 'amp_scaler'
if train_state.updater.grad_scaler is not None and scaler_state_dict_key in checkpoint:
if log_info:
_logger.info('Restoring AMP loss scaler state from checkpoint...')
train_state.updater.grad_scaler.load_state_dict(checkpoint[scaler_state_dict_key])
if 'epoch' in checkpoint:
resume_epoch = checkpoint['epoch']
if 'version' in checkpoint and checkpoint['version'] > 1:
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
train_state.epoch = resume_epoch # FIXME use replace if we make train_state read-only

@ -0,0 +1,219 @@
""" Checkpoint Manager
Track top-n training checkpoints and maintain recovery checkpoints on specified intervals.
Hacked together by / Copyright 2021 Ross Wightman
"""
import glob
import logging
import operator
import os
import shutil
from typing import Optional, Dict, Callable, List
from dataclasses import dataclass, replace
from .checkpoint import save_train_state
from .train_state import TrainState
_logger = logging.getLogger(__name__)
@dataclass
class CheckpointInfo:
path: str = ''
metrics: Dict[str, float] = None # all metrics at time of checkpoint save
metric_name: str = 'loss'
metric_decreasing: bool = True
epoch: int = 0
global_step: int = 0
@property
def valid_key(self):
return self.metric_name and self.metrics and self.metric_name in self.metrics
@property
def sort_key(self):
return self.metrics[self.metric_name] if self.valid_key else self.epoch
@property
def decreasing_key(self):
return self.metric_decreasing if self.valid_key else False
class CheckpointManager:
def __init__(
self,
hparams=None,
save_state_fn=None,
checkpoint_dir='',
recovery_dir='',
checkpoint_tmpl=None,
recovery_tmpl=None,
metric_name='loss',
metric_decreasing=True,
max_history=10):
# extra items to include in checkpoint
self.hparams = hparams # train arguments (config / hparams) # FIXME this will change with new config system
# state
self.checkpoint_files: List[CheckpointInfo] = [] # (filename, metric) tuples in order of decreasing betterness
self.best_checkpoint = None
self.curr_recovery_file = ''
self.prev_recovery_file = ''
self.can_hardlink = True
# util / helper fn
self.save_state_fn = save_state_fn or save_train_state
# file / folder config
self.extension = '.pth.tar'
self.checkpoint_dir = checkpoint_dir
self.recovery_dir = recovery_dir
self.checkpoint_tmpl = (checkpoint_tmpl or 'checkpoint-{index}') + self.extension
self.recovery_tmpl = (recovery_tmpl or 'recovery-{index}') + self.extension
# ordering / history config
self.metric_name = metric_name
self.metric_decreasing = metric_decreasing
self.metric_cmp_fn = operator.lt if metric_decreasing else operator.gt
self.max_history = max_history
assert self.max_history >= 1
def _replace(self, src, dst):
if self.can_hardlink:
try:
if os.path.exists(dst):
os.unlink(dst) # required for Windows support.
except (OSError, NotImplementedError) as e:
self.can_hardlink = False
os.replace(src, dst)
def _duplicate(self, src, dst):
if self.can_hardlink:
try:
if os.path.exists(dst):
# for Windows
os.unlink(dst)
os.link(src, dst)
return
except (OSError, NotImplementedError) as e:
self.can_hardlink = False
shutil.copy2(src, dst)
def _save(self, save_path, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
extra_state = dict(
# version < 2 increments epoch before save
# version < 3, pre timm bits
# version 3, first timm bits checkpoitns
version=3,
)
if self.hparams is not None:
extra_state.update(dict(arch=self.hparams['model'], hparams=self.hparams))
else:
arch = getattr(train_state.model, 'default_cfg', dict()).get('architecture', None)
if arch is None:
arch = type(train_state.model).__name__.lower()
extra_state.update(dict(arch=arch))
if metrics is not None:
# save the metrics and how we originally sorted them in the checkpoint for future comparisons
extra_state.update(dict(
metrics=metrics,
metric_name=self.metric_name,
metric_decreasing=self.metric_decreasing
))
self.save_state_fn(save_path, train_state, extra_state)
checkpoint_info = CheckpointInfo(
path=save_path,
metrics=metrics,
metric_name=self.metric_name,
metric_decreasing=self.metric_decreasing,
epoch=train_state.epoch,
global_step=train_state.step_count_global,
)
return checkpoint_info
def _udpate_checkpoints(self, info: CheckpointInfo):
self.checkpoint_files.append(info)
self.checkpoint_files = sorted(
self.checkpoint_files,
key=lambda x: x.sort_key,
reverse=not info.decreasing_key, # sort in descending order if a lower metric is not better
)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim
if delete_index < 0 or len(self.checkpoint_files) <= delete_index:
return
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
_logger.debug("Cleaning checkpoint: {}".format(d))
os.remove(d.path)
except OSError as e:
_logger.error("Exception '{}' while deleting checkpoint".format(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def _compare_metric(self, lhs: CheckpointInfo, rhs: CheckpointInfo):
# compare metrics against an existing checkpoint
if not lhs or not lhs.valid_key or not rhs or not rhs.valid_key:
# always assume lhs metrics are better if there are no usable metrics to compare
return True
return self.metric_cmp_fn(lhs.sort_key, rhs.sort_key)
def save_checkpoint(self, train_state: TrainState, metrics: Optional[Dict[str, float]] = None):
assert train_state.epoch >= 0
tmp_save_path = os.path.join(self.checkpoint_dir, 'tmp' + self.extension)
last_save_path = os.path.join(self.checkpoint_dir, 'last' + self.extension)
curr_checkpoint = self._save(tmp_save_path, train_state, metrics)
self._replace(tmp_save_path, last_save_path)
worst_checkpoint = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self._compare_metric(curr_checkpoint, worst_checkpoint):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
filename = self.checkpoint_tmpl.format(index=train_state.epoch)
save_path = os.path.join(self.checkpoint_dir, filename)
curr_checkpoint = replace(curr_checkpoint, path=save_path)
self._duplicate(last_save_path, save_path)
self._udpate_checkpoints(curr_checkpoint)
checkpoints_str = "Current checkpoints:\n"
for c in self.checkpoint_files:
checkpoints_str += f' {c.path}, {c.sort_key}\n'.format(c)
_logger.info(checkpoints_str)
if curr_checkpoint.valid_key and self._compare_metric(curr_checkpoint, self.best_checkpoint):
self.best_checkpoint = curr_checkpoint
best_save_path = os.path.join(self.checkpoint_dir, 'best' + self.extension)
self._duplicate(last_save_path, best_save_path)
return curr_checkpoint if self.best_checkpoint is None else self.best_checkpoint
def save_recovery(self, train_state: TrainState):
tmp_save_path = os.path.join(self.recovery_dir, 'recovery_tmp' + self.extension)
self._save(tmp_save_path, train_state)
filename = self.recovery_tmpl.format(index=train_state.step_count_global)
save_path = os.path.join(self.recovery_dir, filename)
self._replace(tmp_save_path, save_path)
if os.path.exists(self.prev_recovery_file):
try:
_logger.debug("Cleaning recovery: {}".format(self.prev_recovery_file))
os.remove(self.prev_recovery_file)
except Exception as e:
_logger.error("Exception '{}' while removing {}".format(e, self.prev_recovery_file))
self.prev_recovery_file = self.curr_recovery_file
self.curr_recovery_file = save_path
def find_recovery(self):
recovery_path = os.path.join(self.recovery_dir, self.recovery_prefix)
files = glob.glob(recovery_path + '*' + self.extension)
files = sorted(files)
return files[0] if len(files) else ''

@ -0,0 +1,191 @@
import abc
from contextlib import suppress
from enum import Enum
from typing import Callable, Union, Optional, List, Tuple, Dict, Any
from dataclasses import dataclass, field, InitVar
import torch
import torch.distributed as dist
TensorList = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
class DeviceEnvType(Enum):
""" Device Environment Types
"""
CPU = "cpu"
CUDA = "cuda"
XLA = "xla"
def state_dict_apply(state_dict: Dict[str, Any], apply_fn, select_fn=lambda x: x.isinstance(torch.Tensor)):
out_dict = {}
for k, v in state_dict.items():
if isinstance(v, dict):
out_dict[k] = state_dict_apply(v, apply_fn, select_fn)
else:
out_dict[k] = apply_fn(v) if select_fn(v) else v
return out_dict
@dataclass
class DeviceEnv:
device_type: InitVar[Optional[str]] = None
device_index: InitVar[Optional[int]] = None
channels_last: InitVar[bool] = False
device: torch.device = field(init=False) # set from device_type + device_index or post_init logic
world_size: Optional[int] = None # set by post_init from env when None
local_rank: Optional[int] = None # set by post_init from env when None
global_rank: Optional[int] = None # set by post_init from env when None
amp: bool = False
autocast: Optional[Callable] = None # set by post_init from env when None
memory_format: Optional[torch.memory_format] = None
dtype: Optional[torch.dtype] = None
def __post_init__(
self,
device_type: Optional[str],
device_index: Optional[int],
channels_last: bool,
):
device_type = device_type or 'cpu'
self.device = torch.device(device_type) if device_index is None \
else torch.device(device_type, device_index)
self.world_size = 1 if self.world_size is None else self.world_size
self.local_rank = 0 if self.local_rank is None else self.local_rank
self.global_rank = 0 if self.global_rank is None else self.global_rank
if self.autocast is None:
self.autocast = suppress
if channels_last:
self.memory_format = torch.channels_last
@staticmethod
def is_instance():
return is_global_device()
@staticmethod
def instance():
# throws if called before global device is set / initialized
return get_global_device()
@property
def type(self) -> DeviceEnvType:
if self.device.type == 'cpu':
return DeviceEnvType.CPU
elif self.device.type == 'cuda':
return DeviceEnvType.CUDA
elif self.device.type == 'xla':
return DeviceEnvType.XLA
else:
assert False, "Unexpected device type for base DevEnv impl."
@property
def type_cuda(self):
# shortcut for common cuda device type
return self.type == DeviceEnvType.CUDA
@property
def type_xla(self):
# shortcut for common xla device type
return self.type == DeviceEnvType.XLA
@property
def distributed(self):
return self.world_size > 1
@property
def primary(self):
return self.local_rank == 0
@property
def global_primary(self):
return self.global_rank == 0
def wrap_distributed(self, *modules):
pass
def wrap_parallel(self, *modules):
pass
def to_cpu(self, *modules: torch.nn.Module):
moved = [m.cpu() for m in modules]
return moved[0] if len(moved) == 1 else moved
def to_device(self, *modules: torch.nn.Module):
# FIXME handling dtype? Do we want separate dtype for data vs model?
moved = [m.to(device=self.device, memory_format=self.memory_format) for m in modules]
return moved[0] if len(moved) == 1 else moved
def state_dict_to_cpu(self, state: Dict[str, Any]):
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.cpu())
return cpu_state
def state_dict_to_device(self, state: Dict[str, Any]):
cpu_state = state_dict_apply(state, apply_fn=lambda x: x.to(self.device))
return cpu_state
def mark_step(self):
pass # NO-OP for non-XLA devices
def synchronize(self, tensors: Optional[TensorList] = None):
pass
def all_reduce_(self, tensor: TensorList, op=dist.ReduceOp.SUM, average=False):
dist.all_reduce(tensor, op=op)
if average:
tensor.div_(self.world_size)
return tensor
def all_reduce(self, tensor: torch.Tensor, op=dist.ReduceOp.SUM, average=False):
reduce_tensor = tensor.clone()
dist.all_reduce(reduce_tensor, op=op)
if average:
reduce_tensor = reduce_tensor / self.world_size
return reduce_tensor
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
output_tensors = [torch.empty_like(tensor) for _ in range(self.world_size)]
dist.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, cat_dim)
def all_to_all(self, tensor: torch.Tensor, num_splits, split_dim, cat_dim=0):
input_tensors = torch.chunk(tensor, num_splits, split_dim)
output_tensors = [torch.empty_like(input_tensors[0]) for _ in range(self.world_size)]
dist.all_to_all(output_tensors, input_tensors)
return torch.cat(output_tensors, cat_dim)
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
dist.broadcast(tensor, src=src_rank)
return tensor
def broadcast(self, tensor: Optional[torch.Tensor] = None, src_rank=0):
if self.global_rank != src_rank:
tensor = torch.empty_like(tensor)
assert tensor is not None
dist.broadcast(tensor, src=src_rank)
return tensor
def barrier(self):
dist.barrier()
# Global device environment singleton instance
_global_device_env: Optional[DeviceEnv] = None
def is_global_device():
return _global_device_env is not None
def get_global_device() -> DeviceEnv:
if not is_global_device():
raise RuntimeError('Please initialize device environment by calling initialize_device / set_global_device.')
return _global_device_env
def set_global_device(device: DeviceEnv):
global _global_device_env
if _global_device_env is not None:
raise RuntimeError('Global device is already set, it should NOT be set again.')
_global_device_env = device

@ -0,0 +1,68 @@
import os
from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional
import torch
from torch.nn.parallel import DistributedDataParallel, DataParallel
from .device_env import DeviceEnv, DeviceEnvType, TensorList
def is_cuda_available():
return torch.cuda.is_available()
@dataclass
class DeviceEnvCuda(DeviceEnv):
def __post_init__(
self,
device_type: Optional[str],
device_index: Optional[int],
channels_last: bool,
):
assert torch.cuda.device_count()
torch.backends.cudnn.benchmark = True
setup_world_size = self.world_size or int(os.environ.get('WORLD_SIZE', 1))
assert setup_world_size
if setup_world_size > 1:
# setup distributed
assert device_index is None
if self.local_rank is None:
lr = os.environ.get('LOCAL_RANK', None)
if lr is None:
raise RuntimeError(
'At least one of LOCAL_RANK env variable or local_rank arg must be set to valid integer.')
self.local_rank = int(lr)
self.device = torch.device('cuda:%d' % self.local_rank)
torch.cuda.set_device(self.local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
self.world_size = torch.distributed.get_world_size()
assert self.world_size == setup_world_size
self.global_rank = torch.distributed.get_rank()
else:
self.device = torch.device('cuda' if device_index is None else f'cuda:{device_index}')
self.local_rank = 0
self.world_size = 1
self.global_rank = 0
if self.autocast is None:
self.autocast = torch.cuda.amp.autocast if self.amp else suppress
if channels_last:
self.memory_format = torch.channels_last
@property
def type(self) -> DeviceEnvType:
return DeviceEnvType.CUDA
def wrap_distributed(self, *modules, **kwargs):
wrapped = [DistributedDataParallel(m, device_ids=[self.local_rank], **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def wrap_parallel(self, *modules, **kwargs):
assert not self.distributed
wrapped = [DataParallel(m, **kwargs) for m in modules]
return wrapped[0] if len(wrapped) == 1 else wrapped
def synchronize(self, tensors: Optional[TensorList] = None):
torch.cuda.synchronize(self.device)

@ -0,0 +1,36 @@
import logging
from .device_env import DeviceEnv, is_global_device, get_global_device, set_global_device
from .device_env_cuda import DeviceEnvCuda, is_cuda_available
from .device_env_xla import DeviceEnvXla, is_xla_available
_logger = logging.getLogger(__name__)
def initialize_device(force_cpu: bool = False, **kwargs) -> DeviceEnv:
if is_global_device():
return get_global_device()
denv = None
if not force_cpu:
xla_device_type = kwargs.get('xla_device_type', None)
if is_xla_available(xla_device_type):
# XLA supports more than just TPU, will search in order TPU, GPU, CPU
denv = DeviceEnvXla(**kwargs)
elif is_cuda_available():
denv = DeviceEnvCuda(**kwargs)
# CPU fallback
if denv is None:
if is_xla_available('CPU'):
denv = DeviceEnvXla(device_type='CPU', **kwargs)
else:
denv = DeviceEnv()
_logger.info(f'Initialized device {denv.device}. '
f'Rank: {denv.global_rank} ({denv.local_rank}) of {denv.world_size}.')
print(denv) # FIXME temporary print for debugging
set_global_device(denv)
return denv

@ -0,0 +1,136 @@
import os
from contextlib import suppress
from dataclasses import dataclass, field, InitVar
from typing import Optional, Dict
import torch
from torch.distributed import ReduceOp
try:
import torch_xla.core.xla_model as xm
import torch_xla
_HAS_XLA = True
except ImportError as e:
xm = None
torch_xla = None
_HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .device_env import DeviceEnv, DeviceEnvType, TensorList
_PT_TO_XM_OP = {
ReduceOp.SUM: 'sum',
ReduceOp.PRODUCT: 'mul',
ReduceOp.MIN: 'min',
ReduceOp.MAX: 'max',
ReduceOp.BAND: 'and',
ReduceOp.BOR: 'or',
}
def is_xla_available(xla_device_type=None):
if not _HAS_XLA:
return False
supported_devs = xm.get_xla_supported_devices(devkind=xla_device_type)
return len(supported_devs) >= 1
@dataclass
class DeviceEnvXla(DeviceEnv):
def __post_init__(
self,
device_type: Optional[str],
device_idx: Optional[int],
channels_last: bool,
):
if device_type is not None:
device_type = device_type.upper()
assert device_type in ('TPU', 'GPU', 'CPU'), "XLA device type must be one of ('TPU', 'GPU', 'CPU')"
self.device = xm.xla_device(n=device_idx, devkind=device_type)
self.world_size = xm.xrt_world_size()
if self.distributed:
assert device_idx is None, "device_index is based on local rank for distributed XLA mode"
self.local_rank = xm.get_local_ordinal()
self.global_rank = xm.get_ordinal()
else:
self.local_rank = 0
self.global_rank = 0
if self.amp:
assert xa is not None, 'XLA AMP is not present on this build'
if self.autocast is None:
self.autocast = xa.autocast if self.amp else suppress
if channels_last:
self.memory_format = torch.channels_last
@property
def type(self) -> DeviceEnvType:
return DeviceEnvType.XLA
def wrap_distributed(self, *modules):
wrapped = [m for m in modules] # NO-OP
return wrapped[0] if len(wrapped) == 1 else wrapped
def wrap_parallel(self, *modules):
assert False, "Not implemented"
def mark_step(self):
xm.mark_step()
def synchronize(self, tensors: Optional[TensorList] = None):
torch_xla._XLAC._xla_sync_multi(tensors, devices=[], wait=True, sync_xla_data=True)
def all_reduce(self, tensor: torch.Tensor, op=ReduceOp.SUM, average=False):
assert isinstance(tensor, torch.Tensor) # unlike in-place variant, lists/tuples not allowed
op = _PT_TO_XM_OP[op]
scale = 1.0 / self.world_size if average else 1.0
return xm.all_reduce(op, tensor, scale=scale)
def all_reduce_(self, tensor: TensorList, op=ReduceOp.SUM, average=False):
op = _PT_TO_XM_OP[op]
scale = 1.0 / self.world_size if average else 1.0
wrapped = False
if isinstance(tensor, torch.Tensor):
tensor = [tensor] # bare tensors are not operated on in-place
wrapped = True
xm.all_reduce(op, tensor, scale=scale)
if wrapped:
tensor = tensor[0]
return tensor
def all_gather(self, tensor: torch.Tensor, cat_dim=0):
output = xm.all_gather(tensor, cat_dim)
return output
def all_to_all(self, tensor, num_splits, split_dim, cat_dim=0):
output = xm.all_to_all(tensor, split_dim, cat_dim, num_splits)
return output
def broadcast(self, tensor: torch.Tensor, src_rank=0):
if self.global_rank != src_rank:
reduce_tensor = torch.zeros_like(tensor)
xm.all_reduce('sum', reduce_tensor)
else:
xm.all_reduce('sum', tensor)
return tensor
def broadcast_(self, tensor: torch.Tensor, src_rank=0):
out_tensor = self.broadcast(tensor, src_rank)
return tensor.copy_(out_tensor)
def barrier(self):
xm.rendezvous('timm.bits.dist_barrier')
def state_dict_to_cpu(self, state: Dict[str, torch.Tensor]):
cpu_state = xm._maybe_convert_to_cpu(state, convert=True)
return cpu_state
def state_dict_to_device(self, state: Dict[str, torch.Tensor]):
device_state = xm.send_cpu_data_to_device(state, device=self.device)
return device_state

@ -0,0 +1,150 @@
from typing import Dict, Tuple, List, Union, Any, Callable
import torch
from torch.distributed import ReduceOp
from timm.utils import unwrap_model
from .device_env import DeviceEnv
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
def _validate_type(tensor: TensorSeq):
if isinstance(tensor, (dict, list, tuple)):
if not tensor:
return
else:
assert isinstance(tensor, torch.Tensor)
def distribute_bn(model: torch.nn.Module, reduce: bool = False, dev_env: DeviceEnv = None):
if dev_env is None:
dev_env = DeviceEnv.instance()
# ensure every node has the same running bn stats
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
dev_env.all_reduce_(bn_buf, average=True)
else:
# broadcast bn stats from rank 0 to whole group
dev_env.broadcast_(bn_buf, 0)
def all_gather_recursive(tensor: TensorSeq, cat_dim=0, dev_env: DeviceEnv = None):
""" Recursive all gather via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor):
return dev_env.all_gather(tensor, cat_dim=cat_dim)
elif isinstance(tensor, dict):
return {k: all_gather_recursive(v, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(all_gather_recursive(v, dev_env=dev_env) for v in tensor)
def all_reduce_recursive(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
""" Recursive all reduce via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor):
return dev_env.all_reduce_(tensor, op=op, average=average)
elif isinstance(tensor, dict):
return {k: all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(all_reduce_recursive(v, op=op, average=average, dev_env=dev_env) for v in tensor)
def broadcast_recursive(tensor: TensorSeq, src_rank: int, dev_env: DeviceEnv = None):
""" Recursive broadcast via DeviceEnv distributed primitives
FIXME add group support
"""
_validate_type(tensor)
if dev_env is None:
dev_env = DeviceEnv.instance()
if isinstance(tensor, torch.Tensor):
return dev_env.broadcast_(tensor, src_rank=src_rank)
elif isinstance(tensor, dict):
return {k: broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for k, v in tensor.items()}
elif isinstance(tensor, (tuple, list)):
return type(tensor)(broadcast_recursive(v, src_rank=src_rank, dev_env=dev_env) for v in tensor)
def all_gather_sequence(tensor: TensorSeq, cat_dim: int = 0, dev_env: DeviceEnv = None):
""" All gather a flat Tensor sequence (dict, list, tuple) of same shape
"""
_validate_type(tensor)
if dev_env is None:
dev_env = DeviceEnv.instance()
with torch.no_grad():
names = None
# merge values into one tensor for reduction
if isinstance(tensor, dict):
names = tensor.keys()
gather_values = tuple(tensor.values())
elif isinstance(tensor, (tuple, list)):
gather_values = tensor
else:
gather_values = (tensor,)
gather_values = torch.stack(gather_values, dim=0)
gather_values = dev_env.all_gather(gather_values, cat_dim=cat_dim + 1).unbind(dim=0)
# separate reduced values into original structure
if isinstance(tensor, dict):
gather_values = {k: v for k, v in zip(names, gather_values)}
elif isinstance(tensor, (tuple, list)):
gather_values = type(tensor)(v for v in gather_values)
else:
gather_values = gather_values[0]
return gather_values
def all_reduce_sequence(tensor: TensorSeq, op=ReduceOp.SUM, average=False, dev_env: DeviceEnv = None):
"""
All reduce the tensors in a flat Tensor sequence (dict, list, tuple) of same tensor shape
Args:
tensor (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
_validate_type(tensor)
if dev_env is None:
dev_env = DeviceEnv.instance()
with torch.no_grad():
names = None
# merge values into one tensor for reduction
if isinstance(tensor, dict):
names = tensor.keys()
reduce_values = tuple(tensor.values())
elif isinstance(tensor, (tuple, list)):
reduce_values = tensor
else:
reduce_values = (tensor,)
reduce_values = torch.stack(reduce_values, dim=0)
dev_env.all_reduce_(reduce_values, op=op, average=average)
reduce_values = reduce_values.unbind(dim=0)
# separate reduced values into original structure
if isinstance(tensor, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(tensor, (tuple, list)):
reduce_values = type(tensor)(v for v in reduce_values)
else:
reduce_values = reduce_values[0]
return reduce_values

@ -0,0 +1,190 @@
""" PyTorch distributed helpers
Some of this lifted from Detectron2 with other fns added by myself.
FIXME many functions remain unfinished/untested
"""
from typing import Dict, Tuple, List, Union, Any, Callable
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
TensorSeq = Union[torch.Tensor, Tuple[torch.Tensor, ...], List[torch.Tensor], Dict[Any, torch.Tensor]]
def synchronize_torch():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
def all_reduce_sequence_torch(values: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
"""
All reduce the tensors in a sequence (dict, list, tuple)
Args:
values (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
world_size = dist.get_world_size(group)
if world_size <= 1:
return values
with torch.no_grad():
names = None
if isinstance(values, dict):
names = values.keys()
reduce_values = torch.stack(tuple(values.values()), dim=0)
elif isinstance(values, (tuple, list)):
reduce_values = torch.stack(values, dim=0)
else:
reduce_values = values
dist.all_reduce(reduce_values, op=op, group=group)
if average:
reduce_values /= world_size
if isinstance(values, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(values, (tuple, list)):
reduce_values = type(values)(v for v in reduce_values)
return reduce_values
def reduce_sequence_torch(values: TensorSeq, dst_rank=0, op=ReduceOp.SUM, average=False, group=None):
"""
All reduce the tensors in a sequence (dict, list, tuple)
Args:
values (dict): inputs to be reduced. All the values must be scalar Tensor.
average (bool): whether to do average or sum
Returns:
a sequence with the same type as input (dict, list, tuple)
"""
world_size = dist.get_world_size(group)
this_rank = dist.get_rank()
if world_size <= 1:
return values
with torch.no_grad():
names = None
if isinstance(values, dict):
names = values.keys()
reduce_values = torch.stack(tuple(values.values()), dim=0)
elif isinstance(values, (tuple, list)):
reduce_values = torch.stack(values, dim=0)
else:
reduce_values = values
reduce_values = torch.stack(reduce_values, dim=0)
dist.reduce(reduce_values, dst=dst_rank, op=op, group=group)
if average and this_rank == dst_rank:
reduce_values /= world_size
if isinstance(values, dict):
reduce_values = {k: v for k, v in zip(names, reduce_values)}
elif isinstance(values, (tuple, list)):
reduce_values = type(values)(v for v in reduce_values)
return reduce_values
def all_gather_sequence_torch(values: TensorSeq, group=None, join_fn=torch.cat, join_dim=0):
world_size = dist.get_world_size(group)
def _do_gather(tensor):
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(tensor_list, tensor, group=group)
return join_fn(tensor_list, dim=join_dim)
if isinstance(values, dict):
gathered = {k: _do_gather(v) for k, v in values.items()}
return gathered
elif isinstance(values, (list, tuple)):
gathered = type(values)(_do_gather(v) for v in values)
return gathered
else:
# if not a dict, list, tuple, expect a singular tensor
assert isinstance(values, torch.Tensor)
return _do_gather(values)
def gather_sequence_torch(values: TensorSeq, dst_rank, group=None, join_fn=torch.cat, join_dim=0):
world_size = dist.get_world_size(group)
this_rank = dist.get_rank(group)
def _do_gather(tensor):
tensor_list = None
if this_rank == dst_rank:
tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.gather(tensor, tensor_list, dst=dst_rank, group=group)
return join_fn(tensor_list, dim=join_dim)
if isinstance(values, dict):
gathered = {k: _do_gather(v) for k, v in values.items()}
return gathered
elif isinstance(values, (list, tuple)):
gathered = type(values)(_do_gather(v) for v in values)
return gathered
else:
# if not a dict, list, tuple, expect a singular tensor
assert isinstance(values, torch.Tensor)
return _do_gather(values)
def all_gather_torch(value: TensorSeq, group=None, join_fn: Callable = None, join_dim=0):
if isinstance(value, torch.Tensor):
world_size = dist.get_world_size(group)
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
dist.all_gather(out_tensors, value, group=group)
if join_fn is not None:
out_tensors = join_fn(out_tensors, dim=join_dim)
return out_tensors
elif isinstance(value, dict):
return {k: all_gather_torch(v, group, join_fn, join_dim) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(all_gather_torch(v, group, join_fn, join_dim) for v in value)
def gather_torch(value: TensorSeq, dst_rank=0, group=None, join_fn: Callable = None, join_dim=0):
if isinstance(value, torch.Tensor):
world_size = dist.get_world_size(group)
this_rank = dist.get_rank()
out_tensors = None
if this_rank == dst_rank:
out_tensors = [torch.empty_like(value) for _ in range(world_size)]
dist.gather(value, out_tensors, dst=dst_rank, group=group)
if join_fn is not None:
out_tensors = join_fn(out_tensors, dim=join_dim)
return out_tensors
elif isinstance(value, dict):
return {k: gather_torch(v, dst_rank, group, join_fn, join_dim) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(gather_torch(v, dst_rank, group, join_fn, join_dim) for v in value)
def all_reduce_torch(value: TensorSeq, op=ReduceOp.SUM, average=False, group=None):
if isinstance(value, torch.Tensor):
dist.all_reduce(value, op=op, group=group)
if average:
value /= dist.get_world_size(group)
elif isinstance(value, dict):
return {k: all_reduce_torch(v, op=op, average=average, group=group) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(all_reduce_torch(v, op=op, average=average, group=group) for v in value)
def broadcast_torch(value: TensorSeq, src_rank: int = 0, group=None):
if isinstance(value, torch.Tensor):
return dist.broadcast(value, src=src_rank, group=group)
elif isinstance(value, dict):
return {k: broadcast_torch(v, src_rank=src_rank, group=group) for k, v in value.items()}
elif isinstance(value, (tuple, list)):
return type(value)(broadcast_torch(v, src_rank=src_rank, group=group) for v in value)

@ -0,0 +1,26 @@
from functools import partial
import torch
from timm.utils.agc import adaptive_clip_grad
def get_clip_grad_fn(mode: str = 'norm', norm_type: float = 2.0):
if mode == 'norm':
return partial(torch.nn.utils.clip_grad_norm_, norm_type=norm_type)
elif mode == 'value':
return torch.nn.utils.clip_grad_value_
elif mode == 'agc':
return partial(adaptive_clip_grad, norm_type=norm_type)
else:
assert False, f"Unknown clip mode ({mode})."
def get_clip_parameters(model, skip_last=0):
if hasattr(model, 'get_clip_parameters'):
return model.get_clip_parameters()
else:
if skip_last:
return list(model.parameters())[::-skip_last]
else:
return model.parameters()

@ -0,0 +1,145 @@
import abc
from typing import Callable, Union, Optional, List, Tuple, Dict
from dataclasses import dataclass
import torch
from torch.distributed import ReduceOp
from .device_env import DeviceEnv
from .distributed import all_gather_sequence, all_reduce_sequence
MetricValueT = Union[float, torch.Tensor, List[float], List[torch.Tensor]]
@dataclass
class ValueInfo:
initial: Optional[MetricValueT] = 0.
dtype: torch.dtype = torch.float32
dist_reduce: str = 'sum'
dist_average: bool = False
class Metric(abc.ABC):
def __init__(
self,
dev_env: DeviceEnv = None
):
self._infos: Dict[str, ValueInfo] = {}
self._values: Dict[str, Optional[MetricValueT]] = {}
self._values_dist: Dict[str, Optional[MetricValueT]] = {}
if dev_env is None:
dev_env = DeviceEnv.instance()
self._dev_env = dev_env
def _register_value(self, name: str, info: Optional[ValueInfo] = None):
info = info or ValueInfo()
self._infos[name] = info
# def get_value(self, name: str, use_dist=True):
# if use_dist:
# return self._values_dist.get(name, self._values.get(name))
# else:
# return self._values.get(name)
def __getattr__(self, item):
if item not in self._infos:
raise AttributeError
value = self._values_dist.get(item, self._values.get(item, None))
return value
def __setattr__(self, key, value):
if '_infos' in self.__dict__ and key in self._infos:
self._values[key] = value
else:
super().__setattr__(key, value)
def update(
self,
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
self._update(predictions, target)
def _update(
self,
predictions: Union[torch.Tensor, Dict[str, torch.Tensor]],
target: Union[torch.Tensor, Dict[str, torch.Tensor]]):
pass
def reset(self):
self._values = {}
self._values_dist = {}
for name, info in self._infos.items():
# if info specifies an initial value, we reset here, otherwise set to None and leave it to child class
if info.initial is not None:
if isinstance(info.initial, torch.Tensor):
tensor = info.initial.detach().clone()
else:
tensor = torch.ones([], dtype=info.dtype) * info.initial # scalar
self._values[name] = tensor.to(device=self._dev_env.device, dtype=info.dtype)
else:
self._values[name] = None
self._reset()
def _reset(self):
pass
def compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
if self._dev_env.distributed:
self._distribute_values()
results = self._compute()
self._values_dist = {}
return results
@abc.abstractmethod
def _compute(self) -> Union[torch.Tensor, Tuple[torch.Tensor, ...], Dict[str, torch.Tensor]]:
pass
def _distribute_values(self):
if not self._infos or not self._values:
return
def _args(op: str):
if op == 'cat':
return True, dict(cat_dim=0)
else:
return False, dict(op=ReduceOp.SUM)
prev_dsr = None
same_dsr = True
names = []
values = []
reductions = []
for name, value in self._values.items():
if value is not None:
info = self._infos[name]
dsr = (value.dtype, value.shape, info.dist_reduce)
if prev_dsr is not None and prev_dsr != dsr:
same_dsr = False
prev_dsr = dsr
names.append(name)
values.append(value)
reductions.append(_args(info.dist_reduce))
if same_dsr:
do_gather, reduce_kwargs = reductions[0]
if do_gather:
reduced_values = all_gather_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
else:
reduced_values = all_reduce_sequence(values, dev_env=self._dev_env, **reduce_kwargs)
for name, reduced_value in zip(names, reduced_values):
info = self._infos[name]
if info.dist_average:
reduced_value /= self._dev_env.world_size
self._values_dist[name] = reduced_value
else:
for n, v, r in zip(names, values, reductions):
info = self._infos[n]
do_gather, reduce_kwargs = r
if do_gather:
reduced_value = self._dev_env.all_gather(v, **reduce_kwargs)
else:
reduced_value = self._dev_env.all_reduce(v, **reduce_kwargs)
if info.dist_average:
reduced_value /= self._dev_env.world_size
self._values_dist[n] = reduced_value

@ -0,0 +1,71 @@
import torch
from typing import Optional, Tuple, Dict
from .device_env import DeviceEnv
from .metric import Metric, ValueInfo
class Accuracy(Metric):
def __init__(
self,
threshold=0.5,
multi_label=False,
accumulate_dtype=torch.float32,
dev_env=None,
):
super().__init__(dev_env=dev_env)
self.accumulate_dtype = accumulate_dtype
self.threshold = threshold
self.eps = 1e-8
self.multi_label = multi_label
# statistics / counts
self._register_value('correct', ValueInfo(dtype=accumulate_dtype))
self._register_value('total', ValueInfo(dtype=accumulate_dtype))
def _update(self, predictions, target):
raise NotImplemented()
def _compute(self):
raise NotImplemented()
class AccuracyTopK(Metric):
def __init__(
self,
topk=(1, 5),
accumulate_dtype=torch.float32,
dev_env: DeviceEnv = None
):
super().__init__(dev_env=dev_env)
self.accumulate_dtype = accumulate_dtype
self.eps = 1e-8
self.topk = topk
self.maxk = max(topk)
# statistics / counts
for k in self.topk:
self._register_value(f'top{k}', ValueInfo(dtype=accumulate_dtype))
self._register_value('total', ValueInfo(dtype=accumulate_dtype))
self.reset()
def _update(self, predictions: torch.Tensor, target: torch.Tensor):
batch_size = predictions.shape[0]
sorted_indices = predictions.topk(self.maxk, dim=1)[1]
target_reshape = target.reshape(-1, 1).expand_as(sorted_indices)
correct = sorted_indices.eq(target_reshape).to(dtype=self.accumulate_dtype).sum(0)
for k in self.topk:
attr_name = f'top{k}'
correct_at_k = correct[:k].sum()
setattr(self, attr_name, getattr(self, attr_name) + correct_at_k)
self.total += batch_size
def _compute(self) -> Dict[str, torch.Tensor]:
assert self.total is not None
output = {}
for k in self.topk:
attr_name = f'top{k}'
output[attr_name] = 100 * getattr(self, attr_name) / self.total
return output

@ -0,0 +1,117 @@
import torch
import torch.nn.functional as F
class PrecisionRecall:
def __init__(self, threshold=0.5, multi_label=False, device=None):
self.threshold = threshold
self.device = device
self.multi_label = multi_label
# statistics
# the total number of true positive instances under each class
# Shape: (num_classes, )
self._tp_sum = None
# the total number of instances
# Shape: (num_classes, )
self._total_sum = None
# the total number of instances under each _predicted_ class,
# including true positives and false positives
# Shape: (num_classes, )
self._pred_sum = None
# the total number of instances under each _true_ class,
# including true positives and false negatives
# Shape: (num_classes, )
self._true_sum = None
self.reset()
def reset(self):
self._tp_sum = None
self._total_sum = None
self._pred_sum = None
self._true_sum = None
def update(self, predictions, target):
output_type = predictions.type()
num_classes = predictions.size(-1)
if self.multi_label:
if self.threshold is not None:
predictions = (predictions > self.threshold).type(output_type)
predictions = predictions.t().reshape(num_classes, -1)
target = target.t().reshape(num_classes, -1)
else:
target = F.one_hot(target.view(-1), num_classes=num_classes)
indices = torch.argmax(predictions, dim=1).view(-1)
predictions = F.one_hot(indices, num_classes=num_classes)
# FIXME make sure binary case works
target = target.type(output_type)
correct = (target * predictions > 0).type(output_type)
pred_positives = predictions.sum(dim=0)
target_positives = target.sum(dim=0)
if correct.sum() == 0:
true_positives = torch.zeros_like(pred_positives)
else:
true_positives = correct.sum(dim=0)
if self._tp_sum is None:
self._tp_sum = torch.zeros(num_classes, device=self.device)
self._true_sum = torch.zeros(num_classes, device=self.device)
self._pred_sum = torch.zeros(num_classes, device=self.device)
self._total_sum = torch.tensor(0, device=self.device)
self._tp_sum += true_positives
self._pred_sum += pred_positives
self._true_sum += target_positives
self._total_sum += target.shape[0]
def counts_as_tuple(self, reduce=False):
tp_sum = self._tp_sum
pred_sum = self._pred_sum
true_sum = self._true_sum
total_sum = self._total_sum
if reduce:
tp_sum = reduce_tensor_sum(tp_sum)
pred_sum = reduce_tensor_sum(pred_sum)
true_sum = reduce_tensor_sum(true_sum)
total_sum = reduce_tensor_sum(total_sum)
return tp_sum, pred_sum, true_sum, total_sum
def counts(self, reduce=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
return dict(tp_sum=tp_sum, pred_sum=pred_sum, true_sum=true_sum, total_sum=total_sum)
def confusion(self, reduce=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=reduce)
fp = pred_sum - tp_sum
fn = true_sum - tp_sum
tp = tp_sum
tn = total_sum - tp - fp - fn
return dict(tp=tp, fp=fp, fn=fn, tn=tn)
def compute(self, fscore_beta=1, average='micro', no_reduce=False, distributed=False):
tp_sum, pred_sum, true_sum, total_sum = self.counts_as_tuple(reduce=distributed)
if average == 'micro':
tp_sum = tp_sum.sum()
pred_sum = pred_sum.sum()
true_sum = true_sum.sum()
precision = tp_sum / pred_sum
recall = tp_sum / true_sum
beta_sq = fscore_beta ** 2
f1_denom = beta_sq * precision + recall
fscore = (1 + beta_sq) * precision * recall / f1_denom
if average == 'macro' and not no_reduce:
precision = precision.mean()
recall = recall.mean()
fscore = fscore.mean()
return dict(fscore=fscore, precision=precision, recall=recall)
return dict(fscore=fscore, precision=precision, recall=recall)

@ -0,0 +1,238 @@
import csv
import logging
import os
from collections import OrderedDict
from typing import Optional, Tuple, Dict, Union
import torch
_logger = logging.getLogger(__name__)
try:
from torch.utils.tensorboard import SummaryWriter
HAS_TB = True
except ImportError as e:
HAS_TB = False
try:
import wandb
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
# FIXME old formatting for reference, to remove
#
# def log_eval(batch_idx, last_idx, batch_time, loss, top1, top5, log_suffix=''):
# log_name = 'Test' + log_suffix
# logging.info(
# f'{log_name}: [{batch_idx:>4d}/{last_idx}] '
# f'Time: {batch_time.smooth_val:.3f} ({batch_time.avg:.3f}) '
# f'Loss: {loss.smooth_val:>7.4f} ({loss.avg:>6.4f}) '
# f'Acc@1: {top1.smooth_val:>7.4f} ({top1.avg:>7.4f}) '
# f'Acc@5: {top5.smooth_val:>7.4f} ({top5.avg:>7.4f})'
# )
#
#
# def log_train(epoch, step, num_steps, loss, batch_size, batch_time, data_time, lr, world_size=1):
# last_step = max(0, num_steps - 1)
# progress = 100. * step / last_step if last_step else 0.
# log_str = f'Train: {epoch} [{step:>4d}/{num_steps} ({progress:>3.0f}%)]' \
# f' Time: {batch_time.smooth_val:.3f}s, {batch_size * world_size / batch_time.smooth_val:>7.2f}/s' \
# f' ({batch_time.avg:.3f}s, {batch_size * world_size / batch_time.avg:>7.2f}/s)' \
# f' Data: {data_time.smooth_val:.3f} ({data_time.avg:.3f})'
# log_str += f' Loss: {loss.smooth_val:>9.6f} ({loss.avg:>6.4f}) '
# log_str += f' LR: {lr:.3e} '
def summary_row_dict(results, index=None, index_name='epoch'):
assert isinstance(results, dict)
row_dict = OrderedDict()
if index is not None:
row_dict[index_name] = index
if not results:
return row_dict
if isinstance(next(iter(results.values())), dict):
# each key in results is a per-phase results dict, flatten by prefixing with phase name
for p, pr in results.items():
assert isinstance(pr, dict)
row_dict.update([('_'.join([p, k]), v) for k, v in pr.items()])
else:
row_dict.update(results)
return row_dict
class SummaryCsv:
def __init__(self, output_dir, filename='summary.csv'):
self.output_dir = output_dir
self.filename = os.path.join(output_dir, filename)
self.needs_header = not os.path.exists(self.filename)
def update(self, row_dict):
with open(self.filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=row_dict.keys())
if self.needs_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
self.needs_header = False
dw.writerow(row_dict)
_sci_keys = {'lr'}
def _add_kwargs(text_update, name_map=None, **kwargs):
def _to_str(key, val):
if isinstance(val, float):
if key.lower() in _sci_keys:
return f'{key}: {val:.3e} '
else:
return f'{key}: {val:.4f}'
else:
return f'{key}: {val}'
def _map_name(key, name_map, capitalize=True):
if name_map is None:
if capitalize:
return key.capitalize() if not key.isupper() else key
else:
return key
return name_map.get(key, None)
for k, v in kwargs.items():
if isinstance(v, dict):
# log each k, v of a dict kwarg as separate items
for kk, vv in v.items():
name = _map_name(kk, name_map)
if not name:
continue
text_update += [_to_str(kk, vv)]
else:
name = _map_name(k, name_map, capitalize=True)
if not name:
continue
text_update += [_to_str(name, v)]
class Monitor:
def __init__(
self,
experiment_name=None,
output_dir=None,
logger=None,
hparams=None,
log_wandb=False,
output_enabled=True,
):
self.output_dir = output_dir # for tensorboard, csv, text file (TODO) logging
self.logger = logger or logging.getLogger('log')
hparams = hparams or {}
# Setup CSV writer(s)
if output_dir is not None:
self.csv_writer = SummaryCsv(output_dir=output_dir)
else:
self.csv_writer = None
# Setup Tensorboard
self.summary_writer = None # FIXME tensorboard
# Setup W&B
self.wandb_run = None
if log_wandb:
if HAS_WANDB:
self.wandb_run = wandb.init(project=experiment_name, config=hparams)
else:
_logger.warning("You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
self.output_enabled = output_enabled
# FIXME image save
def log_step(
self,
phase: str,
step_idx: int,
step_end_idx: Optional[int] = None,
epoch: Optional[int] = None,
loss: Optional[float] = None,
rate: Optional[Union[float, Tuple[float, float]]] = None,
phase_suffix: str = '',
**kwargs,
):
""" log train/eval step
"""
if not self.output_enabled:
return
if 'num_steps' in kwargs:
step_end_idx = max(0, kwargs.pop('num_steps') - 1)
phase_title = f'{phase.capitalize()} ({phase_suffix})' if phase_suffix else f'{phase.capitalize()}:'
progress = 100. * step_idx / step_end_idx if step_end_idx else 0.
rate_str = ''
if isinstance(rate, (tuple, list)):
rate_str = f'Rate: {rate[0]:.2f}/s ({rate[1]:.2f}/s)'
elif rate is not None:
rate_str = f'Rate: {rate:.2f}/s'
text_update = [
phase_title,
f'{epoch}' if epoch is not None else None,
f'[{step_idx}]' if step_end_idx is None else None,
f'[{step_idx}/{step_end_idx} ({progress:>3.0f}%)]' if step_end_idx is not None else None,
rate_str,
f'Loss: {loss:.5f}' if loss is not None else None,
]
_add_kwargs(text_update, **kwargs)
log_str = ' '.join(item for item in text_update if item)
self.logger.info(log_str)
if self.summary_writer is not None:
# FIXME log step values to tensorboard
pass
def log_phase(
self,
phase: str = 'eval',
epoch: Optional[int] = None,
name_map: Optional[dict] = None,
**kwargs
):
"""log completion of evaluation or training phase
"""
if not self.output_enabled:
return
title = [
f'{phase.capitalize()}',
f'epoch: {epoch}' if epoch is not None else None,
'completed. ',
]
title_str = ' '.join(i for i in title if i)
results = []
_add_kwargs(results, name_map=name_map, **kwargs)
log_str = title_str + ', '.join(item for item in results if item)
self.logger.info(log_str)
def write_summary(
self,
results: Dict, # Dict or Dict of Dict where first level keys are treated as per-phase results
index: Optional[Union[int, str]] = None,
index_name: str = 'epoch',
):
""" Log complete results for all phases (typically called at end of epoch)
Args:
results (dict or dict[dict]): dict of results to write, or multiple dicts where first level
key is the name of results dict for each phase
index: value for row index (typically epoch #)
index_name: name for row index header (typically 'epoch')
"""
if not self.output_enabled:
return
row_dict = summary_row_dict(index=index, index_name=index_name, results=results)
if self.csv_writer:
self.csv_writer.update(row_dict)
if self.wandb_run is not None:
wandb.log(row_dict)
if self.summary_writer:
# FIXME log epoch summaries to tensorboard
pass

@ -0,0 +1,59 @@
import time
from typing import Optional
from .avg_scalar import AvgMinMaxScalar
class Tracker:
def __init__(self):
self.data_time = AvgMinMaxScalar() # time for data loader to produce batch of samples
self.step_time = AvgMinMaxScalar() # time for model step
self.iter_time = AvgMinMaxScalar() # full iteration time incl. data, step, and book-keeping
self.epoch_time = AvgMinMaxScalar()
self.iter_timestamp: Optional[float] = None
self.prev_timestamp: Optional[float] = None
self.epoch_timestamp: Optional[float] = None
def _measure_iter(self, ref_timestamp=None):
timestamp = time.perf_counter()
self.prev_timestamp = timestamp
def mark_iter(self):
timestamp = time.perf_counter()
if self.iter_timestamp is not None:
iter_time = timestamp - self.iter_timestamp
self.iter_time.update(iter_time)
self.iter_timestamp = self.prev_timestamp = timestamp
def mark_iter_data_end(self):
assert self.prev_timestamp is not None
timestamp = time.perf_counter()
data_time = timestamp - self.prev_timestamp
self.data_time.update(data_time)
self.prev_timestamp = timestamp
def mark_iter_step_end(self):
assert self.prev_timestamp is not None
timestamp = time.perf_counter()
step_time = timestamp - self.prev_timestamp
self.step_time.update(step_time)
self.prev_timestamp = timestamp
def mark_epoch(self):
timestamp = time.perf_counter()
if self.epoch_timestamp is not None:
epoch_time = timestamp - self.epoch_timestamp
self.epoch_time.update(epoch_time)
self.epoch_timestamp = timestamp
def get_avg_iter_rate(self, num_per_iter: int):
if num_per_iter == 0 or self.iter_time.avg == 0:
return 0
return num_per_iter / self.iter_time.avg
def get_last_iter_rate(self, num_per_iter: int):
if num_per_iter == 0 or self.iter_time.val == 0:
return 0
return num_per_iter / self.iter_time.val

@ -0,0 +1,12 @@
from dataclasses import dataclass
@dataclass
class TrainCfg:
""" Train Loop Configuration
Dataclass to hold training configuration values
"""
num_epochs: int = 100
log_interval: int = 50
recovery_interval: int = 0
accumulate_steps: int = 0

@ -0,0 +1,13 @@
from dataclasses import dataclass
from .monitor import Monitor
from .checkpoint_manager import CheckpointManager
@dataclass
class TrainServices:
""" Train Loop Services
"""
monitor: Monitor = None
checkpoint: CheckpointManager = None

@ -0,0 +1,151 @@
import dataclasses
from typing import Callable, Union, Optional
import logging
import torch
import torch.nn as nn
from timm.models.layers import convert_sync_batchnorm
from timm.optim import create_optimizer_v2
from timm.utils import ModelEmaV2
try:
import deepspeed as ds
except ImportError:
ds = None
from .checkpoint import load_train_state
from .device_env import DeviceEnv
from .train_cfg import TrainCfg
from .train_state import TrainState
from .updater_factory import create_updater
_logger = logging.getLogger(__name__)
def setup_model_and_optimizer(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
deepspeed: bool = False,
):
"""
Args:
dev_env:
model:
optimizer:
optimizer_cfg:
clip_value:
clip_fn:
model_ema:
model_ema_decay:
use_syncbn:
resume_path:
resume_opt:
Returns:
"""
if deepspeed:
return setup_model_and_optimizer_deepspeed(
dev_env=dev_env, model=model, optimizer=optimizer, optimizer_cfg=optimizer_cfg,
clip_fn=clip_fn, clip_value=clip_value, model_ema=model_ema, model_ema_decay=model_ema_decay,
resume_path=resume_path, resume_opt=resume_opt,
)
dev_env.to_device(model)
if use_syncbn and dev_env.distributed:
model = convert_sync_batchnorm(model)
if dev_env.primary:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
if isinstance(optimizer, Callable):
# FIXME this interface needs to be figured out, model, model and/or parameters, or just parameters?
optimizer = optimizer(model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model, **optimizer_cfg)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
)
# ema model
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
load_train_state(
train_state,
resume_path,
load_opt=resume_opt,
log_info=dev_env.primary)
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state
def setup_model_and_optimizer_deepspeed(
dev_env: DeviceEnv,
model: nn.Module,
optimizer: Union[Callable, str],
optimizer_cfg,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
model_ema: bool = False,
model_ema_decay: float = 0.9999,
use_syncbn: bool = False,
resume_path: str = '',
resume_opt: bool = True,
):
dev_env.to_device(model)
if isinstance(optimizer, Callable):
optimizer = optimizer(model=model, **optimizer_cfg)
else:
optimizer = create_optimizer_v2(model=model, **optimizer_cfg)
model = ds.initialize(model=model, optimizer=optimizer, dist_init_required=False)
updater = create_updater(
model=model,
optimizer=optimizer,
clip_fn=clip_fn,
clip_value=clip_value,
deepspeed=True,
)
# ema model
# FIXME how to do EMA w/ deepspeed?
model_ema = ModelEmaV2(model, decay=model_ema_decay) if model_ema else None
train_state = TrainState(model=model, updater=updater, model_ema=model_ema)
if resume_path:
# FIXME deepspeed resumes differently
assert False
if dev_env.distributed:
train_state = dataclasses.replace(
train_state, model=dev_env.wrap_distributed(train_state.model))
return train_state

@ -0,0 +1,64 @@
from dataclasses import dataclass
from typing import Dict, Any
from torch import nn as nn
from timm.scheduler import Scheduler
from timm.utils import get_state_dict, unwrap_model
from .train_cfg import TrainCfg
from .updater import Updater
@dataclass
class TrainState:
model: nn.Module = None
train_loss: nn.Module = None
eval_loss: nn.Module = None
updater: Updater = None
lr_scheduler: Scheduler = None
model_ema: nn.Module = None
train_cfg: TrainCfg = TrainCfg()
# FIXME collect & include other cfg like data & model so it's in one spot for checkpoints / logging / debugging?
epoch: int = 0
step_count: int = 0
step_count_global: int = 0
def __post_init__(self):
assert self.model is not None
assert self.updater is not None
def state_dict(self, unwrap_fn=unwrap_model):
state = dict(
# train loop state (counters, etc), saved and restored
epoch=self.epoch,
step_count=self.step_count,
step_count_global=self.step_count_global,
# model params / state, saved and restored
model=get_state_dict(self.model, unwrap_fn),
model_ema=None if self.model_ema is None else get_state_dict(self.model_ema, unwrap_fn),
# configuration, saved but currently not restored, determined by args / config file for each run
train_cfg=vars(self.train_cfg)
)
# FIXME include lr_scheduler state?
state.update(self.updater.state_dict()) # updater (optimizer, scaler, etc.) state added to state
return state
def load_state_dict(self, state_dict, unwrap_fn=unwrap_model, load_opt=True):
# restore train loop state
self.epoch = state_dict['epoch'] + 1
self.step_count = 0 # FIXME need more logic to restore part way through epoch
self.step_count_global = state_dict['step_count_global']
# restore model params / state
unwrap_fn(self.model).load_state_dict(state_dict.get('model'))
if 'model_ema' in state_dict and self.model_ema is not None:
unwrap_fn(self.model_ema).load_state_dict(state_dict.get('model_ema'))
# restore optimizer state
if load_opt:
self.updater.load_state_dict(state_dict)

@ -0,0 +1,72 @@
from dataclasses import dataclass, field, InitVar
from functools import partial
from typing import Callable, Optional, Union
import torch
import torch.nn as nn
from .grad_clip import get_clip_grad_fn, get_clip_parameters
@dataclass
class Updater:
model: nn.Module = None
optimizer: torch.optim.Optimizer = None # FIXME handle multiple optimizers per-model
clip_fn: Optional[Union[Callable, str]] = None
clip_value: Optional[float] = None
clip_params_fn: Optional[Callable] = None
grad_scaler: Optional[Callable] = None
create_graph: Optional[bool] = None
after_step_closure: bool = False
def __post_init__(self):
assert self.model is not None
assert self.optimizer is not None
if self.clip_fn is not None:
if isinstance(self.clip_fn, Callable):
skip_last = 0
else:
assert isinstance(self.clip_fn, str)
skip_last = 2 if 'agc' in self.clip_fn else 0
self.clip_fn = get_clip_grad_fn(self.clip_fn)
assert self.clip_value is not None
self.clip_params_fn = partial(get_clip_parameters, model=self.model, skip_last=skip_last)
if self.create_graph is None:
self.create_graph = getattr(self.optimizer, 'second_order', False)
self.after_step_closure = False
def reset(self):
self.optimizer.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False):
loss.backward(create_graph=self.create_graph)
if accumulate:
return
if self.clip_fn is not None:
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.optimizer.step()
self.reset()
def get_average_lr(self):
lrl = [param_group['lr'] for param_group in self.optimizer.param_groups if param_group['lr'] > 0]
return sum(lrl) / len(lrl)
def state_dict(self):
state_dict = dict(optimizer=self.optimizer.state_dict())
if self.grad_scaler is not None:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
return state_dict
def load_state_dict(self, state_dict):
if 'optimizer' in state_dict:
self.optimizer.load_state_dict(state_dict['optimizer'])
if 'grad_scaler' in state_dict and self.grad_scaler is not None:
self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
def after_step(self, after_step_fn, *args):
after_step_fn(*args)
@property
def deepspeed(self):
return False

@ -0,0 +1,30 @@
from dataclasses import dataclass, field, InitVar
from typing import Dict, Any
import torch
from .updater import Updater
@dataclass
class UpdaterCudaWithScaler(Updater):
scaler_kwargs: InitVar[Dict[str, Any]] = None
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
super().__post_init__()
scaler_kwargs = scaler_kwargs or {}
self.grad_scaler = torch.cuda.amp.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate=False):
self.grad_scaler.scale(loss).backward(create_graph=self.create_graph)
if accumulate:
# unscale first?
return
if self.clip_fn is not None:
# unscale the gradients of optimizer's assigned params in-place
self.grad_scaler.unscale_(self.optimizer)
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
self.reset()

@ -0,0 +1,30 @@
from dataclasses import dataclass, field, InitVar
import torch
try:
import deepspeed as ds
except ImportError as e:
ds = None
from .updater import Updater
@dataclass
class UpdaterDeepSpeed(Updater):
def __post_init__(self):
super().__post_init__()
# FIXME not sure how to deal with model.module / grad clipping w/ DS engine interface
assert isinstance(self.model, ds.DeepSpeedEngine)
def reset(self):
self.model.zero_grad()
def apply(self, loss: torch.Tensor, accumulate=False):
self.model.backward(loss)
self.model.step()
self.reset()
@property
def deepspeed(self):
return True

@ -0,0 +1,38 @@
from typing import Callable, Optional, Union, Any
import torch
from .device_env import DeviceEnv, DeviceEnvType
from .updater import Updater
from .updater_cuda import UpdaterCudaWithScaler
from .updater_deepspeed import UpdaterDeepSpeed
from .updater_xla import UpdaterXla, UpdaterXlaWithScaler
def create_updater(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
clip_fn: Optional[Union[Callable, str]] = None,
clip_value: Optional[float] = None,
scaler_kwargs: Any = None,
dev_env: Optional[DeviceEnv] = None,
deepspeed: bool = False,
) -> Updater:
if not dev_env:
dev_env = DeviceEnv.instance()
updater_kwargs = dict(model=model, optimizer=optimizer, clip_fn=clip_fn, clip_value=clip_value)
use_scaler = dev_env.amp
if use_scaler:
updater_kwargs['scaler_kwargs'] = scaler_kwargs
updater_cls = Updater
if dev_env.type == DeviceEnvType.XLA:
updater_cls = UpdaterXlaWithScaler if use_scaler else UpdaterXla
elif dev_env.type == DeviceEnvType.CUDA and use_scaler:
updater_cls = UpdaterCudaWithScaler
elif deepspeed:
del updater_kwargs['scaler_kwargs']
updater_cls = UpdaterDeepSpeed
return updater_cls(**updater_kwargs)

@ -0,0 +1,68 @@
from dataclasses import dataclass, field, InitVar
from typing import Any, Dict
import torch
import torch.nn as nn
try:
import torch_xla.core.xla_model as xm
_HAS_XLA = True
except ImportError as e:
xm = None
_HAS_XLA = False
try:
# only the very latest XLA builds have AMP
import torch_xla.amp as xa
except ImportError as e:
xa = None
from .updater import Updater
@dataclass
class UpdaterXla(Updater):
def __post_init__(self):
super().__post_init__()
self.after_step_closure = True
def apply(self, loss: torch.Tensor, accumulate: bool = False):
loss.backward(create_graph=self.create_graph)
if accumulate:
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.optimizer.step()
xm.mark_step()
self.reset()
def after_step(self, after_step_fn, *args):
xm.add_step_closure(after_step_fn, args)
@dataclass
class UpdaterXlaWithScaler(UpdaterXla):
scaler_kwargs: InitVar[Dict[str, Any]] = None
def __post_init__(self, scaler_kwargs: Dict[str, Any]):
super().__post_init__()
scaler_kwargs = scaler_kwargs or {}
assert xa is not None, 'XLA AMP not present in this build'
self.scaler = xa.GradScaler(**scaler_kwargs)
def apply(self, loss: torch.Tensor, accumulate: bool = False):
self.scaler.scale(loss).backward(create_graph=self.create_graph)
if accumulate:
# unscale first?
return
xm.reduce_gradients(self.optimizer)
if self.clip_fn is not None:
self.scaler.unscale_(self.optimizer) # unscale the gradients of optimizer's assigned params in-place
self.clip_fn(self.clip_params_fn(), self.clip_value)
self.scaler.step(self.optimizer)
self.scaler.update()
xm.mark_step()
self.reset()

@ -1,13 +1,13 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .config import resolve_data_config, PreprocessCfg, AugCfg, MixupCfg
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .loader import create_loader
from .loader import create_loader_v2
from .mixup import Mixup, FastCollateMixup
from .parsers import create_parser,\
get_img_extensions, is_img_extension, set_img_extensions, add_img_extensions, del_img_extensions
from .real_labels import RealLabelsImagenet
from .transforms import *
from .transforms_factory import create_transform
from .transforms import RandomResizedCropAndInterpolation, ToTensor, ToNumpy
from .transforms_factory import create_transform_v2, create_transform

@ -36,13 +36,43 @@ _HPARAMS_DEFAULT = dict(
img_mean=_FILL,
)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
if hasattr(Image, "Resampling"):
_RANDOM_INTERPOLATION = (Image.Resampling.BILINEAR, Image.Resampling.BICUBIC)
_DEFAULT_INTERPOLATION = Image.Resampling.BICUBIC
_pil_interpolation_to_str = {
Image.Resampling.NEAREST: 'nearest',
Image.Resampling.BILINEAR: 'bilinear',
Image.Resampling.BICUBIC: 'bicubic',
Image.Resampling.BOX: 'box',
Image.Resampling.HAMMING: 'hamming',
Image.Resampling.LANCZOS: 'lanczos',
}
else:
_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
_DEFAULT_INTERPOLATION = Image.BICUBIC
_pil_interpolation_to_str = {
Image.NEAREST: 'nearest',
Image.BILINEAR: 'bilinear',
Image.BICUBIC: 'bicubic',
Image.BOX: 'box',
Image.HAMMING: 'hamming',
Image.LANCZOS: 'lanczos',
}
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
def _pil_interp(method):
if isinstance(method, (list, tuple)):
return [_str_to_pil_interpolation(m) if isinstance(m, str) else m for m in method]
else:
return _str_to_pil_interpolation(method) if isinstance(method, str) else method
def _interpolation(kwargs):
interpolation = kwargs.pop('resample', _DEFAULT_INTERPOLATION)
@ -329,7 +359,7 @@ class AugmentOp:
self.hparams = hparams.copy()
self.kwargs = dict(
fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL,
resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
resample=_pil_interp(hparams['interpolation']) if 'interpolation' in hparams else _RANDOM_INTERPOLATION,
)
# If magnitude_std is > 0, we introduce some randomness

@ -0,0 +1,38 @@
import numpy as np
import torch
def fast_collate(batch):
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=batch[0][0].dtype)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False

@ -1,10 +1,53 @@
import logging
from dataclasses import dataclass
from typing import Tuple, Optional, Union
from .constants import *
_logger = logging.getLogger(__name__)
@dataclass
class AugCfg:
scale_range: Tuple[float, float] = (0.08, 1.0)
ratio_range: Tuple[float, float] = (3 / 4, 4 / 3)
hflip_prob: float = 0.5
vflip_prob: float = 0.
color_jitter: float = 0.4
auto_augment: Optional[str] = None
re_prob: float = 0.
re_mode: str = 'const'
re_count: int = 1
num_aug_splits: int = 0
@dataclass
class PreprocessCfg:
input_size: Tuple[int, int, int] = (3, 224, 224)
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD
interpolation: str = 'bilinear'
crop_pct: float = 0.875
aug: AugCfg = None
@dataclass
class MixupCfg:
prob: float = 1.0
switch_prob: float = 0.5
mixup_alpha: float = 1.
cutmix_alpha: float = 0.
cutmix_minmax: Optional[Tuple[float, float]] = None
mode: str = 'batch'
correct_lam: bool = True
label_smoothing: float = 0.1
num_classes: int = 0
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
new_config = {}
default_cfg = default_cfg
@ -74,6 +117,14 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct
if getattr(args, 'mixup', 0) > 0 \
or getattr(args, 'cutmix', 0) > 0. \
or getattr(args, 'cutmix_minmax', None) is not None:
new_config['mixup'] = dict(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.num_classes)
if verbose:
_logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():

@ -69,6 +69,7 @@ def create_dataset(
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset
* all - any of the above
Args:
@ -121,12 +122,14 @@ def create_dataset(
elif name == 'imagenet':
if split in _EVAL_SYNONYM:
split = 'val'
torch_kwargs.pop("download")
ds = ImageNet(split=split, **torch_kwargs)
elif name == 'image_folder' or name == 'folder':
# in case torchvision ImageFolder is preferred over timm ImageDataset for some reason
if search_split and os.path.isdir(root):
# look for split specific sub-folder in root
root = _search_split(root, split)
torch_kwargs.pop("download")
ds = ImageFolder(root, **kwargs)
else:
assert False, f"Unknown torchvision dataset {name}"
@ -134,6 +137,10 @@ def create_dataset(
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training,
download=download, batch_size=batch_size, repeats=repeats, **kwargs)
elif name.startswith('wds/'):
ds = IterableImageDataset(
root, parser=name, split=split, is_training=is_training,
batch_size=batch_size, repeats=repeats, **kwargs)
else:
# FIXME support more advance split cfg for ImageFolder/Tar datasets in the future
if search_split and os.path.isdir(root):

@ -0,0 +1,88 @@
import torch
from .constants import *
from .random_erasing import RandomErasing
from .mixup import FastCollateMixup
class FetcherXla:
def __init__(self):
pass
class Fetcher:
def __init__(
self,
loader,
device: torch.device,
dtype=torch.float32,
normalize=True,
normalize_shape=(1, 3, 1, 1),
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
num_aug_splits=0,
use_mp_loader=False,
):
self.loader = loader
self.device = torch.device(device)
self.dtype = dtype
if normalize:
self.mean = torch.tensor(
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.:
# NOTE RandomErasing shouldn't be used here w/ XLA devices
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
else:
self.random_erasing = None
self.use_mp_loader = use_mp_loader
if use_mp_loader:
# FIXME testing for TPU use
import torch_xla.distributed.parallel_loader as pl
self._loader = pl.MpDeviceLoader(loader, device)
else:
self._loader = loader
def __iter__(self):
for sample, target in self._loader:
if not self.use_mp_loader:
sample = sample.to(device=self.device)
target = target.to(device=self.device)
sample = sample.to(dtype=self.dtype)
if self.mean is not None:
sample.sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
sample = self.random_erasing(sample)
yield sample, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x

@ -3,241 +3,104 @@
Prefetcher and Fast Collate inspired by NVIDIA APEX example at
https://github.com/NVIDIA/apex/commit/d5e2bb4bdeedd27b1dfaf5bb2b24d6c000dee9be#diff-cf86c282ff7fba81fad27a559379d5bf
Hacked together by / Copyright 2019, Ross Wightman
Hacked together by / Copyright 2019 Ross Wightman
"""
import random
from functools import partial
from itertools import repeat
from typing import Callable
import torch.utils.data
from typing import Optional, Callable
import numpy as np
import torch.utils.data
from .transforms_factory import create_transform
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .distributed_sampler import OrderedDistributedSampler, RepeatAugSampler
from .random_erasing import RandomErasing
from timm.bits import DeviceEnv
from .collate import fast_collate
from .config import PreprocessCfg, MixupCfg
from .distributed_sampler import OrderedDistributedSampler
from .fetcher import Fetcher
from .mixup import FastCollateMixup
from .prefetcher_cuda import PrefetcherCuda
from .transforms_factory import create_transform_v2
def fast_collate(batch):
""" A fast collation function optimized for uint8 images (np array or torch) and int64 targets (labels)"""
assert isinstance(batch[0], tuple)
batch_size = len(batch)
if isinstance(batch[0][0], tuple):
# This branch 'deinterleaves' and flattens tuples of input tensors into one tensor ordered by position
# such that all tuple of position n will end up in a torch.split(tensor, batch_size) in nth position
inner_tuple_size = len(batch[0][0])
flattened_batch_size = batch_size * inner_tuple_size
targets = torch.zeros(flattened_batch_size, dtype=torch.int64)
tensor = torch.zeros((flattened_batch_size, *batch[0][0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
assert len(batch[i][0]) == inner_tuple_size # all input tensor tuples must be same length
for j in range(inner_tuple_size):
targets[i + j * batch_size] = batch[i][1]
tensor[i + j * batch_size] += torch.from_numpy(batch[i][0][j])
return tensor, targets
elif isinstance(batch[0][0], np.ndarray):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i] += torch.from_numpy(batch[i][0])
return tensor, targets
elif isinstance(batch[0][0], torch.Tensor):
targets = torch.tensor([b[1] for b in batch], dtype=torch.int64)
assert len(targets) == batch_size
tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
for i in range(batch_size):
tensor[i].copy_(batch[i][0])
return tensor, targets
else:
assert False
def expand_to_chs(x, n):
if not isinstance(x, (tuple, list)):
x = tuple(repeat(x, n))
elif len(x) == 1:
x = x * n
else:
assert len(x) == n, 'normalization stats must match image channels'
return x
class PrefetchLoader:
def __init__(
self,
loader,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
channels=3,
fp16=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0):
mean = expand_to_chs(mean, channels)
std = expand_to_chs(std, channels)
normalization_shape = (1, channels, 1, 1)
self.loader = loader
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(normalization_shape)
self.std = torch.tensor([x * 255 for x in std]).cuda().view(normalization_shape)
self.fp16 = fp16
if fp16:
self.mean = self.mean.half()
self.std = self.std.half()
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits)
else:
self.random_erasing = None
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.cuda(non_blocking=True)
next_target = next_target.cuda(non_blocking=True)
if self.fp16:
next_input = next_input.half().sub_(self.mean).div_(self.std)
else:
next_input = next_input.float().sub_(self.mean).div_(self.std)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x
def _worker_init(worker_id, worker_seeding='all'):
def _worker_init(worker_id):
worker_info = torch.utils.data.get_worker_info()
assert worker_info.id == worker_id
if isinstance(worker_seeding, Callable):
seed = worker_seeding(worker_info)
random.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed % (2 ** 32 - 1))
else:
assert worker_seeding in ('all', 'part')
# random / torch seed already called in dataloader iter class w/ worker_info.seed
# to reproduce some old results (same seed + hparam combo), partial seeding is required (skip numpy re-seed)
if worker_seeding == 'all':
np.random.seed(worker_info.seed % (2 ** 32 - 1))
def create_loader(
dataset,
input_size,
batch_size,
is_training=False,
use_prefetcher=True,
no_aug=False,
re_prob=0.,
re_mode='const',
re_count=1,
re_split=False,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
num_aug_repeats=0,
num_aug_splits=0,
interpolation='bilinear',
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
num_workers=1,
distributed=False,
crop_pct=None,
collate_fn=None,
pin_memory=False,
fp16=False,
tf_preprocessing=False,
use_multi_epochs_loader=False,
persistent_workers=True,
worker_seeding='all',
np.random.seed(worker_info.seed % (2**32-1))
def create_loader_v2(
dataset: torch.utils.data.Dataset,
batch_size: int,
is_training: bool = False,
dev_env: Optional[DeviceEnv] = None,
pp_cfg: PreprocessCfg = PreprocessCfg(),
mix_cfg: MixupCfg = None,
create_transform: bool = True,
normalize_in_transform: bool = True,
separate_transform: bool = False,
num_workers: int = 1,
collate_fn: Optional[Callable] = None,
pin_memory: bool = False,
use_multi_epochs_loader: bool = False,
persistent_workers: bool = True,
):
re_num_splits = 0
if re_split:
# apply RE to second half of batch if no aug split otherwise line up with aug split
re_num_splits = num_aug_splits or 2
dataset.transform = create_transform(
input_size,
is_training=is_training,
use_prefetcher=use_prefetcher,
no_aug=no_aug,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
mean=mean,
std=std,
crop_pct=crop_pct,
tf_preprocessing=tf_preprocessing,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=num_aug_splits > 0,
)
"""
Args:
dataset:
batch_size:
is_training:
dev_env:
pp_cfg:
mix_cfg:
create_transform:
normalize_in_transform:
separate_transform:
num_workers:
collate_fn:
pin_memory:
use_multi_epochs_loader:
persistent_workers:
Returns:
"""
if dev_env is None:
dev_env = DeviceEnv.instance()
if create_transform:
dataset.transform = create_transform_v2(
cfg=pp_cfg,
is_training=is_training,
normalize=normalize_in_transform,
separate=separate_transform,
)
sampler = None
if distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if dev_env.distributed and not isinstance(dataset, torch.utils.data.IterableDataset):
if is_training:
if num_aug_repeats:
sampler = RepeatAugSampler(dataset, num_repeats=num_aug_repeats)
else:
sampler = torch.utils.data.distributed.DistributedSampler(dataset)
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
else:
# This will add extra duplicate entries to result in equal num
# of samples per-process, will slightly alter validation results
sampler = OrderedDistributedSampler(dataset)
else:
assert num_aug_repeats == 0, "RepeatAugment not currently supported in non-distributed or IterableDataset use"
sampler = OrderedDistributedSampler(
dataset, num_replicas=dev_env.world_size, rank=dev_env.global_rank)
if collate_fn is None:
collate_fn = fast_collate if use_prefetcher else torch.utils.data.dataloader.default_collate
if mix_cfg is not None and mix_cfg.prob > 0:
collate_fn = FastCollateMixup(
mixup_alpha=mix_cfg.mixup_alpha,
cutmix_alpha=mix_cfg.cutmix_alpha,
cutmix_minmax=mix_cfg.cutmix_minmax,
prob=mix_cfg.prob,
switch_prob=mix_cfg.switch_prob,
mode=mix_cfg.mode,
correct_lam=mix_cfg.correct_lam,
label_smoothing=mix_cfg.label_smoothing,
num_classes=mix_cfg.num_classes,
)
else:
collate_fn = fast_collate
loader_class = torch.utils.data.DataLoader
if use_multi_epochs_loader:
@ -251,27 +114,32 @@ def create_loader(
collate_fn=collate_fn,
pin_memory=pin_memory,
drop_last=is_training,
worker_init_fn=partial(_worker_init, worker_seeding=worker_seeding),
persistent_workers=persistent_workers
)
worker_init_fn=_worker_init,
persistent_workers=persistent_workers)
try:
loader = loader_class(dataset, **loader_args)
except TypeError as e:
loader_args.pop('persistent_workers') # only in Pytorch 1.7+
loader = loader_class(dataset, **loader_args)
if use_prefetcher:
prefetch_re_prob = re_prob if is_training and not no_aug else 0.
loader = PrefetchLoader(
loader,
mean=mean,
std=std,
channels=input_size[0],
fp16=fp16,
re_prob=prefetch_re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits
)
fetcher_kwargs = dict(
normalize=not normalize_in_transform,
mean=pp_cfg.mean,
std=pp_cfg.std,
)
if not normalize_in_transform and is_training and pp_cfg.aug is not None:
# If normalization can be done in the prefetcher, random erasing is done there too
# NOTE RandomErasing does not work well in XLA so normalize_in_transform will be True
fetcher_kwargs.update(dict(
re_prob=pp_cfg.aug.re_prob,
re_mode=pp_cfg.aug.re_mode,
re_count=pp_cfg.aug.re_count,
num_aug_splits=pp_cfg.aug.num_aug_splits,
))
if dev_env.type_cuda:
loader = PrefetcherCuda(loader, **fetcher_kwargs)
else:
loader = Fetcher(loader, device=dev_env.device, **fetcher_kwargs)
return loader

@ -102,7 +102,8 @@ class Mixup:
num_classes (int): number of classes for target
"""
def __init__(self, mixup_alpha=1., cutmix_alpha=0., cutmix_minmax=None, prob=1.0, switch_prob=0.5,
mode='batch', correct_lam=True, label_smoothing=0.1, num_classes=1000):
mode='batch', correct_lam=True, label_smoothing=0., num_classes=0):
assert num_classes > 0, 'num_classes must be set for target generation'
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.cutmix_minmax = cutmix_minmax
@ -218,17 +219,30 @@ class Mixup:
return x, target
def blend(a, b, lam, is_tensor=False, round_output=True):
if is_tensor:
blend = a.to(dtype=torch.float32) * lam + b.to(dtype=torch.float32) * (1 - lam)
if round_output:
torch.round(blend, out=blend)
else:
blend = a.astype(np.float32) * lam + b.astype(np.float32) * (1 - lam)
if round_output:
np.rint(blend, out=blend)
return blend
class FastCollateMixup(Mixup):
""" Fast Collate w/ Mixup/Cutmix that applies different params to each element or whole batch
A Mixup impl that's performed while collating the batches.
"""
def _mix_elem_collate(self, output, batch, half=False):
def _mix_elem_collate(self, output, batch, half=False, is_tensor=False):
batch_size = len(batch)
num_elem = batch_size // 2 if half else batch_size
assert len(output) == num_elem
lam_batch, use_cutmix = self._params_per_elem(num_elem)
round_output = output.dtype == torch.uint8
for i in range(num_elem):
j = batch_size - i - 1
lam = lam_batch[i]
@ -236,22 +250,23 @@ class FastCollateMixup(Mixup):
if lam != 1.:
if use_cutmix[i]:
if not half:
mixed = mixed.copy()
mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
lam_batch[i] = lam
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output)
mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8))
output[i].copy_(mixed)
if half:
lam_batch = np.concatenate((lam_batch, np.ones(num_elem)))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_pair_collate(self, output, batch):
def _mix_pair_collate(self, output, batch, is_tensor=False):
batch_size = len(batch)
lam_batch, use_cutmix = self._params_per_elem(batch_size // 2)
round_output = output.dtype == torch.uint8
for i in range(batch_size // 2):
j = batch_size - i - 1
lam = lam_batch[i]
@ -262,24 +277,30 @@ class FastCollateMixup(Mixup):
if use_cutmix[i]:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
patch_i = mixed_i[:, yl:yh, xl:xh].copy()
patch_i = mixed_i[:, yl:yh, xl:xh]
patch_i = patch_i.clone() if is_tensor else patch_i.copy() # don't want to modify while iterating
mixed_i[:, yl:yh, xl:xh] = mixed_j[:, yl:yh, xl:xh]
mixed_j[:, yl:yh, xl:xh] = patch_i
lam_batch[i] = lam
else:
mixed_temp = mixed_i.astype(np.float32) * lam + mixed_j.astype(np.float32) * (1 - lam)
mixed_j = mixed_j.astype(np.float32) * lam + mixed_i.astype(np.float32) * (1 - lam)
mixed_temp = blend(mixed_i, mixed_j, lam, is_tensor, round_output)
mixed_j = blend(mixed_j, mixed_i, lam, is_tensor, round_output)
mixed_i = mixed_temp
np.rint(mixed_j, out=mixed_j)
np.rint(mixed_i, out=mixed_i)
output[i] += torch.from_numpy(mixed_i.astype(np.uint8))
output[j] += torch.from_numpy(mixed_j.astype(np.uint8))
if is_tensor:
mixed_i = mixed_i.to(dtype=output.dtype)
mixed_j = mixed_j.to(dtype=output.dtype)
else:
mixed_i = torch.from_numpy(mixed_i.astype(np.uint8))
mixed_j = torch.from_numpy(mixed_j.astype(np.uint8))
output[i].copy_(mixed_i)
output[j].copy_(mixed_j)
lam_batch = np.concatenate((lam_batch, lam_batch[::-1]))
return torch.tensor(lam_batch).unsqueeze(1)
def _mix_batch_collate(self, output, batch):
def _mix_batch_collate(self, output, batch, is_tensor=False):
batch_size = len(batch)
lam, use_cutmix = self._params_per_batch()
round_output = output.dtype == torch.uint8
if use_cutmix:
(yl, yh, xl, xh), lam = cutmix_bbox_and_lam(
output.shape, lam, ratio_minmax=self.cutmix_minmax, correct_lam=self.correct_lam)
@ -288,12 +309,12 @@ class FastCollateMixup(Mixup):
mixed = batch[i][0]
if lam != 1.:
if use_cutmix:
mixed = mixed.copy() # don't want to modify the original while iterating
mixed = mixed.clone() if is_tensor else mixed.copy() # don't want to modify while iterating
mixed[:, yl:yh, xl:xh] = batch[j][0][:, yl:yh, xl:xh]
else:
mixed = mixed.astype(np.float32) * lam + batch[j][0].astype(np.float32) * (1 - lam)
np.rint(mixed, out=mixed)
output[i] += torch.from_numpy(mixed.astype(np.uint8))
mixed = blend(mixed, batch[j][0], lam, is_tensor, round_output)
mixed = mixed.to(dtype=output.dtype) if is_tensor else torch.from_numpy(mixed.astype(np.uint8))
output[i].copy_(mixed)
return lam
def __call__(self, batch, _=None):
@ -302,13 +323,15 @@ class FastCollateMixup(Mixup):
half = 'half' in self.mode
if half:
batch_size //= 2
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8)
is_tensor = isinstance(batch[0][0], torch.Tensor)
output_dtype = batch[0][0].dtype if is_tensor else torch.uint8 # always uint8 if numpy src
output = torch.zeros((batch_size, *batch[0][0].shape), dtype=output_dtype)
if self.mode == 'elem' or self.mode == 'half':
lam = self._mix_elem_collate(output, batch, half=half)
lam = self._mix_elem_collate(output, batch, half=half, is_tensor=is_tensor)
elif self.mode == 'pair':
lam = self._mix_pair_collate(output, batch)
lam = self._mix_pair_collate(output, batch, is_tensor=is_tensor)
else:
lam = self._mix_batch_collate(output, batch)
lam = self._mix_batch_collate(output, batch, is_tensor=is_tensor)
target = torch.tensor([b[1] for b in batch], dtype=torch.int64)
target = mixup_target(target, self.num_classes, lam, self.label_smoothing, device='cpu')
target = target[:batch_size]

@ -10,13 +10,17 @@ def create_parser(name, root, split='train', **kwargs):
prefix = ''
if len(name) > 1:
prefix = name[0]
name = name[-1]
name = "/".join(name[1:])
# FIXME improve the selection right now just tfds prefix or fallback path, will need options to
# explicitly select other options shortly
if prefix == 'tfds':
from .parser_tfds import ParserTfds # defer tensorflow import
parser = ParserTfds(root, name, split=split, **kwargs)
elif prefix == 'wds':
from .parser_wds import ParserWebdataset
kwargs.pop('download', False)
parser = ParserWebdataset(root, name, split=split, **kwargs)
else:
assert os.path.exists(root)
# default fallback path (backwards compat), use image tar if root is a .tar file, otherwise image folder

@ -8,7 +8,6 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import torch
import torch.distributed as dist
from PIL import Image
try:
@ -32,14 +31,15 @@ except ImportError as e:
exit(1)
from .parser import Parser
from timm.bits import get_global_device, is_global_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 8192 # examples to shuffle in DS queue
PREFETCH_SIZE = 2048 # examples to prefetch
SHUFFLE_SIZE = 8192 # number of samples to shuffle in DS queue
PREFETCH_SIZE = 2048 # number of samples to prefetch
def even_split_indices(split, n, num_examples):
partitions = [round(i * num_examples / n) for i in range(n + 1)]
def even_split_indices(split, n, num_samples):
partitions = [round(i * num_samples / n) for i in range(n + 1)]
return [f"{split}[{partitions[i]}:{partitions[i + 1]}]" for i in range(n)]
@ -55,20 +55,20 @@ class ParserTfds(Parser):
""" Wrap Tensorflow Datasets for use in PyTorch
There several things to be aware of:
* To prevent excessive examples being dropped per epoch w/ distributed training or multiplicity of
* To prevent excessive samples being dropped per epoch w/ distributed training or multiplicity of
dataloader workers, the train iterator wraps to avoid returning partial batches that trigger drop_last
https://github.com/pytorch/pytorch/issues/33413
* With PyTorch IterableDatasets, each worker in each replica operates in isolation, the final batch
from each worker could be a different size. For training this is worked around by option above, for
validation extra examples are inserted iff distributed mode is enabled so that the batches being reduced
validation extra samples are inserted iff distributed mode is enabled so that the batches being reduced
across replicas are of same size. This will slightly alter the results, distributed validation will not be
100% correct. This is similar to common handling in DistributedSampler for normal Datasets but a bit worse
since there are up to N * J extra examples with IterableDatasets.
since there are up to N * J extra samples with IterableDatasets.
* The sharding (splitting of dataset into TFRecord) files imposes limitations on the number of
replicas and dataloader workers you can use. For really small datasets that only contain a few shards
you may have to train non-distributed w/ 1-2 dataloader workers. This is likely not a huge concern as the
benefit of distributed training or fast dataloading should be much less for small datasets.
* This wrapper is currently configured to return individual, decompressed image examples from the TFDS
* This wrapper is currently configured to return individual, decompressed image samples from the TFDS
dataset. The augmentation (transforms) and batching is still done in PyTorch. It would be possible
to specify TF augmentation fn and return augmented batches w/ some modifications to other downstream
components.
@ -100,7 +100,7 @@ class ParserTfds(Parser):
name: tfds dataset name (eg `imagenet2012`)
split: tfds dataset split (can use all TFDS split strings eg `train[:10%]`)
is_training: training mode, shuffle enabled, dataset len rounded by batch_size
batch_size: batch_size to use to unsure total examples % batch_size == 0 in training across all dis nodes
batch_size: batch_size to use to unsure total samples % batch_size == 0 in training across all dis nodes
download: download and build TFDS dataset if set, otherwise must use tfds CLI
repeats: iterate through (repeat) the dataset this many times per iteration (once if 0 or 1)
seed: common seed for shard shuffle across all distributed/worker instances
@ -123,7 +123,7 @@ class ParserTfds(Parser):
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
# performance settings
# Performance settings
self.prefetch_size = prefetch_size or PREFETCH_SIZE
self.shuffle_size = shuffle_size or SHUFFLE_SIZE
self.max_threadpool_size = max_threadpool_size or MAX_TP_SIZE
@ -139,21 +139,36 @@ class ParserTfds(Parser):
self.builder.download_and_prepare()
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split]
self.num_examples = self.split_info.num_examples
self.num_samples = self.split_info.num_examples
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
if is_global_device():
dev_env = get_global_device()
if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size
else:
# FIXME warn if we fallback to torch distributed?
import torch.distributed as dist
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init, including the tf.data pipeline itself
self.global_num_workers = 1
self.worker_info = None
self.worker_init = False # worker info initialized
self.worker_id = 0
self.worker_seed = 0 # seed unique to each work instance
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
self.subsplit = None # set when data is distributed across workers using sub-splits
self.ds = None # initialized lazily on each dataloader worker process
self.init_count = 0 # number of ds TF data pipeline initializations
# FIXME need to determine if reinit_each_iter is necessary. I'm don't completely trust behaviour
# of `shuffle_reshuffle_each_iteration` when there are multiple workers / nodes across epochs
self.reinit_each_iter = self.is_training
def _lazy_init(self):
""" Lazily initialize the dataset.
@ -166,17 +181,16 @@ class ParserTfds(Parser):
instances once it has been initialized. Do not call any dataset methods that can call _lazy_init
before it is passed to dataloader.
"""
worker_info = torch.utils.data.get_worker_info()
# setup input context to split dataset across distributed processes
num_workers = 1
global_worker_id = 0
if worker_info is not None:
self.worker_info = worker_info
self.worker_seed = worker_info.seed
num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * num_workers
global_worker_id = self.dist_rank * num_workers + worker_info.id
if not self.worker_init:
# worker init done once, even if data-pipeline is re-initialized
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_id = worker_info.id
self.worker_seed = worker_info.seed
self.num_workers = worker_info.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
self.global_num_workers = self.dist_num_replicas * self.num_workers
""" Data sharding
InputContext will assign subset of underlying TFRecord files to each 'pipeline' if used.
@ -186,61 +200,72 @@ class ParserTfds(Parser):
I am currently using a mix of InputContext shard assignment and fine-grained sub-splits for distributing
the data across workers. For training InputContext is used to assign shards to nodes unless num_shards
in dataset < total number of workers. Otherwise sub-split API is used for datasets without enough shards or
for validation where we can't drop examples and need to avoid minimize uneven splits to avoid padding.
for validation where we can't drop samples and need to avoid minimize uneven splits to avoid padding.
"""
should_subsplit = self.global_num_workers > 1 and (
self.split_info.num_shards < self.global_num_workers or not self.is_training)
if should_subsplit:
# split the dataset w/o using sharding for more even examples / worker, can result in less optimal
# split the dataset w/o using sharding for more even samples / worker, can result in less optimal
# read patterns for distributed training (overlap across shards) so better to use InputContext there
if has_buggy_even_splits:
# my even_split workaround doesn't work on subsplits, upgrade tfds!
if not isinstance(self.split_info, tfds.core.splits.SubSplitInfo):
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_examples)
self.subsplit = subsplits[global_worker_id]
subsplits = even_split_indices(self.split, self.global_num_workers, self.num_samples)
self.subsplit = subsplits[self.global_worker_id]
else:
subsplits = tfds.even_splits(self.split, self.global_num_workers)
self.subsplit = subsplits[global_worker_id]
self.subsplit = subsplits[self.global_worker_id]
self.worker_init = True
# initialize TF data pipeline
input_context = None
if self.global_num_workers > 1 and self.subsplit is None:
# set input context to divide shards among distributed replicas
input_context = tf.distribute.InputContext(
num_input_pipelines=self.global_num_workers,
input_pipeline_id=global_worker_id,
input_pipeline_id=self.global_worker_id,
num_replicas_in_sync=self.dist_num_replicas # FIXME does this arg have any impact?
)
read_config = tfds.ReadConfig(
shuffle_seed=self.common_seed,
shuffle_reshuffle_each_iteration=True,
shuffle_seed=self.common_seed + self.init_count, # shard shuffling seed
shuffle_reshuffle_each_iteration=not self.reinit_each_iter, # re-shuffle shards per iteration
input_context=input_context)
ds = self.builder.as_dataset(
split=self.subsplit or self.split, shuffle_files=self.is_training, read_config=read_config)
split=self.subsplit or self.split,
shuffle_files=self.is_training, # enable shard shuffling
read_config=read_config)
# avoid overloading threading w/ combo of TF ds threads + PyTorch workers
options = tf.data.Options()
thread_member = 'threading' if hasattr(options, 'threading') else 'experimental_threading'
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // num_workers)
getattr(options, thread_member).private_threadpool_size = max(1, self.max_threadpool_size // self.num_workers)
getattr(options, thread_member).max_intra_op_parallelism = 1
ds = ds.with_options(options)
if self.is_training or self.repeats > 1:
# to prevent excessive drop_last batch behaviour w/ IterableDatasets
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.is_training:
ds = ds.shuffle(min(self.num_examples, self.shuffle_size) // self.global_num_workers, seed=self.worker_seed)
ds = ds.prefetch(min(self.num_examples // self.global_num_workers, self.prefetch_size))
# shuffle samples
ds = ds.shuffle(
min(self.num_samples, self.shuffle_size) // self.global_num_workers,
seed=self.worker_seed + self.init_count)
ds = ds.prefetch(min(self.num_samples // self.global_num_workers, self.prefetch_size))
self.ds = tfds.as_numpy(ds)
self.init_count += 1
def __iter__(self):
if self.ds is None:
if self.ds is None or self.reinit_each_iter:
self._lazy_init()
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra examples and will slightly alter validation results.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_example_count = math.ceil(max(1, self.repeats) * self.num_examples / self.global_num_workers)
target_example_count = math.ceil(max(1, self.repeats) * self.num_samples / self.global_num_workers)
if self.is_training:
# round up to nearest batch_size per worker-replica
target_example_count = math.ceil(target_example_count / self.batch_size) * self.batch_size
@ -258,11 +283,11 @@ class ParserTfds(Parser):
example_count += 1
if self.is_training and example_count >= target_example_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra examples per epoch but seems more desirable than dropping
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
# Pad across distributed nodes (make counts equal by adding examples)
# Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < example_count < target_example_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
@ -274,12 +299,12 @@ class ParserTfds(Parser):
example_count += 1
def __len__(self):
# this is just an estimate and does not factor in extra examples added to pad batches based on
# this is just an estimate and does not factor in extra samples added to pad batches based on
# complete worker & replica info (not available until init in dataloader).
return math.ceil(max(1, self.repeats) * self.num_examples / self.dist_num_replicas)
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
assert False, "Not supported" # no random access to samples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
@ -287,7 +312,7 @@ class ParserTfds(Parser):
self._lazy_init()
names = []
for sample in self.ds:
if len(names) > self.num_examples:
if len(names) >= self.num_samples:
break # safety for ds.repeat() case
if 'file_name' in sample:
name = sample['file_name']

@ -0,0 +1,330 @@
""" Dataset parser interface for webdataset
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
import os
import io
import json
import logging
import random
import yaml
from dataclasses import dataclass
from itertools import islice
from functools import partial
from typing import Dict, Tuple
import torch
from PIL import Image
try:
import webdataset as wds
from webdataset.shardlists import expand_urls
except ImportError:
wds = None
expand_urls = None
from .parser import Parser
from timm.bits import get_global_device, is_global_device
_logger = logging.getLogger(__name__)
SHUFFLE_SIZE = 8192
def _load_info(root, basename='info'):
info_json = os.path.join(root, basename + '.json')
info_yaml = os.path.join(root, basename + '.yaml')
err_str = ''
try:
with wds.gopen.gopen(info_json) as f:
info_dict = json.load(f)
return info_dict
except Exception:
pass
try:
with wds.gopen.gopen(info_yaml) as f:
info_dict = yaml.safe_load(f)
return info_dict
except Exception as e:
err_str = str(e)
# FIXME change to log
print(f'Dataset info file not found at {info_json} or {info_yaml}. Error: {err_str}. '
f'Falling back to provided split and size arg.')
return {}
@dataclass
class SplitInfo:
num_samples: int
filenames: Tuple[str]
shard_lengths: Tuple[int] = ()
alt_label: str = ''
name: str = ''
def _parse_split_info(split: str, info: Dict):
def _info_convert(dict_info):
return SplitInfo(
num_samples=dict_info['num_samples'],
filenames=tuple(dict_info['filenames']),
shard_lengths=tuple(dict_info['shard_lengths']),
alt_label=dict_info.get('alt_label', ''),
name=dict_info['name'],
)
if 'tar' in split or '..' in split:
# split in WDS string braceexpand format, sample count can be included with a | separator
# ex: `dataset-split-{0000..9999}.tar|100000` for 9999 shards, covering 100,000 samples
split = split.split('|')
num_samples = 0
split_name = ''
if len(split) > 1:
num_samples = int(split[1])
split = split[0]
if '::' not in split:
split_parts = split.split('-', 3)
split_idx = len(split_parts) - 1
if split_idx and 'splits' in info and split_parts[split_idx] in info['splits']:
split_name = split_parts[split_idx]
split_filenames = expand_urls(split)
if split_name:
split_info = info['splits'][split_name]
if not num_samples:
_fc = {f: c for f, c in zip(split_info['filenames'], split_info['shard_lengths'])}
num_samples = sum(_fc[f] for f in split_filenames)
split_info['filenames'] = tuple(_fc.keys())
split_info['shard_lengths'] = tuple(_fc.values())
split_info['num_samples'] = num_samples
split_info = _info_convert(split_info)
else:
split_info = SplitInfo(
name=split_name,
num_samples=num_samples,
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
return split_info
def log_and_continue(exn):
"""Call in an exception handler to ignore any exception, isssue a warning, and continue."""
_logger.warning(f'Handling webdataset error ({repr(exn)}). Ignoring.')
return True
def _decode(sample, image_key='jpg', image_format='RGB', target_key='cls', alt_label=''):
""" Custom sample decode
* decode and convert PIL Image
* cls byte string label to int
* pass through JSON byte string (if it exists) without parse
"""
# decode class label, skip if alternate label not valid
if alt_label:
# alternative labels are encoded in json metadata
meta = json.loads(sample['json'])
class_label = int(meta[alt_label])
if class_label < 0:
# skipped labels currently encoded as -1, may change to a null/None value
return None
else:
class_label = int(sample[target_key])
# decode image
with io.BytesIO(sample[image_key]) as b:
img = Image.open(b)
img.load()
if image_format:
img = img.convert(image_format)
# json passed through in undecoded state
decoded = dict(jpg=img, cls=class_label, json=sample.get('json', None))
return decoded
def _decode_samples(
data,
image_key='jpg',
image_format='RGB',
target_key='cls',
alt_label='',
handler=log_and_continue):
"""Decode samples with skip."""
for sample in data:
try:
result = _decode(
sample, image_key=image_key, image_format=image_format, target_key=target_key, alt_label=alt_label)
except Exception as exn:
if handler(exn):
continue
else:
break
# null results are skipped
if result is not None:
if isinstance(sample, dict) and isinstance(result, dict):
result["__key__"] = sample.get("__key__")
yield result
class ParserWebdataset(Parser):
def __init__(
self,
root,
name,
split,
is_training=False,
batch_size=None,
repeats=0,
seed=42,
input_name='image',
input_image='RGB',
target_name=None,
target_image='',
prefetch_size=None,
shuffle_size=None,
):
super().__init__()
if wds is None:
raise RuntimeError(
'Please install webdataset 0.2.x package `pip install git+https://github.com/webdataset/webdataset`.')
self.root = root
self.is_training = is_training
self.batch_size = batch_size
self.repeats = repeats
self.common_seed = seed # a seed that's fixed across all worker / distributed instances
self.shard_shuffle_size = 500
self.sample_shuffle_size = shuffle_size or SHUFFLE_SIZE
self.image_key = 'jpg'
self.image_format = input_image
self.target_key = 'cls'
self.filename_key = 'filename'
self.key_ext = '.JPEG' # extension to add to key for original filenames (DS specific, default ImageNet)
self.info = _load_info(self.root)
self.split_info = _parse_split_info(split, self.info)
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
# Distributed world state
self.dist_rank = 0
self.dist_num_replicas = 1
if is_global_device():
dev_env = get_global_device()
if dev_env.distributed and dev_env.world_size > 1:
self.dist_rank = dev_env.global_rank
self.dist_num_replicas = dev_env.world_size
else:
# FIXME warn if we fallback to torch distributed?
import torch.distributed as dist
if dist.is_available() and dist.is_initialized() and dist.get_world_size() > 1:
self.dist_rank = dist.get_rank()
self.dist_num_replicas = dist.get_world_size()
# Attributes that are updated in _lazy_init
self.worker_id = 0
self.worker_seed = seed # seed unique to each worker instance
self.num_workers = 1
self.global_worker_id = 0
self.global_num_workers = 1
self.init_count = 0
# DataPipeline is lazy init, majority of WDS DataPipeline could be init here, BUT, shuffle seed
# is not handled in manner where it can be deterministic for each worker AND initialized up front
self.ds = None
def _lazy_init(self):
""" Lazily initialize worker (in worker processes)
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
self.worker_id = worker_info.id
self.worker_seed = worker_info.seed
self.num_workers = worker_info.num_workers
self.global_num_workers = self.dist_num_replicas * self.num_workers
self.global_worker_id = self.dist_rank * self.num_workers + self.worker_id
# init data pipeline
abs_shard_filenames = [os.path.join(self.root, f) for f in self.split_info.filenames]
pipeline = [wds.SimpleShardList(abs_shard_filenames)]
# at this point we have an iterator over all the shards
if self.is_training:
pipeline.extend([
wds.detshuffle(self.shard_shuffle_size, seed=self.common_seed),
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
wds.shuffle(
self.sample_shuffle_size,
rng=random.Random(self.worker_seed)), # this is why we lazy-init whole DataPipeline
])
else:
pipeline.extend([
self._split_by_node_and_worker,
# at this point, we have an iterator over the shards assigned to each worker
wds.tarfile_to_samples(handler=log_and_continue),
])
pipeline.extend([
partial(
_decode_samples,
image_key=self.image_key,
image_format=self.image_format,
alt_label=self.split_info.alt_label)
])
self.ds = wds.DataPipeline(*pipeline)
self.init_count += 1
def _split_by_node_and_worker(self, src):
if self.global_num_workers > 1:
for s in islice(src, self.global_worker_id, None, self.global_num_workers):
yield s
else:
for s in src:
yield s
def __iter__(self):
if not self.init_count:
self._lazy_init()
i = 0
num_worker_samples = math.ceil(self.num_samples / self.global_num_workers)
if self.is_training and self.batch_size is not None:
num_worker_samples = (num_worker_samples // self.batch_size) * self.batch_size
ds = self.ds.with_epoch(num_worker_samples)
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
i += 1
print('end', i) # FIXME debug
def __len__(self):
return math.ceil(max(1, self.repeats) * self.num_samples / self.dist_num_replicas)
def _filename(self, index, basename=False, absolute=False):
assert False, "Not supported" # no random access to examples
def filenames(self, basename=False, absolute=False):
""" Return all filenames in dataset, overrides base"""
if not self.init_count:
self._lazy_init()
names = []
for sample in self.ds:
if self.filename_key in sample:
name = sample[self.filename_key]
elif '__key__' in sample:
name = sample['__key__'] + self.key_ext
else:
assert False, "No supported name field present"
names.append(name)
if len(names) >= self.num_samples:
break # safety for ds.repeat() case
return names

@ -0,0 +1 @@
from .imagenet22k import Imagenet22k, Imagenet12k, imagenet12k_synsets, imagenet22k_synsets

File diff suppressed because it is too large Load Diff

@ -0,0 +1,87 @@
import torch.cuda
from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .mixup import FastCollateMixup
from .random_erasing import RandomErasing
class PrefetcherCuda:
def __init__(
self,
loader,
device: torch.device = torch.device('cuda'),
dtype=torch.float32,
normalize=True,
normalize_shape=(1, 3, 1, 1),
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
num_aug_splits=0,
):
self.loader = loader
self.device = device
self.dtype = dtype
if normalize:
self.mean = torch.tensor(
[x * 255 for x in mean], dtype=self.dtype, device=self.device).view(normalize_shape)
self.std = torch.tensor(
[x * 255 for x in std], dtype=self.dtype, device=self.device).view(normalize_shape)
else:
self.mean = None
self.std = None
if re_prob > 0.:
self.random_erasing = RandomErasing(
probability=re_prob, mode=re_mode, count=re_count, num_splits=num_aug_splits)
else:
self.random_erasing = None
def __iter__(self):
stream = torch.cuda.Stream()
first = True
for next_input, next_target in self.loader:
with torch.cuda.stream(stream):
next_input = next_input.to(device=self.device, non_blocking=True)
next_input = next_input.to(dtype=self.dtype)
if self.mean is not None:
next_input.sub_(self.mean).div_(self.std)
next_target = next_target.to(device=self.device, non_blocking=True)
if self.random_erasing is not None:
next_input = self.random_erasing(next_input)
if not first:
yield input, target
else:
first = False
torch.cuda.current_stream().wait_stream(stream)
input = next_input
target = next_target
yield input, target
def __len__(self):
return len(self.loader)
@property
def sampler(self):
return self.loader.sampler
@property
def dataset(self):
return self.loader.dataset
@property
def mixup_enabled(self):
if isinstance(self.loader.collate_fn, FastCollateMixup):
return self.loader.collate_fn.mixup_enabled
else:
return False
@mixup_enabled.setter
def mixup_enabled(self, x):
if isinstance(self.loader.collate_fn, FastCollateMixup):
self.loader.collate_fn.mixup_enabled = x

@ -38,39 +38,37 @@ class RandomErasing:
'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-channel random (normal) color
'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
"""
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', min_count=1, max_count=None, num_splits=0, device='cuda'):
mode='const', count=1, num_splits=0):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count
self.max_count = max_count or min_count
self.count = count
self.num_splits = num_splits
self.mode = mode.lower()
mode = mode.lower()
self.rand_color = False
self.per_pixel = False
if self.mode == 'rand':
if mode == 'rand':
self.rand_color = True # per block random normal
elif self.mode == 'pixel':
elif mode == 'pixel':
self.per_pixel = True # per pixel random normal
else:
assert not self.mode or self.mode == 'const'
self.device = device
assert not mode or mode == 'const'
def _erase(self, img, chan, img_h, img_w, dtype):
device = img.device
if random.random() > self.probability:
return
area = img_h * img_w
count = self.min_count if self.min_count == self.max_count else \
random.randint(self.min_count, self.max_count)
count = random.randint(1, self.count) if self.count > 1 else self.count
for _ in range(count):
for attempt in range(10):
target_area = random.uniform(self.min_area, self.max_area) * area / count
@ -81,23 +79,76 @@ class RandomErasing:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel, self.rand_color, (chan, h, w),
dtype=dtype, device=self.device)
self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=device)
break
def __call__(self, input):
if len(input.size()) == 3:
self._erase(input, *input.size(), input.dtype)
def __call__(self, x):
if len(x.size()) == 3:
self._erase(x, *x.shape, x.dtype)
else:
batch_size, chan, img_h, img_w = input.size()
batch_size, chan, img_h, img_w = x.shape
# skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype)
return input
def __repr__(self):
# NOTE simplified state for repr
fs = self.__class__.__name__ + f'(p={self.probability}, mode={self.mode}'
fs += f', count=({self.min_count}, {self.max_count}))'
return fs
self._erase(x[i], chan, img_h, img_w, x.dtype)
return x
class RandomErasingMasked:
""" Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
This variant of RandomErasing is intended to be applied to either a batch
or single image tensor after it has been normalized by dataset mean and std.
Args:
probability: Probability that the Random Erasing operation will be performed for each box (count)
min_area: Minimum percentage of erased area wrt input image area.
max_area: Maximum percentage of erased area wrt input image area.
min_aspect: Minimum aspect ratio of erased area.
count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is between 0 and this value.
"""
def __init__(
self,
probability=0.5, min_area=0.02, max_area=1/3, min_aspect=0.3, max_aspect=None,
mode='const', count=1, num_splits=0):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.mode = mode # FIXME currently ignored, add back options besides normal mean=0, std=1 noise?
self.count = count
self.num_splits = num_splits
@torch.no_grad()
def __call__(self, x: torch.Tensor) -> torch.Tensor:
device = x.device
batch_size, _, img_h, img_w = x.shape
batch_start = batch_size // self.num_splits if self.num_splits > 1 else 0
# NOTE simplified from v1 with with one count value and same prob applied for all
enable = (torch.empty((batch_size, self.count), device=device).uniform_() < self.probability).float()
enable = enable / enable.sum(dim=1, keepdim=True).clamp(min=1)
target_area = torch.empty(
(batch_size, self.count), device=device).uniform_(self.min_area, self.max_area) * enable
aspect_ratio = torch.empty((batch_size, self.count), device=device).uniform_(*self.log_aspect_ratio).exp()
h_coord = torch.arange(0, img_h, device=device).unsqueeze(-1).expand(-1, self.count).float()
w_coord = torch.arange(0, img_w, device=device).unsqueeze(-1).expand(-1, self.count).float()
h_mid = torch.rand((batch_size, self.count), device=device) * img_h
w_mid = torch.rand((batch_size, self.count), device=device) * img_w
noise = torch.empty_like(x[0]).normal_()
for i in range(batch_start, batch_size):
h_half = (img_h / 2) * torch.sqrt(target_area[i] * aspect_ratio[i]) # 1/2 box h
h_mask = (h_coord > (h_mid[i] - h_half)) & (h_coord < (h_mid[i] + h_half))
w_half = (img_w / 2) * torch.sqrt(target_area[i] / aspect_ratio[i]) # 1/2 box w
w_mask = (w_coord > (w_mid[i] - w_half)) & (w_coord < (w_mid[i] + w_half))
#mask = (h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1)
#x[i].copy_(torch.where(mask, noise, x[i]))
mask = ~(h_mask.unsqueeze(1) & w_mask.unsqueeze(0)).any(dim=-1)
x[i] = x[i].where(mask, noise)
#x[i].masked_scatter_(mask, noise)
return x

@ -22,7 +22,10 @@ Hacked together by / Copyright 2020 Ross Wightman
# limitations under the License.
# ==============================================================================
"""ImageNet preprocessing for MnasNet."""
import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
tf.compat.v1.disable_eager_execution()
import numpy as np
IMAGE_SIZE = 224
@ -131,6 +134,39 @@ def _decode_and_center_crop(image_bytes, image_size, resize_method):
return image
def crop(image_bytes, crop_window):
"""Helper function to crop a jpeg or a decoded image."""
if image_bytes.dtype == tf.dtypes.string:
image = tf.image.decode_and_crop_jpeg(image_bytes,
tf.stack(crop_window),
channels=3)
else:
image = tf.image.crop_to_bounding_box(image_bytes, *crop_window)
return image
def _decode_and_resize_then_crop(
image_bytes: tf.Tensor,
image_size,
crop_pct: float = 32,
) -> tf.Tensor:
"""Rescales an image to image_size / crop_pct, then center crops."""
image = tf.image.decode_jpeg(image_bytes, channels=3)
# Scale image to "scaled size" before taking a center crop
if crop_pct > 1.0: # If crop_pct is >1, treat it as num pad pixels (like VGG)
scale_size = tuple([int(x + crop_pct) for x in image_size])
else:
scale_size = tuple([int(float(x) / crop_pct) for x in image_size])
image = tf.image.resize(image, scale_size, tf.image.ResizeMethod.BICUBIC)
crop_height = tf.cast(image_size[0], tf.int32)
crop_width = tf.cast(image_size[1], tf.int32)
offset_height = ((scale_size[0] - crop_height) + 1) // 2
offset_width = ((scale_size[1] - crop_width) + 1) // 2
crop_window = [offset_height, offset_width, crop_height, crop_width]
image = crop(image, crop_window)
return image
def _flip(image):
"""Random horizontal image flip."""
image = tf.image.random_flip_left_right(image)
@ -172,6 +208,7 @@ def preprocess_for_eval(image_bytes, use_bfloat16, image_size=IMAGE_SIZE, interp
"""
resize_method = tf.image.ResizeMethod.BICUBIC if interpolation == 'bicubic' else tf.image.ResizeMethod.BILINEAR
image = _decode_and_center_crop(image_bytes, image_size, resize_method)
#image = _decode_and_resize_then_crop(image_bytes, (image_size, image_size), resize_method)
image = tf.reshape(image, [image_size, image_size, 3])
image = tf.image.convert_image_dtype(
image, dtype=tf.bfloat16 if use_bfloat16 else tf.float32)

@ -1,11 +1,7 @@
import torch
import torchvision.transforms.functional as F
try:
from torchvision.transforms.functional import InterpolationMode
has_interpolation_mode = True
except ImportError:
has_interpolation_mode = False
from PIL import Image
from torchvision.transforms import InterpolationMode
import warnings
import math
import random
@ -35,65 +31,40 @@ class ToTensor:
return torch.from_numpy(np_img).to(dtype=self.dtype)
# Pillow is deprecating the top-level resampling attributes (e.g., Image.BILINEAR) in
# favor of the Image.Resampling enum. The top-level resampling attributes will be
# removed in Pillow 10.
if hasattr(Image, "Resampling"):
_pil_interpolation_to_str = {
Image.Resampling.NEAREST: 'nearest',
Image.Resampling.BILINEAR: 'bilinear',
Image.Resampling.BICUBIC: 'bicubic',
Image.Resampling.BOX: 'box',
Image.Resampling.HAMMING: 'hamming',
Image.Resampling.LANCZOS: 'lanczos',
}
else:
_pil_interpolation_to_str = {
Image.NEAREST: 'nearest',
Image.BILINEAR: 'bilinear',
Image.BICUBIC: 'bicubic',
Image.BOX: 'box',
Image.HAMMING: 'hamming',
Image.LANCZOS: 'lanczos',
}
_str_to_pil_interpolation = {b: a for a, b in _pil_interpolation_to_str.items()}
if has_interpolation_mode:
_torch_interpolation_to_str = {
InterpolationMode.NEAREST: 'nearest',
InterpolationMode.BILINEAR: 'bilinear',
InterpolationMode.BICUBIC: 'bicubic',
InterpolationMode.BOX: 'box',
InterpolationMode.HAMMING: 'hamming',
InterpolationMode.LANCZOS: 'lanczos',
}
_str_to_torch_interpolation = {b: a for a, b in _torch_interpolation_to_str.items()}
else:
_pil_interpolation_to_torch = {}
_torch_interpolation_to_str = {}
def str_to_pil_interp(mode_str):
return _str_to_pil_interpolation[mode_str]
def str_to_interp_mode(mode_str):
if has_interpolation_mode:
return _str_to_torch_interpolation[mode_str]
else:
return _str_to_pil_interpolation[mode_str]
class ToTensorNormalize:
def __init__(self, mean, std, dtype=torch.float32, device=torch.device('cpu')):
self.dtype = dtype
mean = torch.as_tensor(mean, dtype=dtype, device=device)
std = torch.as_tensor(std, dtype=dtype, device=device)
if (std == 0).any():
raise ValueError('std evaluated to zero after conversion to {}, leading to division by zero.'.format(dtype))
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
self.mean = mean
self.std = std
def interp_mode_to_str(mode):
if has_interpolation_mode:
return _torch_interpolation_to_str[mode]
else:
return _pil_interpolation_to_str[mode]
def __call__(self, pil_img):
mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32}
img = torch.from_numpy(
np.array(pil_img, mode_to_nptype.get(pil_img.mode, np.uint8))
)
if pil_img.mode == '1':
img = 255 * img
img = img.view(pil_img.size[1], pil_img.size[0], len(pil_img.getbands()))
img = img.permute((2, 0, 1))
if isinstance(img, torch.ByteTensor):
img = img.to(self.dtype)
img.sub_(self.mean * 255.).div_(self.std * 255.)
else:
img = img.to(self.dtype)
img.sub_(self.mean).div_(self.std)
return img
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
_RANDOM_INTERPOLATION = (InterpolationMode.BILINEAR, InterpolationMode.BICUBIC)
class RandomResizedCropAndInterpolation:
@ -123,7 +94,7 @@ class RandomResizedCropAndInterpolation:
if interpolation == 'random':
self.interpolation = _RANDOM_INTERPOLATION
else:
self.interpolation = str_to_interp_mode(interpolation)
self.interpolation = InterpolationMode(interpolation)
self.scale = scale
self.ratio = ratio
@ -187,11 +158,13 @@ class RandomResizedCropAndInterpolation:
def __repr__(self):
if isinstance(self.interpolation, (tuple, list)):
interpolate_str = ' '.join([interp_mode_to_str(x) for x in self.interpolation])
interpolate_str = ' '.join([x.value for x in self.interpolation])
else:
interpolate_str = interp_mode_to_str(self.interpolation)
interpolate_str = self.interpolation.value
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
format_string += ', interpolation={0})'.format(interpolate_str)
return format_string

@ -4,60 +4,53 @@ Factory methods for building image transforms for use with TIMM (PyTorch Image M
Hacked together by / Copyright 2019, Ross Wightman
"""
import math
from typing import Union, Tuple
import torch
from torchvision import transforms
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.auto_augment import rand_augment_transform, augment_and_mix_transform, auto_augment_transform
from timm.data.transforms import str_to_interp_mode, str_to_pil_interp, RandomResizedCropAndInterpolation, ToNumpy
from timm.data.config import PreprocessCfg, AugCfg
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from timm.data.random_erasing import RandomErasing
from timm.data.transforms import RandomResizedCropAndInterpolation, ToNumpy, ToTensorNormalize
def transforms_noaug_train(
img_size=224,
img_size: Union[int, Tuple[int]] = 224,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
normalize=False,
compose=True,
):
if interpolation == 'random':
# random interpolation not supported with no-aug
interpolation = 'bilinear'
tfl = [
transforms.Resize(img_size, interpolation=str_to_interp_mode(interpolation)),
transforms.Resize(img_size, transforms.InterpolationMode(interpolation)),
transforms.CenterCrop(img_size)
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
if normalize:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))
]
return transforms.Compose(tfl)
else:
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl += [ToNumpy()]
return transforms.Compose(tfl) if compose else tfl
def transforms_imagenet_train(
img_size=224,
scale=None,
ratio=None,
hflip=0.5,
vflip=0.,
color_jitter=0.4,
auto_augment=None,
img_size: Union[int, Tuple[int]] = 224,
interpolation='random',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD,
re_prob=0.,
re_mode='const',
re_count=1,
re_num_splits=0,
aug_cfg=AugCfg(),
normalize=False,
separate=False,
compose=True,
):
"""
If separate==True, the transforms are returned as a tuple of 3 separate transforms
@ -66,18 +59,24 @@ def transforms_imagenet_train(
* a portion of the data through the secondary transform
* normalizes and converts the branches above with the third, final transform
"""
scale = tuple(scale or (0.08, 1.0)) # default imagenet scale range
ratio = tuple(ratio or (3./4., 4./3.)) # default imagenet ratio range
scale_range = tuple(aug_cfg.scale_range or (0.08, 1.0)) # default imagenet scale range
ratio_range = tuple(aug_cfg.ratio_range or (3. / 4., 4. / 3.)) # default imagenet ratio range
# 'primary' train transforms include random resize + crop w/ optional horizontal and vertical flipping aug.
# This is the core of standard ImageNet ResNet and Inception pre-processing
primary_tfl = [
RandomResizedCropAndInterpolation(img_size, scale=scale, ratio=ratio, interpolation=interpolation)]
if hflip > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=hflip)]
if vflip > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=vflip)]
RandomResizedCropAndInterpolation(img_size, scale=scale_range, ratio=ratio_range, interpolation=interpolation)]
if aug_cfg.hflip_prob > 0.:
primary_tfl += [transforms.RandomHorizontalFlip(p=aug_cfg.hflip_prob)]
if aug_cfg.vflip_prob > 0.:
primary_tfl += [transforms.RandomVerticalFlip(p=aug_cfg.vflip_prob)]
# 'secondary' transform stage includes either color jitter (could add lighting too) or auto-augmentations
# such as AutoAugment, RandAugment, AugMix, etc
secondary_tfl = []
if auto_augment:
assert isinstance(auto_augment, str)
if aug_cfg.auto_augment:
aa = aug_cfg.auto_augment
assert isinstance(aa, str)
if isinstance(img_size, (tuple, list)):
img_size_min = min(img_size)
else:
@ -87,58 +86,68 @@ def transforms_imagenet_train(
img_mean=tuple([min(255, round(255 * x)) for x in mean]),
)
if interpolation and interpolation != 'random':
aa_params['interpolation'] = str_to_pil_interp(interpolation)
if auto_augment.startswith('rand'):
secondary_tfl += [rand_augment_transform(auto_augment, aa_params)]
elif auto_augment.startswith('augmix'):
aa_params['interpolation'] = interpolation
if aa.startswith('rand'):
secondary_tfl += [rand_augment_transform(aa, aa_params)]
elif aa.startswith('augmix'):
aa_params['translate_pct'] = 0.3
secondary_tfl += [augment_and_mix_transform(auto_augment, aa_params)]
secondary_tfl += [augment_and_mix_transform(aa, aa_params)]
else:
secondary_tfl += [auto_augment_transform(auto_augment, aa_params)]
elif color_jitter is not None:
secondary_tfl += [auto_augment_transform(aa, aa_params)]
elif aug_cfg.color_jitter is not None:
# color jitter is enabled when not using AA
if isinstance(color_jitter, (list, tuple)):
cj = aug_cfg.color_jitter
if isinstance(cj, (list, tuple)):
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
# or 4 if also augmenting hue
assert len(color_jitter) in (3, 4)
assert len(cj) in (3, 4)
else:
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
color_jitter = (float(color_jitter),) * 3
secondary_tfl += [transforms.ColorJitter(*color_jitter)]
cj = (float(cj),) * 3
secondary_tfl += [transforms.ColorJitter(*cj)]
# 'final' transform stage includes normalization, followed by optional random erasing and tensor conversion
final_tfl = []
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
final_tfl += [ToNumpy()]
else:
if normalize:
final_tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
ToTensorNormalize(mean=mean, std=std)
]
if re_prob > 0.:
final_tfl.append(
RandomErasing(re_prob, mode=re_mode, max_count=re_count, num_splits=re_num_splits, device='cpu'))
if aug_cfg.re_prob > 0.:
final_tfl.append(RandomErasing(
aug_cfg.re_prob,
mode=aug_cfg.re_mode,
count=aug_cfg.re_count,
num_splits=aug_cfg.num_aug_splits))
else:
# when normalize disabled, (pre)fetcher and collate will handle tensor conversion and normalize
final_tfl += [ToNumpy()]
if separate:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
# return each transform stage separately
if compose:
return transforms.Compose(primary_tfl), transforms.Compose(secondary_tfl), transforms.Compose(final_tfl)
else:
return primary_tfl, secondary_tfl, final_tfl
else:
return transforms.Compose(primary_tfl + secondary_tfl + final_tfl)
tfl = primary_tfl + secondary_tfl + final_tfl
return transforms.Compose(tfl) if compose else tfl
def transforms_imagenet_eval(
img_size=224,
img_size: Union[int, Tuple[int]] = 224,
crop_pct=None,
interpolation='bilinear',
use_prefetcher=False,
mean=IMAGENET_DEFAULT_MEAN,
std=IMAGENET_DEFAULT_STD):
std=IMAGENET_DEFAULT_STD,
normalize=False,
compose=True,
):
crop_pct = crop_pct or DEFAULT_CROP_PCT
if isinstance(img_size, (tuple, list)):
assert len(img_size) == 2
if img_size[-1] == img_size[-2]:
# FIXME handle case where img is square and we want non aspect preserving resize
# fall-back to older behaviour so Resize scales to shortest edge if target is square
scale_size = int(math.floor(img_size[0] / crop_pct))
else:
@ -147,21 +156,87 @@ def transforms_imagenet_eval(
scale_size = int(math.floor(img_size / crop_pct))
tfl = [
transforms.Resize(scale_size, interpolation=str_to_interp_mode(interpolation)),
transforms.Resize(scale_size, transforms.InterpolationMode(interpolation)),
transforms.CenterCrop(img_size),
]
if use_prefetcher:
# prefetcher and collate will handle tensor conversion and norm
tfl += [ToNumpy()]
else:
if normalize:
tfl += [
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
ToTensorNormalize(mean=mean, std=std)
]
else:
# (pre)fetcher and collate will handle tensor conversion and normalize
tfl += [ToNumpy()]
return transforms.Compose(tfl)
return transforms.Compose(tfl) if compose else tfl
def create_transform_v2(
cfg=PreprocessCfg(),
is_training=False,
normalize=False,
separate=False,
compose=True,
tf_preprocessing=False,
):
"""
Args:
cfg: Pre-processing configuration
is_training (bool): Create transform for training pre-processing
normalize (bool): Enable normalization in transforms (otherwise handled by fetcher/pre-fetcher)
separate (bool): Return transforms separated into stages (for train)
compose (bool): Wrap transforms in transform.Compose(), returns list otherwise
tf_preprocessing (bool): Use Tensorflow pre-processing (for validation)
Returns:
"""
input_size = cfg.input_size
if isinstance(input_size, (tuple, list)):
img_size = input_size[-2:]
else:
img_size = input_size
if tf_preprocessing:
assert not normalize, "Expecting normalization to be handled in (pre)fetcher w/ TF preprocessing"
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
transform = TfPreprocessTransform(
is_training=is_training, size=img_size, interpolation=cfg.interpolation)
else:
if is_training and cfg.aug is None:
assert not separate, "Cannot perform split augmentation with no_aug"
transform = transforms_noaug_train(
img_size,
interpolation=cfg.interpolation,
normalize=normalize,
mean=cfg.mean,
std=cfg.std,
compose=compose,
)
elif is_training:
transform = transforms_imagenet_train(
img_size,
interpolation=cfg.interpolation,
mean=cfg.mean,
std=cfg.std,
aug_cfg=cfg.aug,
normalize=normalize,
separate=separate,
compose=compose,
)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
transform = transforms_imagenet_eval(
img_size,
interpolation=cfg.interpolation,
crop_pct=cfg.crop_pct,
mean=cfg.mean,
std=cfg.std,
normalize=normalize,
compose=compose,
)
return transform
def create_transform(
@ -191,6 +266,7 @@ def create_transform(
else:
img_size = input_size
normalize_in_transform = not use_prefetcher
if tf_preprocessing and use_prefetcher:
assert not separate, "Separate transforms not supported for TF preprocessing"
from timm.data.tf_preprocessing import TfPreprocessTransform
@ -202,35 +278,41 @@ def create_transform(
transform = transforms_noaug_train(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std)
std=std,
normalize=normalize_in_transform,
)
elif is_training:
transform = transforms_imagenet_train(
img_size,
scale=scale,
ratio=ratio,
hflip=hflip,
vflip=vflip,
aug_cfg = AugCfg(
scale_range=scale,
ratio_range=ratio,
hflip_prob=hflip,
vflip_prob=vflip,
color_jitter=color_jitter,
auto_augment=auto_augment,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
re_prob=re_prob,
re_mode=re_mode,
re_count=re_count,
re_num_splits=re_num_splits,
separate=separate)
num_aug_splits=re_num_splits,
)
transform = transforms_imagenet_train(
img_size,
interpolation=interpolation,
mean=mean,
std=std,
aug_cfg=aug_cfg,
normalize=normalize_in_transform,
separate=separate
)
else:
assert not separate, "Separate transforms not supported for validation preprocessing"
assert not separate, "Separate transforms not supported for validation pre-processing"
transform = transforms_imagenet_eval(
img_size,
interpolation=interpolation,
use_prefetcher=use_prefetcher,
mean=mean,
std=std,
crop_pct=crop_pct)
crop_pct=crop_pct,
normalize=normalize_in_transform,
)
return transform

@ -169,6 +169,31 @@ def convert_sync_batchnorm(module, process_group=None):
return module_output
def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
# This is a workaround for some odd behaviour running on PyTorch XLA w/ TPUs.
x_shape = x.shape
x_dtype = x.dtype
if flatten:
norm_shape = (x_shape[0], groups, -1)
reduce_dim = -1
else:
norm_shape = (x_shape[0], groups, x_shape[1] // groups) + x_shape[2:]
reduce_dim = tuple(range(2, x.ndim + 1))
affine_shape = (1, -1) + (1,) * (x.ndim - 2)
x = x.reshape(norm_shape)
# x = x.to(torch.float32) # for testing w/ AMP
xm = x.mean(dim=reduce_dim, keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU
var = (x.square().mean(dim=reduce_dim, keepdim=True) - xm.square()).clamp(0)
else:
var = (x - xm).square().mean(dim=reduce_dim, keepdim=True)
x = (x - xm.expand(norm_shape)) / var.add(eps).sqrt().expand(norm_shape)
x = x.reshape(x_shape) * w.view(affine_shape) + b.view(affine_shape)
# x = x.to(x_dtype) # for testing w/ AMP
return x
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
@ -193,10 +218,13 @@ class GroupNormAct(nn.GroupNorm):
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if False: # FIXME TPU temporary while resolving some performance issues
x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x

@ -829,11 +829,7 @@ def checkpoint_filter_fn(state_dict, model):
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
SwinTransformerV2Cr, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs
)
model = build_model_with_cfg(SwinTransformerV2Cr, variant, pretrained, **kwargs)
return model

@ -6,3 +6,4 @@ from .step_lr import StepLRScheduler
from .tanh_lr import TanhLRScheduler
from .scheduler_factory import create_scheduler
from .scheduler import Scheduler

@ -65,14 +65,16 @@ class Scheduler:
return None
def step(self, epoch: int, metric: float = None) -> None:
self.metric = metric
if metric is not None:
self.metric = metric
values = self.get_epoch_values(epoch)
if values is not None:
values = self._add_noise(values, epoch)
self.update_groups(values)
def step_update(self, num_updates: int, metric: float = None):
self.metric = metric
if metric is not None:
self.metric = metric
values = self.get_update_values(num_updates)
if values is not None:
values = self._add_noise(values, num_updates)

@ -108,7 +108,8 @@ class CheckpointSaver:
save_state['arch'] = self.args.model
save_state['args'] = self.args
if self.amp_scaler is not None:
save_state[self.amp_scaler.state_dict_key] = self.amp_scaler.state_dict()
amp_key = getattr(self.amp_scaler, 'state_dict_key', 'amp_scaler')
save_state[amp_key] = self.amp_scaler.state_dict()
if self.model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(self.model_ema, self.unwrap_fn)
if metric is not None:

@ -3,7 +3,11 @@ import torch
from timm.utils.agc import adaptive_clip_grad
def dispatch_clip_grad(parameters, value: float, mode: str = 'norm', norm_type: float = 2.0):
def dispatch_clip_grad(
parameters,
value: float,
mode: str = 'norm',
norm_type: float = 2.0):
""" Dispatch to gradient clipping method
Args:

@ -21,8 +21,8 @@ def distribute_bn(model, world_size, reduce=False):
if ('running_mean' in bn_name) or ('running_var' in bn_name):
if reduce:
# average bn stats across whole group
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
torch.distributed.all_reduce_recursive(bn_buf, op=dist.ReduceOp.SUM)
bn_buf /= float(world_size)
else:
# broadcast bn stats from rank 0 to whole group
torch.distributed.broadcast(bn_buf, 0)
torch.distributed.broadcast_recursive(bn_buf, 0)

@ -7,14 +7,16 @@ import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma
_SUB_MODULE_ATTR = ('module', 'model')
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
else:
return model.module if hasattr(model, 'module') else model
def unwrap_model(model, recursive=True):
for attr in _SUB_MODULE_ATTR:
sub_module = getattr(model, attr, None)
if sub_module is not None:
return unwrap_model(sub_module) if recursive else sub_module
return model
def get_state_dict(model, unwrap_fn=unwrap_model):

@ -1 +1 @@
__version__ = '0.6.9'
__version__ = '0.8.2.dev0'

1176
train.py

File diff suppressed because it is too large Load Diff

@ -11,41 +11,20 @@ import argparse
import os
import csv
import glob
import json
import time
import logging
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
decay_batch_step, check_batch_size_retry
has_apex = False
try:
from apex import amp
has_apex = True
except ImportError:
pass
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
torch.backends.cudnn.benchmark = True
from timm.bits import initialize_device, Tracker, Monitor, AccuracyTopK, AvgTensor
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_transform_v2, create_loader_v2, resolve_data_config, RealLabelsImagenet, \
PreprocessCfg
from timm.utils import natural_key, setup_default_logging
_logger = logging.getLogger('validate')
@ -58,8 +37,8 @@ parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
help='model architecture (default: resnet50)')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -68,8 +47,6 @@ parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
@ -84,76 +61,43 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
parser.add_argument('--log-freq', default=20, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
# parser.add_argument('--num-gpu', type=int, default=1,
# help='Number of GPUS to use')
parser.add_argument('--test-pool', dest='test_pool', action='store_true',
help='enable test time pool')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--amp', action='store_true', default=False,
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
parser.add_argument('--apex-amp', action='store_true', default=False,
help='Use NVIDIA Apex AMP mixed precision')
parser.add_argument('--native-amp', action='store_true', default=False,
help='Use Native Torch AMP mixed precision')
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
help='convert model torchscript for inference')
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
help='Output csv file for validation results (summary)')
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
help='Real labels JSON file for imagenet evaluation')
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
help='Valid label indices txt file for validation of partial label space')
parser.add_argument('--retry', default=False, action='store_true',
help='Enable batch size decay & retry for single model validation')
parser.add_argument('--force-cpu', action='store_true', default=False,
help='Force CPU to be used even if HW accelerator exists.')
def validate(args):
# might as well try to validate something
args.pretrained = args.pretrained or not args.checkpoint
args.prefetcher = not args.no_prefetcher
amp_autocast = suppress # do nothing
if args.amp:
if has_native_amp:
args.native_amp = True
elif has_apex:
args.apex_amp = True
else:
_logger.warning("Neither APEX or Native Torch AMP is available.")
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
if args.native_amp:
amp_autocast = torch.cuda.amp.autocast
_logger.info('Validating in mixed precision with native PyTorch AMP.')
elif args.apex_amp:
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
else:
_logger.info('Validating in float32. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
if args.fast_norm:
set_fast_norm()
dev_env = initialize_device(force_cpu=args.force_cpu, amp=args.amp)
# create model
model = create_model(
@ -173,34 +117,18 @@ def validate(args):
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(
vars(args),
model=model,
use_test_size=not args.use_train_size,
verbose=True
)
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
if args.torchscript:
torch.jit.optimized_execution(True)
model = torch.jit.script(model)
if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
model = model.cuda()
if args.apex_amp:
model = amp.initialize(model, opt_level='O1')
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
criterion = nn.CrossEntropyLoss().cuda()
# FIXME device
model, criterion = dev_env.to_device(model, nn.CrossEntropyLoss())
model.to(dev_env.device)
dataset = create_dataset(
root=args.data, name=args.dataset, split=args.split,
@ -218,113 +146,82 @@ def validate(args):
else:
real_labels = None
crop_pct = 1.0 if test_time_pool else data_config['crop_pct']
loader = create_loader(
dataset,
eval_pp_cfg = PreprocessCfg(
input_size=data_config['input_size'],
batch_size=args.batch_size,
use_prefetcher=args.prefetcher,
interpolation=data_config['interpolation'],
crop_pct=1.0 if test_time_pool else data_config['crop_pct'],
mean=data_config['mean'],
std=data_config['std'],
)
loader = create_loader_v2(
dataset,
batch_size=args.batch_size,
is_training=False,
pp_cfg=eval_pp_cfg,
num_workers=args.workers,
crop_pct=crop_pct,
pin_memory=args.pin_mem,
tf_preprocessing=args.tf_preprocessing)
pin_memory=args.pin_mem)
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
logger = Monitor(logger=_logger)
tracker = Tracker()
losses = AvgTensor()
accuracy = AccuracyTopK(dev_env=dev_env)
model.eval()
num_steps = len(loader)
with torch.no_grad():
# warmup, reduce variability of first batch time, especially for comparing torchscript vs non
input = torch.randn((args.batch_size,) + tuple(data_config['input_size'])).cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
with amp_autocast():
model(input)
end = time.time()
for batch_idx, (input, target) in enumerate(loader):
if args.no_prefetcher:
target = target.cuda()
input = input.cuda()
if args.channels_last:
input = input.contiguous(memory_format=torch.channels_last)
# compute output
with amp_autocast():
output = model(input)
tracker.mark_iter()
for step_idx, (sample, target) in enumerate(loader):
last_step = step_idx == num_steps - 1
tracker.mark_iter_data_end()
with dev_env.autocast():
output = model(sample)
if valid_labels is not None:
output = output[:, valid_labels]
loss = criterion(output, target)
if dev_env.type_xla:
dev_env.mark_step()
elif dev_env.type_cuda:
dev_env.synchronize()
tracker.mark_iter_step_end()
if real_labels is not None:
real_labels.add_result(output)
# measure accuracy and record loss
acc1, acc5 = accuracy(output.detach(), target, topk=(1, 5))
losses.update(loss.item(), input.size(0))
top1.update(acc1.item(), input.size(0))
top5.update(acc5.item(), input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
'Acc@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) '
'Acc@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format(
batch_idx, len(loader), batch_time=batch_time,
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
losses.update(loss.detach(), sample.size(0))
accuracy.update(output.detach(), target)
tracker.mark_iter()
if last_step or step_idx % args.log_freq == 0:
top1, top5 = accuracy.compute().values()
loss_avg = losses.compute()
logger.log_step(
phase='eval',
step_idx=step_idx,
num_steps=num_steps,
rate=(tracker.get_last_iter_rate(output.shape[0]), tracker.get_avg_iter_rate(args.batch_size)),
loss=loss_avg.item(),
top1=top1.item(),
top5=top5.item(),
)
if real_labels is not None:
# real labels mode replaces topk values at the end
top1a, top5a = real_labels.get_accuracy(k=1), real_labels.get_accuracy(k=5)
else:
top1a, top5a = top1.avg, top5.avg
top1a, top5a = accuracy.compute().values()
top1a, top5a = top1a.item(), top5a.item()
results = OrderedDict(
model=args.model,
top1=round(top1a, 4), top1_err=round(100 - top1a, 4),
top5=round(top5a, 4), top5_err=round(100 - top5a, 4),
param_count=round(param_count / 1e6, 2),
img_size=data_config['input_size'][-1],
crop_pct=crop_pct,
cropt_pct=eval_pp_cfg.crop_pct,
interpolation=data_config['interpolation'])
logger.log_phase(phase='eval', name_map={'top1': 'Acc@1', 'top5': 'Acc@5'}, **results)
_logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def _try_run(args, initial_batch_size):
batch_size = initial_batch_size
results = OrderedDict()
error_str = 'Unknown'
while batch_size:
args.batch_size = batch_size * args.num_gpu # multiply by num-gpu for DataParallel case
try:
torch.cuda.empty_cache()
results = validate(args)
return results
except RuntimeError as e:
error_str = str(e)
_logger.error(f'"{error_str}" while running validation.')
if not check_batch_size_retry(error_str):
break
batch_size = decay_batch_step(batch_size)
_logger.warning(f'Reducing batch size to {batch_size} for retry.')
results['error'] = error_str
_logger.error(f'{args.model} failed to validate ({error_str}).')
return results
@ -343,7 +240,7 @@ def main():
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino'])
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter
@ -360,28 +257,36 @@ def main():
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
results = []
try:
initial_batch_size = args.batch_size
start_batch_size = args.batch_size
for m, c in model_cfgs:
batch_size = start_batch_size
args.model = m
args.checkpoint = c
r = _try_run(args, initial_batch_size)
if 'error' in r:
continue
args.num_gpu = 1 # FIXME support data-parallel?
result = OrderedDict(model=args.model)
r = {}
while not r and batch_size >= 1:
try:
args.batch_size = batch_size
print('Validating with batch size: %d' % args.batch_size)
r = validate(args)
except RuntimeError as e:
if batch_size <= args.num_gpu:
print("Validation failed with no ability to reduce batch size. Exiting.")
raise e
batch_size = max(batch_size // 2, args.num_gpu)
print("Validation failed, reducing batch size by 50%")
result.update(r)
if args.checkpoint:
r['checkpoint'] = args.checkpoint
results.append(r)
result['checkpoint'] = args.checkpoint
results.append(result)
except KeyboardInterrupt as e:
pass
results = sorted(results, key=lambda x: x['top1'], reverse=True)
if len(results):
write_results(results_file, results)
else:
if args.retry:
results = _try_run(args, args.batch_size)
else:
results = validate(args)
# output results in JSON to stdout w/ delimiter for runner script
print(f'--result\n{json.dumps(results, indent=4)}')
validate(args)
def write_results(results_file, results):

Loading…
Cancel
Save