Compare commits
3 Commits
Author | SHA1 | Date |
---|---|---|
Ross Wightman | abb206c916 | 1 year ago |
Ross Wightman | 8016ac3569 | 1 year ago |
Ross Wightman | 781f174fe6 | 1 year ago |
@ -1,112 +0,0 @@
|
||||
*This guideline is very much a work-in-progress.*
|
||||
|
||||
Contriubtions to `timm` for code, documentation, tests are more than welcome!
|
||||
|
||||
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
|
||||
|
||||
# Coding style
|
||||
|
||||
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
|
||||
|
||||
A few specific differences from Google style (or black)
|
||||
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
|
||||
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
|
||||
|
||||
Example, from Google guide, but this is a NO here:
|
||||
```
|
||||
# Aligned with opening delimiter.
|
||||
foo = long_function_name(var_one, var_two,
|
||||
var_three, var_four)
|
||||
meal = (spam,
|
||||
beans)
|
||||
|
||||
# Aligned with opening delimiter in a dictionary.
|
||||
foo = {
|
||||
'long_dictionary_key': value1 +
|
||||
value2,
|
||||
...
|
||||
}
|
||||
```
|
||||
This is YES:
|
||||
|
||||
```
|
||||
# 4-space hanging indent; nothing on first line,
|
||||
# closing parenthesis on a new line.
|
||||
foo = long_function_name(
|
||||
var_one, var_two, var_three,
|
||||
var_four
|
||||
)
|
||||
meal = (
|
||||
spam,
|
||||
beans,
|
||||
)
|
||||
|
||||
# 4-space hanging indent in a dictionary.
|
||||
foo = {
|
||||
'long_dictionary_key':
|
||||
long_dictionary_value,
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
|
||||
|
||||
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
|
||||
|
||||
```
|
||||
black --skip-string-normalization --line-length 120 <path-to-file>
|
||||
```
|
||||
|
||||
Avoid formatting code that is unrelated to your PR though.
|
||||
|
||||
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
|
||||
|
||||
# Documentation
|
||||
|
||||
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
|
||||
|
||||
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
|
||||
|
||||
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
|
||||
|
||||
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
|
||||
|
||||
# Installation
|
||||
|
||||
Create a Python virtual environment using Python 3.10. Inside the environment, install the following test dependencies:
|
||||
|
||||
```
|
||||
python -m pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
|
||||
```
|
||||
|
||||
Install `torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
|
||||
|
||||
Then install the remaining dependencies:
|
||||
|
||||
```
|
||||
python -m pip install -r requirements.txt
|
||||
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
|
||||
python -m pip install -e .
|
||||
```
|
||||
|
||||
## Unit tests
|
||||
|
||||
Run the tests using:
|
||||
|
||||
```
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
|
||||
|
||||
```
|
||||
pytest -k "substring-to-match" -n 4 tests/
|
||||
```
|
||||
|
||||
## Building documentation
|
||||
|
||||
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
|
||||
|
||||
# Questions
|
||||
|
||||
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).
|
@ -1,3 +1,2 @@
|
||||
include timm/models/_pruned/*.txt
|
||||
include timm/data/_info/*.txt
|
||||
include timm/data/_info/*.json
|
||||
include timm/models/pruned/*.txt
|
||||
|
||||
|
@ -1,14 +0,0 @@
|
||||
# Hugging Face Timm Docs
|
||||
|
||||
## Getting Started
|
||||
|
||||
```
|
||||
pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder
|
||||
pip install watchdog black
|
||||
```
|
||||
|
||||
## Preview the Docs Locally
|
||||
|
||||
```
|
||||
doc-builder preview timm hfdocs/source
|
||||
```
|
@ -1,160 +1,149 @@
|
||||
- sections:
|
||||
- local: index
|
||||
title: Home
|
||||
- local: quickstart
|
||||
title: Quickstart
|
||||
- local: installation
|
||||
title: Installation
|
||||
title: Get started
|
||||
- sections:
|
||||
- local: feature_extraction
|
||||
title: Using Pretrained Models as Feature Extractors
|
||||
- local: training_script
|
||||
title: Training With The Official Training Script
|
||||
- local: hf_hub
|
||||
title: Share and Load Models from the 🤗 Hugging Face Hub
|
||||
title: Tutorials
|
||||
- sections:
|
||||
title: Pytorch Image Models (timm)
|
||||
- local: models
|
||||
title: Model Summaries
|
||||
- local: results
|
||||
title: Results
|
||||
- local: models/adversarial-inception-v3
|
||||
title: Adversarial Inception v3
|
||||
- local: models/advprop
|
||||
title: AdvProp (EfficientNet)
|
||||
- local: models/big-transfer
|
||||
title: Big Transfer (BiT)
|
||||
- local: models/csp-darknet
|
||||
title: CSP-DarkNet
|
||||
- local: models/csp-resnet
|
||||
title: CSP-ResNet
|
||||
- local: models/csp-resnext
|
||||
title: CSP-ResNeXt
|
||||
- local: models/densenet
|
||||
title: DenseNet
|
||||
- local: models/dla
|
||||
title: Deep Layer Aggregation
|
||||
- local: models/dpn
|
||||
title: Dual Path Network (DPN)
|
||||
- local: models/ecaresnet
|
||||
title: ECA-ResNet
|
||||
- local: models/efficientnet
|
||||
title: EfficientNet
|
||||
- local: models/efficientnet-pruned
|
||||
title: EfficientNet (Knapsack Pruned)
|
||||
- local: models/ensemble-adversarial
|
||||
title: Ensemble Adversarial Inception ResNet v2
|
||||
- local: models/ese-vovnet
|
||||
title: ESE-VoVNet
|
||||
- local: models/fbnet
|
||||
title: FBNet
|
||||
- local: models/gloun-inception-v3
|
||||
title: (Gluon) Inception v3
|
||||
- local: models/gloun-resnet
|
||||
title: (Gluon) ResNet
|
||||
- local: models/gloun-resnext
|
||||
title: (Gluon) ResNeXt
|
||||
- local: models/gloun-senet
|
||||
title: (Gluon) SENet
|
||||
- local: models/gloun-seresnext
|
||||
title: (Gluon) SE-ResNeXt
|
||||
- local: models/gloun-xception
|
||||
title: (Gluon) Xception
|
||||
- local: models/hrnet
|
||||
title: HRNet
|
||||
- local: models/ig-resnext
|
||||
title: Instagram ResNeXt WSL
|
||||
- local: models/inception-resnet-v2
|
||||
title: Inception ResNet v2
|
||||
- local: models/inception-v3
|
||||
title: Inception v3
|
||||
- local: models/inception-v4
|
||||
title: Inception v4
|
||||
- local: models/legacy-se-resnet
|
||||
title: (Legacy) SE-ResNet
|
||||
- local: models/legacy-se-resnext
|
||||
title: (Legacy) SE-ResNeXt
|
||||
- local: models/legacy-senet
|
||||
title: (Legacy) SENet
|
||||
- local: models/mixnet
|
||||
title: MixNet
|
||||
- local: models/mnasnet
|
||||
title: MnasNet
|
||||
- local: models/mobilenet-v2
|
||||
title: MobileNet v2
|
||||
- local: models/mobilenet-v3
|
||||
title: MobileNet v3
|
||||
- local: models/nasnet
|
||||
title: NASNet
|
||||
- local: models/noisy-student
|
||||
title: Noisy Student (EfficientNet)
|
||||
- local: models/pnasnet
|
||||
title: PNASNet
|
||||
- local: models/regnetx
|
||||
title: RegNetX
|
||||
- local: models/regnety
|
||||
title: RegNetY
|
||||
- local: models/res2net
|
||||
title: Res2Net
|
||||
- local: models/res2next
|
||||
title: Res2NeXt
|
||||
- local: models/resnest
|
||||
title: ResNeSt
|
||||
- local: models/resnet
|
||||
title: ResNet
|
||||
- local: models/resnet-d
|
||||
title: ResNet-D
|
||||
- local: models/resnext
|
||||
title: ResNeXt
|
||||
- local: models/rexnet
|
||||
title: RexNet
|
||||
- local: models/se-resnet
|
||||
title: SE-ResNet
|
||||
- local: models/selecsls
|
||||
title: SelecSLS
|
||||
- local: models/seresnext
|
||||
title: SE-ResNeXt
|
||||
- local: models/skresnet
|
||||
title: SK-ResNet
|
||||
- local: models/skresnext
|
||||
title: SK-ResNeXt
|
||||
- local: models/spnasnet
|
||||
title: SPNASNet
|
||||
- local: models/ssl-resnet
|
||||
title: SSL ResNet
|
||||
- local: models/swsl-resnet
|
||||
title: SWSL ResNet
|
||||
- local: models/swsl-resnext
|
||||
title: SWSL ResNeXt
|
||||
- local: models/tf-efficientnet
|
||||
title: (Tensorflow) EfficientNet
|
||||
- local: models/tf-efficientnet-condconv
|
||||
title: (Tensorflow) EfficientNet CondConv
|
||||
- local: models/tf-efficientnet-lite
|
||||
title: (Tensorflow) EfficientNet Lite
|
||||
- local: models/tf-inception-v3
|
||||
title: (Tensorflow) Inception v3
|
||||
- local: models/tf-mixnet
|
||||
title: (Tensorflow) MixNet
|
||||
- local: models/tf-mobilenet-v3
|
||||
title: (Tensorflow) MobileNet v3
|
||||
- local: models/tresnet
|
||||
title: TResNet
|
||||
- local: models/wide-resnet
|
||||
title: Wide ResNet
|
||||
- local: models/xception
|
||||
title: Xception
|
||||
title: Model Pages
|
||||
isExpanded: false
|
||||
- sections:
|
||||
- local: reference/models
|
||||
title: Models
|
||||
- local: reference/data
|
||||
title: Data
|
||||
- local: reference/optimizers
|
||||
title: Optimizers
|
||||
- local: reference/schedulers
|
||||
title: Learning Rate Schedulers
|
||||
title: Reference
|
||||
- local: scripts
|
||||
title: Scripts
|
||||
- local: training_hparam_examples
|
||||
title: Training Examples
|
||||
- local: feature_extraction
|
||||
title: Feature Extraction
|
||||
- local: changes
|
||||
title: Recent Changes
|
||||
- local: archived_changes
|
||||
title: Archived Changes
|
||||
- local: model_pages
|
||||
title: Model Pages
|
||||
isExpanded: false
|
||||
sections:
|
||||
- local: models/adversarial-inception-v3
|
||||
title: Adversarial Inception v3
|
||||
- local: models/advprop
|
||||
title: AdvProp (EfficientNet)
|
||||
- local: models/big-transfer
|
||||
title: Big Transfer (BiT)
|
||||
- local: models/csp-darknet
|
||||
title: CSP-DarkNet
|
||||
- local: models/csp-resnet
|
||||
title: CSP-ResNet
|
||||
- local: models/csp-resnext
|
||||
title: CSP-ResNeXt
|
||||
- local: models/densenet
|
||||
title: DenseNet
|
||||
- local: models/dla
|
||||
title: Deep Layer Aggregation
|
||||
- local: models/dpn
|
||||
title: Dual Path Network (DPN)
|
||||
- local: models/ecaresnet
|
||||
title: ECA-ResNet
|
||||
- local: models/efficientnet
|
||||
title: EfficientNet
|
||||
- local: models/efficientnet-pruned
|
||||
title: EfficientNet (Knapsack Pruned)
|
||||
- local: models/ensemble-adversarial
|
||||
title: Ensemble Adversarial Inception ResNet v2
|
||||
- local: models/ese-vovnet
|
||||
title: ESE-VoVNet
|
||||
- local: models/fbnet
|
||||
title: FBNet
|
||||
- local: models/gloun-inception-v3
|
||||
title: (Gluon) Inception v3
|
||||
- local: models/gloun-resnet
|
||||
title: (Gluon) ResNet
|
||||
- local: models/gloun-resnext
|
||||
title: (Gluon) ResNeXt
|
||||
- local: models/gloun-senet
|
||||
title: (Gluon) SENet
|
||||
- local: models/gloun-seresnext
|
||||
title: (Gluon) SE-ResNeXt
|
||||
- local: models/gloun-xception
|
||||
title: (Gluon) Xception
|
||||
- local: models/hrnet
|
||||
title: HRNet
|
||||
- local: models/ig-resnext
|
||||
title: Instagram ResNeXt WSL
|
||||
- local: models/inception-resnet-v2
|
||||
title: Inception ResNet v2
|
||||
- local: models/inception-v3
|
||||
title: Inception v3
|
||||
- local: models/inception-v4
|
||||
title: Inception v4
|
||||
- local: models/legacy-se-resnet
|
||||
title: (Legacy) SE-ResNet
|
||||
- local: models/legacy-se-resnext
|
||||
title: (Legacy) SE-ResNeXt
|
||||
- local: models/legacy-senet
|
||||
title: (Legacy) SENet
|
||||
- local: models/mixnet
|
||||
title: MixNet
|
||||
- local: models/mnasnet
|
||||
title: MnasNet
|
||||
- local: models/mobilenet-v2
|
||||
title: MobileNet v2
|
||||
- local: models/mobilenet-v3
|
||||
title: MobileNet v3
|
||||
- local: models/nasnet
|
||||
title: NASNet
|
||||
- local: models/noisy-student
|
||||
title: Noisy Student (EfficientNet)
|
||||
- local: models/pnasnet
|
||||
title: PNASNet
|
||||
- local: models/regnetx
|
||||
title: RegNetX
|
||||
- local: models/regnety
|
||||
title: RegNetY
|
||||
- local: models/res2net
|
||||
title: Res2Net
|
||||
- local: models/res2next
|
||||
title: Res2NeXt
|
||||
- local: models/resnest
|
||||
title: ResNeSt
|
||||
- local: models/resnet
|
||||
title: ResNet
|
||||
- local: models/resnet-d
|
||||
title: ResNet-D
|
||||
- local: models/resnext
|
||||
title: ResNeXt
|
||||
- local: models/rexnet
|
||||
title: RexNet
|
||||
- local: models/se-resnet
|
||||
title: SE-ResNet
|
||||
- local: models/selecsls
|
||||
title: SelecSLS
|
||||
- local: models/seresnext
|
||||
title: SE-ResNeXt
|
||||
- local: models/skresnet
|
||||
title: SK-ResNet
|
||||
- local: models/skresnext
|
||||
title: SK-ResNeXt
|
||||
- local: models/spnasnet
|
||||
title: SPNASNet
|
||||
- local: models/ssl-resnet
|
||||
title: SSL ResNet
|
||||
- local: models/swsl-resnet
|
||||
title: SWSL ResNet
|
||||
- local: models/swsl-resnext
|
||||
title: SWSL ResNeXt
|
||||
- local: models/tf-efficientnet
|
||||
title: (Tensorflow) EfficientNet
|
||||
- local: models/tf-efficientnet-condconv
|
||||
title: (Tensorflow) EfficientNet CondConv
|
||||
- local: models/tf-efficientnet-lite
|
||||
title: (Tensorflow) EfficientNet Lite
|
||||
- local: models/tf-inception-v3
|
||||
title: (Tensorflow) Inception v3
|
||||
- local: models/tf-mixnet
|
||||
title: (Tensorflow) MixNet
|
||||
- local: models/tf-mobilenet-v3
|
||||
title: (Tensorflow) MobileNet v3
|
||||
- local: models/tresnet
|
||||
title: TResNet
|
||||
- local: models/wide-resnet
|
||||
title: Wide ResNet
|
||||
- local: models/xception
|
||||
title: Xception
|
||||
title: Get started
|
||||
|
||||
|
@ -0,0 +1,418 @@
|
||||
# Archived Changes
|
||||
|
||||
### July 12, 2021
|
||||
|
||||
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)
|
||||
|
||||
### July 5-9, 2021
|
||||
|
||||
* Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res)
|
||||
* top-1 82.34 @ 288x288 and 82.54 @ 320x320
|
||||
* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models.
|
||||
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
|
||||
* `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426
|
||||
|
||||
### June 23, 2021
|
||||
|
||||
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
|
||||
|
||||
### June 20, 2021
|
||||
|
||||
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
|
||||
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
|
||||
* See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights
|
||||
* Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
|
||||
* Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
|
||||
* `vit_deit_*` renamed to just `deit_*`
|
||||
* Remove my old small model, replace with DeiT compatible small w/ AugReg weights
|
||||
* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
|
||||
* Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
|
||||
* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
|
||||
* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
|
||||
* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
|
||||
* weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
|
||||
* eps values adjusted, will be slight differences but should be quite close
|
||||
* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
|
||||
* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
|
||||
* Please report any regressions, this PR touched quite a few models.
|
||||
|
||||
### June 8, 2021
|
||||
|
||||
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
|
||||
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
|
||||
* NFNet inspired block layout with quad layer stem and no maxpool
|
||||
* Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
|
||||
|
||||
### May 25, 2021
|
||||
|
||||
* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models
|
||||
* Cleanup input_size/img_size override handling and testing for all vision transformer models
|
||||
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
|
||||
|
||||
### May 14, 2021
|
||||
|
||||
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
|
||||
* 1k trained variants: `tf_efficientnetv2_s/m/l`
|
||||
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
|
||||
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
|
||||
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
|
||||
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
|
||||
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
|
||||
|
||||
### May 5, 2021
|
||||
|
||||
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
|
||||
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
|
||||
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
|
||||
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
|
||||
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
|
||||
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
|
||||
* Update ByoaNet attention modles
|
||||
* Improve SA module inits
|
||||
* Hack together experimental stand-alone Swin based attn module and `swinnet`
|
||||
* Consistent '26t' model defs for experiments.
|
||||
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
|
||||
* WandB logging support
|
||||
|
||||
### April 13, 2021
|
||||
|
||||
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer
|
||||
|
||||
### April 12, 2021
|
||||
|
||||
* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256.
|
||||
* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training.
|
||||
* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs
|
||||
* Lambda Networks - https://arxiv.org/abs/2102.08602
|
||||
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
|
||||
* Halo Nets - https://arxiv.org/abs/2103.12731
|
||||
* Adabelief optimizer contributed by Juntang Zhuang
|
||||
|
||||
### April 1, 2021
|
||||
|
||||
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
|
||||
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
|
||||
* Merged distilled variant into main for torchscript compatibility
|
||||
* Some `timm` cleanup/style tweaks and weights have hub download support
|
||||
* Cleanup Vision Transformer (ViT) models
|
||||
* Merge distilled (DeiT) model into main so that torchscript can work
|
||||
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
|
||||
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
|
||||
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
|
||||
* nn.Sequential for block stack (does not break downstream compat)
|
||||
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
|
||||
* Add RegNetY-160 weights from DeiT teacher model
|
||||
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
|
||||
* Some fixes/improvements for TFDS dataset wrapper
|
||||
|
||||
### March 7, 2021
|
||||
|
||||
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
|
||||
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
|
||||
|
||||
### Feb 18, 2021
|
||||
|
||||
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
|
||||
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
|
||||
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
|
||||
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
|
||||
* Matching the original pre-processing as closely as possible I get these results:
|
||||
* `dm_nfnet_f6` - 86.352
|
||||
* `dm_nfnet_f5` - 86.100
|
||||
* `dm_nfnet_f4` - 85.834
|
||||
* `dm_nfnet_f3` - 85.676
|
||||
* `dm_nfnet_f2` - 85.178
|
||||
* `dm_nfnet_f1` - 84.696
|
||||
* `dm_nfnet_f0` - 83.464
|
||||
|
||||
### Feb 16, 2021
|
||||
|
||||
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
|
||||
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
|
||||
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
|
||||
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
|
||||
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
|
||||
|
||||
### Feb 12, 2021
|
||||
|
||||
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
|
||||
|
||||
### Feb 10, 2021
|
||||
|
||||
* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
|
||||
* GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
|
||||
* RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py`
|
||||
* classic VGG (from torchvision, impl in `vgg`)
|
||||
* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models
|
||||
* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not.
|
||||
* Fix a few bugs introduced since last pypi release
|
||||
|
||||
### Feb 8, 2021
|
||||
|
||||
* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352.
|
||||
* `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256
|
||||
* `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256
|
||||
* `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320
|
||||
* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed).
|
||||
* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test.
|
||||
|
||||
### Jan 30, 2021
|
||||
|
||||
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)
|
||||
|
||||
### Jan 25, 2021
|
||||
|
||||
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
|
||||
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
|
||||
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
|
||||
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
|
||||
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
|
||||
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
|
||||
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
|
||||
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
|
||||
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
|
||||
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
|
||||
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
|
||||
|
||||
### Jan 3, 2021
|
||||
|
||||
* Add SE-ResNet-152D weights
|
||||
* 256x256 val, 0.94 crop top-1 - 83.75
|
||||
* 320x320 val, 1.0 crop - 84.36
|
||||
* Update results files
|
||||
|
||||
### Dec 18, 2020
|
||||
|
||||
* Add ResNet-101D, ResNet-152D, and ResNet-200D weights trained @ 256x256
|
||||
* 256x256 val, 0.94 crop (top-1) - 101D (82.33), 152D (83.08), 200D (83.25)
|
||||
* 288x288 val, 1.0 crop - 101D (82.64), 152D (83.48), 200D (83.76)
|
||||
* 320x320 val, 1.0 crop - 101D (83.00), 152D (83.66), 200D (84.01)
|
||||
|
||||
### Dec 7, 2020
|
||||
|
||||
* Simplify EMA module (ModelEmaV2), compatible with fully torchscripted models
|
||||
* Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript
|
||||
* PyPi release @ 0.3.2 (needed by EfficientDet)
|
||||
|
||||
|
||||
### Oct 30, 2020
|
||||
|
||||
* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue.
|
||||
* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16.
|
||||
* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated.
|
||||
* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage.
|
||||
* PyPi release @ 0.3.0 version!
|
||||
|
||||
### Oct 26, 2020
|
||||
|
||||
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
|
||||
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
|
||||
* ViT-B/16 - 84.2
|
||||
* ViT-B/32 - 81.7
|
||||
* ViT-L/16 - 85.2
|
||||
* ViT-L/32 - 81.5
|
||||
|
||||
### Oct 21, 2020
|
||||
|
||||
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
|
||||
|
||||
### Oct 13, 2020
|
||||
|
||||
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
|
||||
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
|
||||
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
|
||||
* Pip release, doc updates pending a few more changes...
|
||||
|
||||
### Sept 18, 2020
|
||||
|
||||
* New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D
|
||||
* Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D)
|
||||
|
||||
### Sept 3, 2020
|
||||
|
||||
* New weights
|
||||
* Wide-ResNet50 - 81.5 top-1 (vs 78.5 torchvision)
|
||||
* SEResNeXt50-32x4d - 81.3 top-1 (vs 79.1 cadene)
|
||||
* Support for native Torch AMP and channels_last memory format added to train/validate scripts (`--channels-last`, `--native-amp` vs `--apex-amp`)
|
||||
* Models tested with channels_last on latest NGC 20.08 container. AdaptiveAvgPool in attn layers changed to mean((2,3)) to work around bug with NHWC kernel.
|
||||
|
||||
### Aug 12, 2020
|
||||
|
||||
* New/updated weights from training experiments
|
||||
* EfficientNet-B3 - 82.1 top-1 (vs 81.6 for official with AA and 81.9 for AdvProp)
|
||||
* RegNetY-3.2GF - 82.0 top-1 (78.9 from official ver)
|
||||
* CSPResNet50 - 79.6 top-1 (76.6 from official ver)
|
||||
* Add CutMix integrated w/ Mixup. See [pull request](https://github.com/rwightman/pytorch-image-models/pull/218) for some usage examples
|
||||
* Some fixes for using pretrained weights with `in_chans` != 3 on several models.
|
||||
|
||||
### Aug 5, 2020
|
||||
|
||||
Universal feature extraction, new models, new weights, new test sets.
|
||||
* All models support the `features_only=True` argument for `create_model` call to return a network that extracts feature maps from the deepest layer at each stride.
|
||||
* New models
|
||||
* CSPResNet, CSPResNeXt, CSPDarkNet, DarkNet
|
||||
* ReXNet
|
||||
* (Modified Aligned) Xception41/65/71 (a proper port of TF models)
|
||||
* New trained weights
|
||||
* SEResNet50 - 80.3 top-1
|
||||
* CSPDarkNet53 - 80.1 top-1
|
||||
* CSPResNeXt50 - 80.0 top-1
|
||||
* DPN68b - 79.2 top-1
|
||||
* EfficientNet-Lite0 (non-TF ver) - 75.5 (submitted by [@hal-314](https://github.com/hal-314))
|
||||
* Add 'real' labels for ImageNet and ImageNet-Renditions test set, see [`results/README.md`](results/README.md)
|
||||
* Test set ranking/top-n diff script by [@KushajveerSingh](https://github.com/KushajveerSingh)
|
||||
* Train script and loader/transform tweaks to punch through more aug arguments
|
||||
* README and documentation overhaul. See initial (WIP) documentation at https://rwightman.github.io/pytorch-image-models/
|
||||
* adamp and sgdp optimizers added by [@hellbell](https://github.com/hellbell)
|
||||
|
||||
### June 11, 2020
|
||||
|
||||
Bunch of changes:
|
||||
* DenseNet models updated with memory efficient addition from torchvision (fixed a bug), blur pooling and deep stem additions
|
||||
* VoVNet V1 and V2 models added, 39 V2 variant (ese_vovnet_39b) trained to 79.3 top-1
|
||||
* Activation factory added along with new activations:
|
||||
* select act at model creation time for more flexibility in using activations compatible with scripting or tracing (ONNX export)
|
||||
* hard_mish (experimental) added with memory-efficient grad, along with ME hard_swish
|
||||
* context mgr for setting exportable/scriptable/no_jit states
|
||||
* Norm + Activation combo layers added with initial trial support in DenseNet and VoVNet along with impl of EvoNorm and InplaceAbn wrapper that fit the interface
|
||||
* Torchscript works for all but two of the model types as long as using Pytorch 1.5+, tests added for this
|
||||
* Some import cleanup and classifier reset changes, all models will have classifier reset to nn.Identity on reset_classifer(0) call
|
||||
* Prep for 0.1.28 pip release
|
||||
|
||||
### May 12, 2020
|
||||
|
||||
* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955))
|
||||
|
||||
### May 3, 2020
|
||||
|
||||
* Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo)
|
||||
|
||||
### May 1, 2020
|
||||
|
||||
* Merged a number of execellent contributions in the ResNet model family over the past month
|
||||
* BlurPool2D and resnetblur models initiated by [Chris Ha](https://github.com/VRandme), I trained resnetblur50 to 79.3.
|
||||
* TResNet models and SpaceToDepth, AntiAliasDownsampleLayer layers by [mrT23](https://github.com/mrT23)
|
||||
* ecaresnet (50d, 101d, light) models and two pruned variants using pruning as per (https://arxiv.org/abs/2002.08258) by [Yonathan Aflalo](https://github.com/yoniaflalo)
|
||||
* 200 pretrained models in total now with updated results csv in results folder
|
||||
|
||||
### April 5, 2020
|
||||
|
||||
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
|
||||
* 3.5M param MobileNet-V2 100 @ 73%
|
||||
* 4.5M param MobileNet-V2 110d @ 75%
|
||||
* 6.1M param MobileNet-V2 140 @ 76.5%
|
||||
* 5.8M param MobileNet-V2 120d @ 77.3%
|
||||
|
||||
### March 18, 2020
|
||||
|
||||
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
|
||||
|
||||
### April 5, 2020
|
||||
|
||||
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
|
||||
* 3.5M param MobileNet-V2 100 @ 73%
|
||||
* 4.5M param MobileNet-V2 110d @ 75%
|
||||
* 6.1M param MobileNet-V2 140 @ 76.5%
|
||||
* 5.8M param MobileNet-V2 120d @ 77.3%
|
||||
|
||||
### March 18, 2020
|
||||
|
||||
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
|
||||
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
|
||||
|
||||
### Feb 29, 2020
|
||||
|
||||
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
|
||||
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
|
||||
* overall results similar to a bit better training from scratch on a few smaller models tried
|
||||
* performance early in training seems consistently improved but less difference by end
|
||||
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
|
||||
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
|
||||
|
||||
### Feb 18, 2020
|
||||
|
||||
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
|
||||
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
|
||||
* ResNet downsample paths now properly support dilation (output stride != 32) for avg_pool ('D' variant) and 3x3 (SENets) networks
|
||||
* Add Selective Kernel Nets on top of ResNet base, pretrained weights
|
||||
* skresnet18 - 73% top-1
|
||||
* skresnet34 - 76.9% top-1
|
||||
* skresnext50_32x4d (equiv to SKNet50) - 80.2% top-1
|
||||
* ECA and CECA (circular padding) attention layer contributed by [Chris Ha](https://github.com/VRandme)
|
||||
* CBAM attention experiment (not the best results so far, may remove)
|
||||
* Attention factory to allow dynamically selecting one of SE, ECA, CBAM in the `.se` position for all ResNets
|
||||
* Add DropBlock and DropPath (formerly DropConnect for EfficientNet/MobileNetv3) support to all ResNet variants
|
||||
* Full dataset results updated that incl NoisyStudent weights and 2 of the 3 SK weights
|
||||
|
||||
### Feb 12, 2020
|
||||
|
||||
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
|
||||
|
||||
### Feb 6, 2020
|
||||
|
||||
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
|
||||
|
||||
### Feb 1/2, 2020
|
||||
|
||||
* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization.
|
||||
* Update results csv files on all models for ImageNet validation and three other test sets
|
||||
* Push PyPi package update
|
||||
|
||||
### Jan 31, 2020
|
||||
|
||||
* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below.
|
||||
|
||||
### Jan 11/12, 2020
|
||||
|
||||
* Master may be a bit unstable wrt to training, these changes have been tested but not all combos
|
||||
* Implementations of AugMix added to existing RA and AA. Including numerous supporting pieces like JSD loss (Jensen-Shannon divergence + CE), and AugMixDataset
|
||||
* SplitBatchNorm adaptation layer added for implementing Auxiliary BN as per AdvProp paper
|
||||
* ResNet-50 AugMix trained model w/ 79% top-1 added
|
||||
* `seresnext26tn_32x4d` - 77.99 top-1, 93.75 top-5 added to tiered experiment, higher img/s than 't' and 'd'
|
||||
|
||||
### Jan 3, 2020
|
||||
|
||||
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
|
||||
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
|
||||
|
||||
### Dec 30, 2019
|
||||
|
||||
* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch
|
||||
|
||||
### Dec 28, 2019
|
||||
|
||||
* Add new model weights and training hparams (see Training Hparams section)
|
||||
* `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct
|
||||
* trained with RandAugment, ended up with an interesting but less than perfect result (see training section)
|
||||
* `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5
|
||||
* deep stem (32, 32, 64), avgpool downsample
|
||||
* stem/dowsample from bag-of-tricks paper
|
||||
* `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5
|
||||
* deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant)
|
||||
* stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments
|
||||
|
||||
### Dec 23, 2019
|
||||
|
||||
* Add RandAugment trained MixNet-XL weights with 80.48 top-1.
|
||||
* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval
|
||||
|
||||
### Dec 4, 2019
|
||||
|
||||
* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5).
|
||||
|
||||
### Nov 29, 2019
|
||||
|
||||
* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded.
|
||||
* AdvProp weights added
|
||||
* Official TF MobileNetv3 weights added
|
||||
* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here...
|
||||
* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification
|
||||
* Consistency in global pooling, `reset_classifer`, and `forward_features` across models
|
||||
* `forward_features` always returns unpooled feature maps now
|
||||
* Reasonable chance I broke something... let me know
|
||||
|
||||
### Nov 22, 2019
|
||||
|
||||
* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update.
|
||||
* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise.
|
@ -1,54 +0,0 @@
|
||||
# Sharing and Loading Models From the Hugging Face Hub
|
||||
|
||||
The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub.
|
||||
|
||||
In this short guide, we'll see how to:
|
||||
1. Share a `timm` model on the Hub
|
||||
2. How to load that model back from the Hub
|
||||
|
||||
## Authenticating
|
||||
|
||||
First, you'll need to make sure you have the `huggingface_hub` package installed.
|
||||
|
||||
```bash
|
||||
pip install huggingface_hub
|
||||
```
|
||||
|
||||
Then, you'll need to authenticate yourself. You can do this by running the following command:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Or, if you're using a notebook, you can use the `notebook_login` helper:
|
||||
|
||||
```py
|
||||
>>> from huggingface_hub import notebook_login
|
||||
>>> notebook_login()
|
||||
```
|
||||
|
||||
## Sharing a Model
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
>>> model = timm.create_model('resnet18', pretrained=True, num_classes=4)
|
||||
```
|
||||
|
||||
Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial.
|
||||
|
||||
Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function.
|
||||
|
||||
```py
|
||||
>>> model_cfg = dict(labels=['a', 'b', 'c', 'd'])
|
||||
>>> timm.models.hub.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg)
|
||||
```
|
||||
|
||||
Running the above would push the model to `<your-username>/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code!
|
||||
|
||||
## Loading a Model
|
||||
|
||||
Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub.
|
||||
|
||||
```py
|
||||
>>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True)
|
||||
```
|
@ -1,22 +1,89 @@
|
||||
# timm
|
||||
# Getting Started
|
||||
|
||||
<img class="float-left !m-0 !border-0 !dark:border-0 !shadow-none !max-w-lg w-[150px]" src="https://huggingface.co/front/thumbnails/docs/timm.png"/>
|
||||
## Welcome
|
||||
|
||||
`timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.
|
||||
Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`.
|
||||
|
||||
It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use.
|
||||
For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora).
|
||||
|
||||
Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library.
|
||||
## Install
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./feature_extraction"
|
||||
><div class="w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Tutorials</div>
|
||||
<p class="text-gray-700">Learn the basics and become familiar with timm. Start here if you are using timm for the first time!</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./reference/models"
|
||||
><div class="w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Reference</div>
|
||||
<p class="text-gray-700">Technical descriptions of how timm classes and methods work.</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
The library can be installed with pip:
|
||||
|
||||
```
|
||||
pip install timm
|
||||
```
|
||||
|
||||
I update the PyPi (pip) packages when I'm confident there are no significant model regressions from previous releases. If you want to pip install the bleeding edge from GitHub, use:
|
||||
```
|
||||
pip install git+https://github.com/rwightman/pytorch-image-models.git
|
||||
```
|
||||
|
||||
### Conda Environment
|
||||
|
||||
<Tip>
|
||||
|
||||
- All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10
|
||||
|
||||
- Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment.
|
||||
|
||||
- PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code.
|
||||
|
||||
</Tip>
|
||||
|
||||
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
|
||||
|
||||
```bash
|
||||
conda create -n torch-env
|
||||
conda activate torch-env
|
||||
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
|
||||
conda install pyyaml
|
||||
```
|
||||
|
||||
## Load a Pretrained Model
|
||||
|
||||
Pretrained models can be loaded using `timm.create_model`
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
|
||||
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
|
||||
>>> m.eval()
|
||||
```
|
||||
|
||||
## List Models with Pretrained Weights
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
>>> from pprint import pprint
|
||||
>>> model_names = timm.list_models(pretrained=True)
|
||||
>>> pprint(model_names)
|
||||
[
|
||||
'adv_inception_v3',
|
||||
'cspdarknet53',
|
||||
'cspresnext50',
|
||||
'densenet121',
|
||||
'densenet161',
|
||||
'densenet169',
|
||||
'densenet201',
|
||||
'densenetblur121d',
|
||||
'dla34',
|
||||
'dla46_c',
|
||||
]
|
||||
```
|
||||
|
||||
## List Model Architectures by Wildcard
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
>>> from pprint import pprint
|
||||
>>> model_names = timm.list_models('*resne*t*')
|
||||
>>> pprint(model_names)
|
||||
[
|
||||
'cspresnet50',
|
||||
'cspresnet50d',
|
||||
'cspresnet50w',
|
||||
'cspresnext50',
|
||||
...
|
||||
]
|
||||
```
|
||||
|
@ -1,74 +0,0 @@
|
||||
# Installation
|
||||
|
||||
Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**.
|
||||
|
||||
## Virtual Environment
|
||||
|
||||
You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts.
|
||||
|
||||
1. Create and navigate to your project directory:
|
||||
|
||||
```bash
|
||||
mkdir ~/my-project
|
||||
cd ~/my-project
|
||||
```
|
||||
|
||||
2. Start a virtual environment inside your directory:
|
||||
|
||||
```bash
|
||||
python -m venv .env
|
||||
```
|
||||
|
||||
3. Activate and deactivate the virtual environment with the following commands:
|
||||
|
||||
```bash
|
||||
# Activate the virtual environment
|
||||
source .env/bin/activate
|
||||
|
||||
# Deactivate the virtual environment
|
||||
source .env/bin/deactivate
|
||||
```
|
||||
`
|
||||
Once you've created your virtual environment, you can install `timm` in it.
|
||||
|
||||
## Using pip
|
||||
|
||||
The most straightforward way to install `timm` is with pip:
|
||||
|
||||
```bash
|
||||
pip install timm
|
||||
```
|
||||
|
||||
Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/rwightman/pytorch-image-models.git
|
||||
```
|
||||
|
||||
Run the following command to check if `timm` has been properly installed:
|
||||
|
||||
```bash
|
||||
python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
|
||||
```
|
||||
|
||||
This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output:
|
||||
|
||||
```python
|
||||
['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384']
|
||||
```
|
||||
|
||||
## From Source
|
||||
|
||||
Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/rwightman/pytorch-image-models.git
|
||||
cd timm
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Again, you can check if `timm` was properly installed with the following command:
|
||||
|
||||
```bash
|
||||
python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
|
||||
```
|
@ -0,0 +1,5 @@
|
||||
# Available Models
|
||||
|
||||
`timm` comes bundled with a number of model architectures and corresponding pretrained models.
|
||||
|
||||
In these pages, you will find the models available in the `timm` library, as well as information on how to use them.
|
@ -1,228 +0,0 @@
|
||||
# Quickstart
|
||||
|
||||
This quickstart is intended for developers who are ready to dive into the code and see an example of how to integrate `timm` into their model training workflow.
|
||||
|
||||
First, you'll need to install `timm`. For more information on installation, see [Installation](installation).
|
||||
|
||||
```bash
|
||||
pip install timm
|
||||
```
|
||||
|
||||
## Load a Pretrained Model
|
||||
|
||||
Pretrained models can be loaded using [`create_model`].
|
||||
|
||||
Here, we load the pretrained `mobilenetv3_large_100` model.
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
|
||||
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
|
||||
>>> m.eval()
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Note: The returned PyTorch model is set to train mode by default, so you must call .eval() on it if you plan to use it for inference.
|
||||
</Tip>
|
||||
|
||||
## List Models with Pretrained Weights
|
||||
|
||||
To list models packaged with `timm`, you can use [`list_models`]. If you specify `pretrained=True`, this function will only return model names that have associated pretrained weights available.
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
>>> from pprint import pprint
|
||||
>>> model_names = timm.list_models(pretrained=True)
|
||||
>>> pprint(model_names)
|
||||
[
|
||||
'adv_inception_v3',
|
||||
'cspdarknet53',
|
||||
'cspresnext50',
|
||||
'densenet121',
|
||||
'densenet161',
|
||||
'densenet169',
|
||||
'densenet201',
|
||||
'densenetblur121d',
|
||||
'dla34',
|
||||
'dla46_c',
|
||||
]
|
||||
```
|
||||
|
||||
You can also list models with a specific pattern in their name.
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
>>> from pprint import pprint
|
||||
>>> model_names = timm.list_models('*resne*t*')
|
||||
>>> pprint(model_names)
|
||||
[
|
||||
'cspresnet50',
|
||||
'cspresnet50d',
|
||||
'cspresnet50w',
|
||||
'cspresnext50',
|
||||
...
|
||||
]
|
||||
```
|
||||
|
||||
## Fine-Tune a Pretrained Model
|
||||
|
||||
You can finetune any of the pre-trained models just by changing the classifier (the last layer).
|
||||
|
||||
```py
|
||||
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
|
||||
```
|
||||
|
||||
To fine-tune on your own dataset, you have to write a PyTorch training loop or adapt `timm`'s [training script](training_script) to use your dataset.
|
||||
|
||||
## Use a Pretrained Model for Feature Extraction
|
||||
|
||||
Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks.
|
||||
|
||||
For a more in depth guide to using `timm` for feature extraction, see [Feature Extraction](feature_extraction).
|
||||
|
||||
```py
|
||||
>>> import timm
|
||||
>>> import torch
|
||||
>>> x = torch.randn(1, 3, 224, 224)
|
||||
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
|
||||
>>> features = model.forward_features(x)
|
||||
>>> print(features.shape)
|
||||
torch.Size([1, 960, 7, 7])
|
||||
```
|
||||
|
||||
## Image Augmentation
|
||||
|
||||
To transform images into valid inputs for a model, you can use [`timm.data.create_transform`], providing the desired `input_size` that the model expects.
|
||||
|
||||
This will return a generic transform that uses reasonable defaults.
|
||||
|
||||
```py
|
||||
>>> timm.data.create_transform((3, 224, 224))
|
||||
Compose(
|
||||
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
|
||||
CenterCrop(size=(224, 224))
|
||||
ToTensor()
|
||||
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
|
||||
)
|
||||
```
|
||||
|
||||
Pretrained models have specific transforms that were applied to images fed into them while training. If you use the wrong transform on your image, the model won't understand what it's seeing!
|
||||
|
||||
To figure out which transformations were used for a given pretrained model, we can start by taking a look at its `pretrained_cfg`
|
||||
|
||||
```py
|
||||
>>> model.pretrained_cfg
|
||||
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
|
||||
'num_classes': 1000,
|
||||
'input_size': (3, 224, 224),
|
||||
'pool_size': (7, 7),
|
||||
'crop_pct': 0.875,
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.485, 0.456, 0.406),
|
||||
'std': (0.229, 0.224, 0.225),
|
||||
'first_conv': 'conv_stem',
|
||||
'classifier': 'classifier',
|
||||
'architecture': 'mobilenetv3_large_100'}
|
||||
```
|
||||
|
||||
We can then resolve only the data related configuration by using [`timm.data.resolve_data_config`].
|
||||
|
||||
```py
|
||||
>>> timm.data.resolve_data_config(model.pretrained_cfg)
|
||||
{'input_size': (3, 224, 224),
|
||||
'interpolation': 'bicubic',
|
||||
'mean': (0.485, 0.456, 0.406),
|
||||
'std': (0.229, 0.224, 0.225),
|
||||
'crop_pct': 0.875}
|
||||
```
|
||||
|
||||
We can pass this data config to [`timm.data.create_transform`] to initialize the model's associated transform.
|
||||
|
||||
```py
|
||||
>>> data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
|
||||
>>> transform = timm.data.create_transform(**data_cfg)
|
||||
>>> transform
|
||||
Compose(
|
||||
Resize(size=256, interpolation=bicubic, max_size=None, antialias=None)
|
||||
CenterCrop(size=(224, 224))
|
||||
ToTensor()
|
||||
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
|
||||
)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
Note: Here, the pretrained model's config happens to be the same as the generic config we made earlier. This is not always the case. So, it's safer to use the data config to create the transform as we did here instead of using the generic transform.
|
||||
</Tip>
|
||||
|
||||
## Using Pretrained Models for Inference
|
||||
|
||||
Here, we will put together the above sections and use a pretrained model for inference.
|
||||
|
||||
First we'll need an image to do inference on. Here we load a picture of a leaf from the web:
|
||||
|
||||
```py
|
||||
>>> import requests
|
||||
>>> from PIL import Image
|
||||
>>> from io import BytesIO
|
||||
>>> url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
|
||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||
>>> image
|
||||
```
|
||||
|
||||
Here's the image we loaded:
|
||||
|
||||
<img src="https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg" alt="An Image from a link" width="300"/>
|
||||
|
||||
Now, we'll create our model and transforms again. This time, we make sure to set our model in evaluation mode.
|
||||
|
||||
```py
|
||||
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
|
||||
>>> transform = timm.data.create_transform(
|
||||
**timm.data.resolve_data_config(model.pretrained_cfg)
|
||||
)
|
||||
```
|
||||
|
||||
We can prepare this image for the model by passing it to the transform.
|
||||
|
||||
```py
|
||||
>>> image_tensor = transform(image)
|
||||
>>> image_tensor.shape
|
||||
torch.Size([3, 224, 224])
|
||||
```
|
||||
|
||||
Now we can pass that image to the model to get the predictions. We use `unsqueeze(0)` in this case, as the model is expecting a batch dimension.
|
||||
|
||||
```py
|
||||
>>> output = model(image_tensor.unsqueeze(0))
|
||||
>>> output.shape
|
||||
torch.Size([1, 1000])
|
||||
```
|
||||
|
||||
To get the predicted probabilities, we apply softmax to the output. This leaves us with a tensor of shape `(num_classes,)`.
|
||||
|
||||
```py
|
||||
>>> probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
||||
>>> probabilities.shape
|
||||
torch.Size([1000])
|
||||
```
|
||||
|
||||
Now we'll find the top 5 predicted class indexes and values using `torch.topk`.
|
||||
|
||||
```py
|
||||
>>> values, indices = torch.topk(probabilities, 5)
|
||||
>>> indices
|
||||
tensor([162, 166, 161, 164, 167])
|
||||
```
|
||||
|
||||
If we check the imagenet labels for the top index, we can see what the model predicted...
|
||||
|
||||
```py
|
||||
>>> IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
|
||||
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
|
||||
>>> [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
|
||||
[{'label': 'beagle', 'value': 0.8486220836639404},
|
||||
{'label': 'Walker_hound, Walker_foxhound', 'value': 0.03753996267914772},
|
||||
{'label': 'basset, basset_hound', 'value': 0.024628572165966034},
|
||||
{'label': 'bluetick', 'value': 0.010317106731235981},
|
||||
{'label': 'English_foxhound', 'value': 0.006958036217838526}]
|
||||
```
|
@ -1,9 +0,0 @@
|
||||
# Data
|
||||
|
||||
[[autodoc]] timm.data.create_dataset
|
||||
|
||||
[[autodoc]] timm.data.create_loader
|
||||
|
||||
[[autodoc]] timm.data.create_transform
|
||||
|
||||
[[autodoc]] timm.data.resolve_data_config
|
@ -1,5 +0,0 @@
|
||||
# Models
|
||||
|
||||
[[autodoc]] timm.create_model
|
||||
|
||||
[[autodoc]] timm.list_models
|
@ -1,27 +0,0 @@
|
||||
# Optimization
|
||||
|
||||
This page contains the API reference documentation for learning rate optimizers included in `timm`.
|
||||
|
||||
## Optimizers
|
||||
|
||||
### Factory functions
|
||||
|
||||
[[autodoc]] timm.optim.optim_factory.create_optimizer
|
||||
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
|
||||
|
||||
### Optimizer Classes
|
||||
|
||||
[[autodoc]] timm.optim.adabelief.AdaBelief
|
||||
[[autodoc]] timm.optim.adafactor.Adafactor
|
||||
[[autodoc]] timm.optim.adahessian.Adahessian
|
||||
[[autodoc]] timm.optim.adamp.AdamP
|
||||
[[autodoc]] timm.optim.adamw.AdamW
|
||||
[[autodoc]] timm.optim.lamb.Lamb
|
||||
[[autodoc]] timm.optim.lars.Lars
|
||||
[[autodoc]] timm.optim.lookahead.Lookahead
|
||||
[[autodoc]] timm.optim.madgrad.MADGRAD
|
||||
[[autodoc]] timm.optim.nadam.Nadam
|
||||
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
|
||||
[[autodoc]] timm.optim.radam.RAdam
|
||||
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
|
||||
[[autodoc]] timm.optim.sgdp.SGDP
|
@ -1,19 +0,0 @@
|
||||
# Learning Rate Schedulers
|
||||
|
||||
This page contains the API reference documentation for learning rate schedulers included in `timm`.
|
||||
|
||||
## Schedulers
|
||||
|
||||
### Factory functions
|
||||
|
||||
[[autodoc]] timm.scheduler.scheduler_factory.create_scheduler
|
||||
[[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2
|
||||
|
||||
### Scheduler Classes
|
||||
|
||||
[[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler
|
||||
[[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler
|
||||
[[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler
|
||||
[[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler
|
||||
[[autodoc]] timm.scheduler.step_lr.StepLRScheduler
|
||||
[[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler
|
@ -0,0 +1,35 @@
|
||||
# Scripts
|
||||
A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release.
|
||||
|
||||
The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on
|
||||
[NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples).
|
||||
|
||||
## Training Script
|
||||
|
||||
The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder.
|
||||
|
||||
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
|
||||
|
||||
```bash
|
||||
./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4
|
||||
```
|
||||
|
||||
<Tip>
|
||||
It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. --amp defaults to native AMP as of timm ver 0.4.3. --apex-amp will force use of APEX components if they are installed.
|
||||
</Tip>
|
||||
|
||||
## Validation / Inference Scripts
|
||||
|
||||
Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.
|
||||
|
||||
To validate with the model's pretrained weights (if they exist):
|
||||
|
||||
```bash
|
||||
python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained
|
||||
```
|
||||
|
||||
To run inference from a checkpoint:
|
||||
|
||||
```bash
|
||||
python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar
|
||||
```
|
@ -1,3 +1,4 @@
|
||||
dependencies = ['torch']
|
||||
import timm
|
||||
globals().update(timm.models._registry._model_entrypoints)
|
||||
from timm.models import registry
|
||||
|
||||
globals().update(registry._model_entrypoints)
|
||||
|
@ -1,5 +1,4 @@
|
||||
mkdocs
|
||||
mkdocs-material
|
||||
mkdocs-redirects
|
||||
mdx_truly_sane_lists
|
||||
mkdocs-awesome-pages-plugin
|
||||
mkdocs-awesome-pages-plugin
|
@ -0,0 +1,2 @@
|
||||
model-index==0.1.10
|
||||
jinja2==2.11.3
|
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,4 +1,4 @@
|
||||
from .version import __version__
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
|
||||
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
|
||||
is_scriptable, is_exportable, set_scriptable, set_exportable, \
|
||||
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,73 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
|
||||
class DatasetInfo(ABC):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def num_classes(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def label_names(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def index_to_label_name(self, index) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class CustomDatasetInfo(DatasetInfo):
|
||||
""" DatasetInfo that wraps passed values for custom datasets."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
label_names: Union[List[str], Dict[int, str]],
|
||||
label_descriptions: Optional[Dict[str, str]] = None
|
||||
):
|
||||
super().__init__()
|
||||
assert len(label_names) > 0
|
||||
self._label_names = label_names # label index => label name mapping
|
||||
self._label_descriptions = label_descriptions # label name => label description mapping
|
||||
if self._label_descriptions is not None:
|
||||
# validate descriptions (label names required)
|
||||
assert isinstance(self._label_descriptions, dict)
|
||||
for n in self._label_names:
|
||||
assert n in self._label_descriptions
|
||||
|
||||
def num_classes(self):
|
||||
return len(self._label_names)
|
||||
|
||||
def label_names(self):
|
||||
return self._label_names
|
||||
|
||||
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||
return self._label_descriptions
|
||||
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
if self._label_descriptions:
|
||||
return self._label_descriptions[label]
|
||||
return label # return label name itself if a descriptions is not present
|
||||
|
||||
def index_to_label_name(self, index) -> str:
|
||||
assert 0 <= index < len(self._label_names)
|
||||
return self._label_names[index]
|
||||
|
||||
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||
label = self.index_to_label_name(index)
|
||||
return self.label_name_to_description(label, detailed=detailed)
|
@ -1,92 +0,0 @@
|
||||
import csv
|
||||
import os
|
||||
import pkgutil
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from .dataset_info import DatasetInfo
|
||||
|
||||
|
||||
_NUM_CLASSES_TO_SUBSET = {
|
||||
1000: 'imagenet-1k',
|
||||
11821: 'imagenet-12k',
|
||||
21841: 'imagenet-22k',
|
||||
21843: 'imagenet-21k-goog',
|
||||
11221: 'imagenet-21k-miil',
|
||||
}
|
||||
|
||||
_SUBSETS = {
|
||||
'imagenet1k': 'imagenet_synsets.txt',
|
||||
'imagenet12k': 'imagenet12k_synsets.txt',
|
||||
'imagenet22k': 'imagenet22k_synsets.txt',
|
||||
'imagenet21k': 'imagenet21k_goog_synsets.txt',
|
||||
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
|
||||
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
|
||||
}
|
||||
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
|
||||
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'
|
||||
|
||||
|
||||
def infer_imagenet_subset(model_or_cfg) -> Optional[str]:
|
||||
if isinstance(model_or_cfg, dict):
|
||||
num_classes = model_or_cfg.get('num_classes', None)
|
||||
else:
|
||||
num_classes = getattr(model_or_cfg, 'num_classes', None)
|
||||
if not num_classes:
|
||||
pretrained_cfg = getattr(model_or_cfg, 'pretrained_cfg', {})
|
||||
# FIXME at some point pretrained_cfg should include dataset-tag,
|
||||
# which will be more robust than a guess based on num_classes
|
||||
num_classes = pretrained_cfg.get('num_classes', None)
|
||||
if not num_classes or num_classes not in _NUM_CLASSES_TO_SUBSET:
|
||||
return None
|
||||
return _NUM_CLASSES_TO_SUBSET[num_classes]
|
||||
|
||||
|
||||
class ImageNetInfo(DatasetInfo):
|
||||
|
||||
def __init__(self, subset: str = 'imagenet-1k'):
|
||||
super().__init__()
|
||||
subset = re.sub(r'[-_\s]', '', subset.lower())
|
||||
assert subset in _SUBSETS, f'Unknown imagenet subset {subset}.'
|
||||
|
||||
# WordNet synsets (part-of-speach + offset) are the unique class label names for ImageNet classifiers
|
||||
synset_file = _SUBSETS[subset]
|
||||
synset_data = pkgutil.get_data(__name__, os.path.join('_info', synset_file))
|
||||
self._synsets = synset_data.decode('utf-8').splitlines()
|
||||
|
||||
# WordNet lemmas (canonical dictionary form of word) and definitions are used to build
|
||||
# the class descriptions. If detailed=True both are used, otherwise just the lemmas.
|
||||
lemma_data = pkgutil.get_data(__name__, os.path.join('_info', _LEMMA_FILE))
|
||||
reader = csv.reader(lemma_data.decode('utf-8').splitlines(), delimiter='\t')
|
||||
self._lemmas = dict(reader)
|
||||
definition_data = pkgutil.get_data(__name__, os.path.join('_info', _DEFINITION_FILE))
|
||||
reader = csv.reader(definition_data.decode('utf-8').splitlines(), delimiter='\t')
|
||||
self._definitions = dict(reader)
|
||||
|
||||
def num_classes(self):
|
||||
return len(self._synsets)
|
||||
|
||||
def label_names(self):
|
||||
return self._synsets
|
||||
|
||||
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
|
||||
if as_dict:
|
||||
return {label: self.label_name_to_description(label, detailed=detailed) for label in self._synsets}
|
||||
else:
|
||||
return [self.label_name_to_description(label, detailed=detailed) for label in self._synsets]
|
||||
|
||||
def index_to_label_name(self, index) -> str:
|
||||
assert 0 <= index < len(self._synsets), \
|
||||
f'Index ({index}) out of range for dataset with {len(self._synsets)} classes.'
|
||||
return self._synsets[index]
|
||||
|
||||
def index_to_description(self, index: int, detailed: bool = False) -> str:
|
||||
label = self.index_to_label_name(index)
|
||||
return self.label_name_to_description(label, detailed=detailed)
|
||||
|
||||
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
|
||||
if detailed:
|
||||
description = f'{self._lemmas[label]}: {self._definitions[label]}'
|
||||
else:
|
||||
description = f'{self._lemmas[label]}'
|
||||
return description
|
@ -1,50 +0,0 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
from .blur_pool import BlurPool2d
|
||||
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
|
||||
from .cond_conv2d import CondConv2d, get_condconv_initializer
|
||||
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
|
||||
set_layer_config
|
||||
from .conv2d_same import Conv2dSame, conv2d_same
|
||||
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
|
||||
from .create_act import create_act_layer, get_act_layer, get_act_fn
|
||||
from .create_attn import get_attn, create_attn
|
||||
from .create_conv2d import create_conv2d
|
||||
from .create_norm import get_norm_layer, create_norm_layer
|
||||
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
|
||||
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
|
||||
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
|
||||
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
|
||||
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
|
||||
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
|
||||
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
|
||||
from .gather_excite import GatherExcite
|
||||
from .global_context import GlobalContext
|
||||
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
|
||||
from .inplace_abn import InplaceAbn
|
||||
from .linear import Linear
|
||||
from .mixed_conv2d import MixedConv2d
|
||||
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
|
||||
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
|
||||
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
|
||||
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
|
||||
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
|
||||
from .padding import get_padding, get_same_padding, pad_same
|
||||
from .patch_embed import PatchEmbed, resample_patch_embed
|
||||
from .pool2d_same import AvgPool2dSame, create_pool2d
|
||||
from .pos_embed import resample_abs_pos_embed
|
||||
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
|
||||
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
|
||||
FourierEmbed, RotaryEmbedding
|
||||
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
|
||||
from .selective_kernel import SelectiveKernel
|
||||
from .separable_conv import SeparableConv2d, SeparableConvNormAct
|
||||
from .space_to_depth import SpaceToDepthModule
|
||||
from .split_attn import SplitAttn
|
||||
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
|
||||
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
|
||||
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
from .trace_utils import _assert, _float_to_int
|
||||
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
|
@ -1,161 +0,0 @@
|
||||
""" Classifier head and layer factory
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from typing import Optional, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .create_act import get_act_layer
|
||||
from .create_norm import get_norm_layer
|
||||
|
||||
|
||||
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
|
||||
if not pool_type:
|
||||
assert num_classes == 0 or use_conv,\
|
||||
'Pooling can only be disabled if classifier is also removed or conv classifier is used'
|
||||
flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
|
||||
global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool)
|
||||
num_pooled_features = num_features * global_pool.feat_mult()
|
||||
return global_pool, num_pooled_features
|
||||
|
||||
|
||||
def _create_fc(num_features, num_classes, use_conv=False):
|
||||
if num_classes <= 0:
|
||||
fc = nn.Identity() # pass-through (no classifier)
|
||||
elif use_conv:
|
||||
fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
|
||||
else:
|
||||
fc = nn.Linear(num_features, num_classes, bias=True)
|
||||
return fc
|
||||
|
||||
|
||||
def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False):
|
||||
global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
|
||||
fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
return global_pool, fc
|
||||
|
||||
|
||||
class ClassifierHead(nn.Module):
|
||||
"""Classifier head w/ configurable global pooling and dropout."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
use_conv: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features: The number of input features.
|
||||
num_classes: The number of classes for the final classifier layer (output).
|
||||
pool_type: Global pooling type, pooling disabled if empty string ('').
|
||||
drop_rate: Pre-classifier dropout rate.
|
||||
"""
|
||||
super(ClassifierHead, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.use_conv = use_conv
|
||||
|
||||
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
|
||||
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
if global_pool != self.global_pool.pool_type:
|
||||
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
|
||||
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
|
||||
num_pooled_features = self.in_features * self.global_pool.feat_mult()
|
||||
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
if self.drop_rate:
|
||||
x = F.dropout(x, p=float(self.drop_rate), training=self.training)
|
||||
if pre_logits:
|
||||
return x.flatten(1)
|
||||
else:
|
||||
x = self.fc(x)
|
||||
return self.flatten(x)
|
||||
|
||||
|
||||
class NormMlpClassifierHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
num_classes: int,
|
||||
hidden_size: Optional[int] = None,
|
||||
pool_type: str = 'avg',
|
||||
drop_rate: float = 0.,
|
||||
norm_layer: Union[str, Callable] = 'layernorm2d',
|
||||
act_layer: Union[str, Callable] = 'tanh',
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
in_features: The number of input features.
|
||||
num_classes: The number of classes for the final classifier layer (output).
|
||||
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
|
||||
pool_type: Global pooling type, pooling disabled if empty string ('').
|
||||
drop_rate: Pre-classifier dropout rate.
|
||||
norm_layer: Normalization layer type.
|
||||
act_layer: MLP activation layer type (only used if hidden_size is not None).
|
||||
"""
|
||||
super().__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.in_features = in_features
|
||||
self.hidden_size = hidden_size
|
||||
self.num_features = in_features
|
||||
self.use_conv = not pool_type
|
||||
norm_layer = get_norm_layer(norm_layer)
|
||||
act_layer = get_act_layer(act_layer)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
|
||||
self.norm = norm_layer(in_features)
|
||||
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
|
||||
if hidden_size:
|
||||
self.pre_logits = nn.Sequential(OrderedDict([
|
||||
('fc', linear_layer(in_features, hidden_size)),
|
||||
('act', act_layer()),
|
||||
]))
|
||||
self.num_features = hidden_size
|
||||
else:
|
||||
self.pre_logits = nn.Identity()
|
||||
self.drop = nn.Dropout(self.drop_rate)
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def reset(self, num_classes, global_pool=None):
|
||||
if global_pool is not None:
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
||||
self.use_conv = self.global_pool.is_identity()
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
|
||||
if self.hidden_size:
|
||||
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
|
||||
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
|
||||
with torch.no_grad():
|
||||
new_fc = linear_layer(self.in_features, self.hidden_size)
|
||||
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
|
||||
new_fc.bias.copy_(self.pre_logits.fc.bias)
|
||||
self.pre_logits.fc = new_fc
|
||||
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward(self, x, pre_logits: bool = False):
|
||||
x = self.global_pool(x)
|
||||
x = self.norm(x)
|
||||
x = self.flatten(x)
|
||||
x = self.pre_logits(x)
|
||||
if pre_logits:
|
||||
return x
|
||||
x = self.fc(x)
|
||||
return x
|
@ -1,39 +0,0 @@
|
||||
""" Global Response Normalization Module
|
||||
|
||||
Based on the GRN layer presented in
|
||||
`ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
|
||||
|
||||
This implementation
|
||||
* works for both NCHW and NHWC tensor layouts
|
||||
* uses affine param names matching existing torch norm layers
|
||||
* slightly improves eager mode performance via fused addcmul
|
||||
|
||||
Hacked together by / Copyright 2023 Ross Wightman
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
|
||||
|
||||
class GlobalResponseNorm(nn.Module):
|
||||
""" Global Response Normalization layer
|
||||
"""
|
||||
def __init__(self, dim, eps=1e-6, channels_last=True):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
if channels_last:
|
||||
self.spatial_dim = (1, 2)
|
||||
self.channel_dim = -1
|
||||
self.wb_shape = (1, 1, 1, -1)
|
||||
else:
|
||||
self.spatial_dim = (2, 3)
|
||||
self.channel_dim = 1
|
||||
self.wb_shape = (1, -1, 1, 1)
|
||||
|
||||
self.weight = nn.Parameter(torch.zeros(dim))
|
||||
self.bias = nn.Parameter(torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
|
||||
x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
|
||||
return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)
|
@ -1,184 +0,0 @@
|
||||
""" Image to Patch Embedding using Conv2d
|
||||
|
||||
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
||||
|
||||
Based on code in:
|
||||
* https://github.com/google-research/vision_transformer
|
||||
* https://github.com/google-research/big_vision/tree/main/big_vision
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
from .trace_utils import _assert
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" 2D Image to Patch Embedding
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
norm_layer=None,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
self.flatten = flatten
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
_assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
_assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
def resample_patch_embed(
|
||||
patch_embed,
|
||||
new_size: List[int],
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""Resample the weights of the patch embedding kernel to target resolution.
|
||||
We resample the patch embedding kernel by approximately inverting the effect
|
||||
of patch resizing.
|
||||
|
||||
Code based on:
|
||||
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
|
||||
|
||||
With this resizing, we can for example load a B/8 filter into a B/16 model
|
||||
and, on 2x larger input image, the result will match.
|
||||
|
||||
Args:
|
||||
patch_embed: original parameter to be resized.
|
||||
new_size (tuple(int, int): target shape (height, width)-only.
|
||||
interpolation (str): interpolation for resize
|
||||
antialias (bool): use anti-aliasing filter in resize
|
||||
verbose (bool): log operation
|
||||
Returns:
|
||||
Resized patch embedding kernel.
|
||||
"""
|
||||
import numpy as np
|
||||
try:
|
||||
import functorch
|
||||
vmap = functorch.vmap
|
||||
except ImportError:
|
||||
if hasattr(torch, 'vmap'):
|
||||
vmap = torch.vmap
|
||||
else:
|
||||
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
|
||||
|
||||
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
||||
assert len(new_size) == 2, "New shape should only be hw"
|
||||
old_size = patch_embed.shape[-2:]
|
||||
if tuple(old_size) == tuple(new_size):
|
||||
return patch_embed
|
||||
|
||||
if verbose:
|
||||
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
|
||||
|
||||
def resize(x_np, _new_size):
|
||||
x_tf = torch.Tensor(x_np)[None, None, ...]
|
||||
x_upsampled = F.interpolate(
|
||||
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
|
||||
return x_upsampled
|
||||
|
||||
def get_resize_mat(_old_size, _new_size):
|
||||
mat = []
|
||||
for i in range(np.prod(_old_size)):
|
||||
basis_vec = np.zeros(_old_size)
|
||||
basis_vec[np.unravel_index(i, _old_size)] = 1.
|
||||
mat.append(resize(basis_vec, _new_size).reshape(-1))
|
||||
return np.stack(mat).T
|
||||
|
||||
resize_mat = get_resize_mat(old_size, new_size)
|
||||
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
|
||||
|
||||
def resample_kernel(kernel):
|
||||
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
|
||||
return resampled_kernel.reshape(new_size)
|
||||
|
||||
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
|
||||
return v_resample_kernel(patch_embed)
|
||||
|
||||
|
||||
# def divs(n, m=None):
|
||||
# m = m or n // 2
|
||||
# if m == 1:
|
||||
# return [1]
|
||||
# if n % m == 0:
|
||||
# return [m] + divs(n, m - 1)
|
||||
# return divs(n, m - 1)
|
||||
#
|
||||
#
|
||||
# class FlexiPatchEmbed(nn.Module):
|
||||
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
|
||||
# FIXME WIP
|
||||
# """
|
||||
# def __init__(
|
||||
# self,
|
||||
# img_size=240,
|
||||
# patch_size=16,
|
||||
# in_chans=3,
|
||||
# embed_dim=768,
|
||||
# base_img_size=240,
|
||||
# base_patch_size=32,
|
||||
# norm_layer=None,
|
||||
# flatten=True,
|
||||
# bias=True,
|
||||
# ):
|
||||
# super().__init__()
|
||||
# self.img_size = to_2tuple(img_size)
|
||||
# self.patch_size = to_2tuple(patch_size)
|
||||
# self.num_patches = 0
|
||||
#
|
||||
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
|
||||
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
|
||||
#
|
||||
# self.base_img_size = to_2tuple(base_img_size)
|
||||
# self.base_patch_size = to_2tuple(base_patch_size)
|
||||
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
|
||||
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
|
||||
#
|
||||
# self.flatten = flatten
|
||||
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
|
||||
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
#
|
||||
# def forward(self, x):
|
||||
# B, C, H, W = x.shape
|
||||
#
|
||||
# if self.patch_size == self.base_patch_size:
|
||||
# weight = self.proj.weight
|
||||
# else:
|
||||
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
|
||||
# patch_size = self.patch_size
|
||||
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
|
||||
# if self.flatten:
|
||||
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
# x = self.norm(x)
|
||||
# return x
|
@ -1,52 +0,0 @@
|
||||
""" Position Embedding Utilities
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from typing import List, Tuple, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .helpers import to_2tuple
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def resample_abs_pos_embed(
|
||||
posemb,
|
||||
new_size: List[int],
|
||||
old_size: Optional[List[int]] = None,
|
||||
num_prefix_tokens: int = 1,
|
||||
interpolation: str = 'bicubic',
|
||||
antialias: bool = True,
|
||||
verbose: bool = False,
|
||||
):
|
||||
# sort out sizes, assume square if old size not provided
|
||||
new_size = to_2tuple(new_size)
|
||||
new_ntok = new_size[0] * new_size[1]
|
||||
if not old_size:
|
||||
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
|
||||
old_size = to_2tuple(old_size)
|
||||
if new_size == old_size: # might not both be same container type
|
||||
return posemb
|
||||
|
||||
if num_prefix_tokens:
|
||||
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
|
||||
else:
|
||||
posemb_prefix, posemb = None, posemb
|
||||
|
||||
# do the interpolation
|
||||
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
|
||||
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
|
||||
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
|
||||
|
||||
if verbose:
|
||||
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
|
||||
|
||||
# add back extra (class, etc) prefix tokens
|
||||
if posemb_prefix is not None:
|
||||
print(posemb_prefix.shape, posemb.shape)
|
||||
posemb = torch.cat([posemb_prefix, posemb], dim=1)
|
||||
return posemb
|
@ -1,270 +0,0 @@
|
||||
""" Relative position embedding modules and functions
|
||||
|
||||
Hacked together by / Copyright 2022 Ross Wightman
|
||||
"""
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .mlp import Mlp
|
||||
from .weight_init import trunc_normal_
|
||||
|
||||
|
||||
def gen_relative_position_index(
|
||||
q_size: Tuple[int, int],
|
||||
k_size: Tuple[int, int] = None,
|
||||
class_token: bool = False) -> torch.Tensor:
|
||||
# Adapted with significant modifications from Swin / BeiT codebases
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
|
||||
if k_size is None:
|
||||
k_coords = q_coords
|
||||
k_size = q_size
|
||||
else:
|
||||
# different q vs k sizes is a WIP
|
||||
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
|
||||
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
|
||||
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
|
||||
|
||||
if class_token:
|
||||
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
|
||||
# NOTE not intended or tested with MLP log-coords
|
||||
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
|
||||
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
|
||||
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
|
||||
relative_position_index[0, 0:] = num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = num_relative_distance - 2
|
||||
relative_position_index[0, 0] = num_relative_distance - 1
|
||||
|
||||
return relative_position_index.contiguous()
|
||||
|
||||
|
||||
class RelPosBias(nn.Module):
|
||||
""" Relative Position Bias
|
||||
Adapted from Swin-V1 relative position bias impl, modularized.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||
super().__init__()
|
||||
assert prefix_tokens <= 1
|
||||
self.window_size = window_size
|
||||
self.window_area = window_size[0] * window_size[1]
|
||||
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
|
||||
|
||||
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
|
||||
# win_h * win_w, win_h * win_w, num_heads
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
return attn + self.get_bias()
|
||||
|
||||
|
||||
def gen_relative_log_coords(
|
||||
win_size: Tuple[int, int],
|
||||
pretrained_win_size: Tuple[int, int] = (0, 0),
|
||||
mode='swin',
|
||||
):
|
||||
assert mode in ('swin', 'cr')
|
||||
# as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
|
||||
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
|
||||
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
|
||||
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
|
||||
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
|
||||
if mode == 'swin':
|
||||
if pretrained_win_size[0] > 0:
|
||||
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
|
||||
else:
|
||||
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
|
||||
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
|
||||
relative_coords_table *= 8 # normalize to -8, 8
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
|
||||
1.0 + relative_coords_table.abs()) / math.log2(8)
|
||||
else:
|
||||
# mode == 'cr'
|
||||
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
|
||||
1.0 + relative_coords_table.abs())
|
||||
|
||||
return relative_coords_table
|
||||
|
||||
|
||||
class RelPosMlp(nn.Module):
|
||||
""" Log-Coordinate Relative Position MLP
|
||||
Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
|
||||
|
||||
This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
window_size,
|
||||
num_heads=8,
|
||||
hidden_dim=128,
|
||||
prefix_tokens=0,
|
||||
mode='cr',
|
||||
pretrained_window_size=(0, 0)
|
||||
):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.window_area = self.window_size[0] * self.window_size[1]
|
||||
self.prefix_tokens = prefix_tokens
|
||||
self.num_heads = num_heads
|
||||
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
|
||||
if mode == 'swin':
|
||||
self.bias_act = nn.Sigmoid()
|
||||
self.bias_gain = 16
|
||||
mlp_bias = (True, False)
|
||||
else:
|
||||
self.bias_act = nn.Identity()
|
||||
self.bias_gain = None
|
||||
mlp_bias = True
|
||||
|
||||
self.mlp = Mlp(
|
||||
2, # x, y
|
||||
hidden_features=hidden_dim,
|
||||
out_features=num_heads,
|
||||
act_layer=nn.ReLU,
|
||||
bias=mlp_bias,
|
||||
drop=(0.125, 0.)
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index",
|
||||
gen_relative_position_index(window_size),
|
||||
persistent=False)
|
||||
|
||||
# get relative_coords_table
|
||||
self.register_buffer(
|
||||
"rel_coords_log",
|
||||
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
|
||||
persistent=False)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
relative_position_bias = self.mlp(self.rel_coords_log)
|
||||
if self.relative_position_index is not None:
|
||||
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
|
||||
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.view(self.bias_shape)
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1)
|
||||
relative_position_bias = self.bias_act(relative_position_bias)
|
||||
if self.bias_gain is not None:
|
||||
relative_position_bias = self.bias_gain * relative_position_bias
|
||||
if self.prefix_tokens:
|
||||
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
|
||||
return relative_position_bias.unsqueeze(0).contiguous()
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
return attn + self.get_bias()
|
||||
|
||||
|
||||
def generate_lookup_tensor(
|
||||
length: int,
|
||||
max_relative_position: Optional[int] = None,
|
||||
):
|
||||
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
|
||||
|
||||
Args:
|
||||
length: the length to reindex to.
|
||||
max_relative_position: the maximum relative position to consider.
|
||||
Relative position embeddings for distances above this threshold
|
||||
are zeroed out.
|
||||
Returns:
|
||||
a lookup Tensor of size [length, length, vocab_size] that satisfies
|
||||
ret[n,m,v] = 1{m - n + max_relative_position = v}.
|
||||
"""
|
||||
if max_relative_position is None:
|
||||
max_relative_position = length - 1
|
||||
# Return the cached lookup tensor, otherwise compute it and cache it.
|
||||
vocab_size = 2 * max_relative_position + 1
|
||||
ret = torch.zeros(length, length, vocab_size)
|
||||
for i in range(length):
|
||||
for x in range(length):
|
||||
v = x - i + max_relative_position
|
||||
if abs(x - i) > max_relative_position:
|
||||
continue
|
||||
ret[i, x, v] = 1
|
||||
return ret
|
||||
|
||||
|
||||
def reindex_2d_einsum_lookup(
|
||||
relative_position_tensor,
|
||||
height: int,
|
||||
width: int,
|
||||
height_lookup: torch.Tensor,
|
||||
width_lookup: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Reindex 2d relative position bias with 2 independent einsum lookups.
|
||||
|
||||
Adapted from:
|
||||
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
|
||||
|
||||
Args:
|
||||
relative_position_tensor: tensor of shape
|
||||
[..., vocab_height, vocab_width, ...].
|
||||
height: height to reindex to.
|
||||
width: width to reindex to.
|
||||
height_lookup: one-hot height lookup
|
||||
width_lookup: one-hot width lookup
|
||||
Returns:
|
||||
reindexed_tensor: a Tensor of shape
|
||||
[..., height * width, height * width, ...]
|
||||
"""
|
||||
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
|
||||
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
|
||||
area = height * width
|
||||
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
|
||||
|
||||
|
||||
class RelPosBiasTf(nn.Module):
|
||||
""" Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
|
||||
Adapted from:
|
||||
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
|
||||
"""
|
||||
def __init__(self, window_size, num_heads, prefix_tokens=0):
|
||||
super().__init__()
|
||||
assert prefix_tokens <= 1
|
||||
self.window_size = window_size
|
||||
self.window_area = window_size[0] * window_size[1]
|
||||
self.num_heads = num_heads
|
||||
|
||||
vocab_height = 2 * window_size[0] - 1
|
||||
vocab_width = 2 * window_size[1] - 1
|
||||
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
|
||||
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
|
||||
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
|
||||
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
nn.init.normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def get_bias(self) -> torch.Tensor:
|
||||
# FIXME change to not use one-hot/einsum?
|
||||
return reindex_2d_einsum_lookup(
|
||||
self.relative_position_bias_table,
|
||||
self.window_size[0],
|
||||
self.window_size[1],
|
||||
self.height_lookup,
|
||||
self.width_lookup
|
||||
)
|
||||
|
||||
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
|
||||
return attn + self.get_bias()
|
@ -1,409 +0,0 @@
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Optional, Dict, Callable, Any, Tuple
|
||||
|
||||
from torch import nn as nn
|
||||
from torch.hub import load_state_dict_from_url
|
||||
|
||||
from timm.models._features import FeatureListNet, FeatureHookNet
|
||||
from timm.models._features_fx import FeatureGraphNet
|
||||
from timm.models._helpers import load_state_dict
|
||||
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
|
||||
from timm.models._manipulate import adapt_input_conv
|
||||
from timm.models._pretrained import PretrainedCfg
|
||||
from timm.models._prune import adapt_model_from_file
|
||||
from timm.models._registry import get_pretrained_cfg
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
# Global variables for rarely used pretrained checkpoint download progress and hash check.
|
||||
# Use set_pretrained_download_progress / set_pretrained_check_hash functions to toggle.
|
||||
_DOWNLOAD_PROGRESS = False
|
||||
_CHECK_HASH = False
|
||||
|
||||
|
||||
__all__ = ['set_pretrained_download_progress', 'set_pretrained_check_hash', 'load_custom_pretrained', 'load_pretrained',
|
||||
'pretrained_cfg_for_features', 'resolve_pretrained_cfg', 'build_model_with_cfg']
|
||||
|
||||
|
||||
def _resolve_pretrained_source(pretrained_cfg):
|
||||
cfg_source = pretrained_cfg.get('source', '')
|
||||
pretrained_url = pretrained_cfg.get('url', None)
|
||||
pretrained_file = pretrained_cfg.get('file', None)
|
||||
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
|
||||
|
||||
# resolve where to load pretrained weights from
|
||||
load_from = ''
|
||||
pretrained_loc = ''
|
||||
if cfg_source == 'hf-hub' and has_hf_hub(necessary=True):
|
||||
# hf-hub specified as source via model identifier
|
||||
load_from = 'hf-hub'
|
||||
assert hf_hub_id
|
||||
pretrained_loc = hf_hub_id
|
||||
else:
|
||||
# default source == timm or unspecified
|
||||
if pretrained_file:
|
||||
# file load override is the highest priority if set
|
||||
load_from = 'file'
|
||||
pretrained_loc = pretrained_file
|
||||
else:
|
||||
# next, HF hub is prioritized unless a valid cached version of weights exists already
|
||||
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
|
||||
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
|
||||
# hf-hub available as alternate weight source in default_cfg
|
||||
load_from = 'hf-hub'
|
||||
pretrained_loc = hf_hub_id
|
||||
elif pretrained_url:
|
||||
load_from = 'url'
|
||||
pretrained_loc = pretrained_url
|
||||
|
||||
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
|
||||
# if a filename override is set, return tuple for location w/ (hub_id, filename)
|
||||
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
|
||||
return load_from, pretrained_loc
|
||||
|
||||
|
||||
def set_pretrained_download_progress(enable=True):
|
||||
""" Set download progress for pretrained weights on/off (globally). """
|
||||
global _DOWNLOAD_PROGRESS
|
||||
_DOWNLOAD_PROGRESS = enable
|
||||
|
||||
|
||||
def set_pretrained_check_hash(enable=True):
|
||||
""" Set hash checking for pretrained weights on/off (globally). """
|
||||
global _CHECK_HASH
|
||||
_CHECK_HASH = enable
|
||||
|
||||
|
||||
def load_custom_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
load_fn: Optional[Callable] = None,
|
||||
):
|
||||
r"""Loads a custom (read non .pth) weight file
|
||||
|
||||
Downloads checkpoint file into cache-dir like torch.hub based loaders, but calls
|
||||
a passed in custom load fun, or the `load_pretrained` model member fn.
|
||||
|
||||
If the object is already present in `model_dir`, it's deserialized and returned.
|
||||
The default value of `model_dir` is ``<hub_dir>/checkpoints`` where
|
||||
`hub_dir` is the directory returned by :func:`~torch.hub.get_dir`.
|
||||
|
||||
Args:
|
||||
model: The instantiated model to load weights into
|
||||
pretrained_cfg (dict): Default pretrained model cfg
|
||||
load_fn: An external standalone fn that loads weights into provided model, otherwise a fn named
|
||||
'laod_pretrained' on the model will be called if it exists
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if not load_from:
|
||||
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
|
||||
return
|
||||
if load_from == 'hf-hub': # FIXME
|
||||
_logger.warning("Hugging Face hub not currently supported for custom load pretrained models.")
|
||||
elif load_from == 'url':
|
||||
pretrained_loc = download_cached_file(
|
||||
pretrained_loc,
|
||||
check_hash=_CHECK_HASH,
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
)
|
||||
|
||||
if load_fn is not None:
|
||||
load_fn(model, pretrained_loc)
|
||||
elif hasattr(model, 'load_pretrained'):
|
||||
model.load_pretrained(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
|
||||
|
||||
|
||||
def load_pretrained(
|
||||
model: nn.Module,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
num_classes: int = 1000,
|
||||
in_chans: int = 3,
|
||||
filter_fn: Optional[Callable] = None,
|
||||
strict: bool = True,
|
||||
):
|
||||
""" Load pretrained checkpoint
|
||||
|
||||
Args:
|
||||
model (nn.Module) : PyTorch model module
|
||||
pretrained_cfg (Optional[Dict]): configuration for pretrained weights / target dataset
|
||||
num_classes (int): num_classes for target model
|
||||
in_chans (int): in_chans for target model
|
||||
filter_fn (Optional[Callable]): state_dict filter fn for load (takes state_dict, model as args)
|
||||
strict (bool): strict load of checkpoint
|
||||
|
||||
"""
|
||||
pretrained_cfg = pretrained_cfg or getattr(model, 'pretrained_cfg', None)
|
||||
if not pretrained_cfg:
|
||||
_logger.warning("Invalid pretrained config, cannot load weights.")
|
||||
return
|
||||
|
||||
load_from, pretrained_loc = _resolve_pretrained_source(pretrained_cfg)
|
||||
if load_from == 'file':
|
||||
_logger.info(f'Loading pretrained weights from file ({pretrained_loc})')
|
||||
state_dict = load_state_dict(pretrained_loc)
|
||||
elif load_from == 'url':
|
||||
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
|
||||
if pretrained_cfg.get('custom_load', False):
|
||||
pretrained_loc = download_cached_file(
|
||||
pretrained_loc,
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
)
|
||||
model.load_pretrained(pretrained_loc)
|
||||
return
|
||||
else:
|
||||
state_dict = load_state_dict_from_url(
|
||||
pretrained_loc,
|
||||
map_location='cpu',
|
||||
progress=_DOWNLOAD_PROGRESS,
|
||||
check_hash=_CHECK_HASH,
|
||||
)
|
||||
elif load_from == 'hf-hub':
|
||||
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
|
||||
if isinstance(pretrained_loc, (list, tuple)):
|
||||
state_dict = load_state_dict_from_hf(*pretrained_loc)
|
||||
else:
|
||||
state_dict = load_state_dict_from_hf(pretrained_loc)
|
||||
else:
|
||||
_logger.warning("No pretrained weights exist or were found for this model. Using random initialization.")
|
||||
return
|
||||
|
||||
if filter_fn is not None:
|
||||
try:
|
||||
state_dict = filter_fn(state_dict, model)
|
||||
except TypeError as e:
|
||||
# for backwards compat with filter fn that take one arg
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
input_convs = pretrained_cfg.get('first_conv', None)
|
||||
if input_convs is not None and in_chans != 3:
|
||||
if isinstance(input_convs, str):
|
||||
input_convs = (input_convs,)
|
||||
for input_conv_name in input_convs:
|
||||
weight_name = input_conv_name + '.weight'
|
||||
try:
|
||||
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
|
||||
_logger.info(
|
||||
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
|
||||
except NotImplementedError as e:
|
||||
del state_dict[weight_name]
|
||||
strict = False
|
||||
_logger.warning(
|
||||
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
|
||||
|
||||
classifiers = pretrained_cfg.get('classifier', None)
|
||||
label_offset = pretrained_cfg.get('label_offset', 0)
|
||||
if classifiers is not None:
|
||||
if isinstance(classifiers, str):
|
||||
classifiers = (classifiers,)
|
||||
if num_classes != pretrained_cfg['num_classes']:
|
||||
for classifier_name in classifiers:
|
||||
# completely discard fully connected if model num_classes doesn't match pretrained weights
|
||||
state_dict.pop(classifier_name + '.weight', None)
|
||||
state_dict.pop(classifier_name + '.bias', None)
|
||||
strict = False
|
||||
elif label_offset > 0:
|
||||
for classifier_name in classifiers:
|
||||
# special case for pretrained weights with an extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def pretrained_cfg_for_features(pretrained_cfg):
|
||||
pretrained_cfg = deepcopy(pretrained_cfg)
|
||||
# remove default pretrained cfg fields that don't have much relevance for feature backbone
|
||||
to_remove = ('num_classes', 'classifier', 'global_pool') # add default final pool size?
|
||||
for tr in to_remove:
|
||||
pretrained_cfg.pop(tr, None)
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
def _filter_kwargs(kwargs, names):
|
||||
if not kwargs or not names:
|
||||
return
|
||||
for n in names:
|
||||
kwargs.pop(n, None)
|
||||
|
||||
|
||||
def _update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter):
|
||||
""" Update the default_cfg and kwargs before passing to model
|
||||
|
||||
Args:
|
||||
pretrained_cfg: input pretrained cfg (updated in-place)
|
||||
kwargs: keyword args passed to model build fn (updated in-place)
|
||||
kwargs_filter: keyword arg keys that must be removed before model __init__
|
||||
"""
|
||||
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
|
||||
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
|
||||
if pretrained_cfg.get('fixed_input_size', False):
|
||||
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
|
||||
default_kwarg_names += ('img_size',)
|
||||
|
||||
for n in default_kwarg_names:
|
||||
# for legacy reasons, model __init__args uses img_size + in_chans as separate args while
|
||||
# pretrained_cfg has one input_size=(C, H ,W) entry
|
||||
if n == 'img_size':
|
||||
input_size = pretrained_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[-2:])
|
||||
elif n == 'in_chans':
|
||||
input_size = pretrained_cfg.get('input_size', None)
|
||||
if input_size is not None:
|
||||
assert len(input_size) == 3
|
||||
kwargs.setdefault(n, input_size[0])
|
||||
else:
|
||||
default_val = pretrained_cfg.get(n, None)
|
||||
if default_val is not None:
|
||||
kwargs.setdefault(n, pretrained_cfg[n])
|
||||
|
||||
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
|
||||
_filter_kwargs(kwargs, names=kwargs_filter)
|
||||
|
||||
|
||||
def resolve_pretrained_cfg(
|
||||
variant: str,
|
||||
pretrained_cfg=None,
|
||||
pretrained_cfg_overlay=None,
|
||||
) -> PretrainedCfg:
|
||||
model_with_tag = variant
|
||||
pretrained_tag = None
|
||||
if pretrained_cfg:
|
||||
if isinstance(pretrained_cfg, dict):
|
||||
# pretrained_cfg dict passed as arg, validate by converting to PretrainedCfg
|
||||
pretrained_cfg = PretrainedCfg(**pretrained_cfg)
|
||||
elif isinstance(pretrained_cfg, str):
|
||||
pretrained_tag = pretrained_cfg
|
||||
pretrained_cfg = None
|
||||
|
||||
# fallback to looking up pretrained cfg in model registry by variant identifier
|
||||
if not pretrained_cfg:
|
||||
if pretrained_tag:
|
||||
model_with_tag = '.'.join([variant, pretrained_tag])
|
||||
pretrained_cfg = get_pretrained_cfg(model_with_tag)
|
||||
|
||||
if not pretrained_cfg:
|
||||
_logger.warning(
|
||||
f"No pretrained configuration specified for {model_with_tag} model. Using a default."
|
||||
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
|
||||
pretrained_cfg = PretrainedCfg() # instance with defaults
|
||||
|
||||
pretrained_cfg_overlay = pretrained_cfg_overlay or {}
|
||||
if not pretrained_cfg.architecture:
|
||||
pretrained_cfg_overlay.setdefault('architecture', variant)
|
||||
pretrained_cfg = dataclasses.replace(pretrained_cfg, **pretrained_cfg_overlay)
|
||||
|
||||
return pretrained_cfg
|
||||
|
||||
|
||||
def build_model_with_cfg(
|
||||
model_cls: Callable,
|
||||
variant: str,
|
||||
pretrained: bool,
|
||||
pretrained_cfg: Optional[Dict] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict] = None,
|
||||
model_cfg: Optional[Any] = None,
|
||||
feature_cfg: Optional[Dict] = None,
|
||||
pretrained_strict: bool = True,
|
||||
pretrained_filter_fn: Optional[Callable] = None,
|
||||
kwargs_filter: Optional[Tuple[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
""" Build model with specified default_cfg and optional model_cfg
|
||||
|
||||
This helper fn aids in the construction of a model including:
|
||||
* handling default_cfg and associated pretrained weight loading
|
||||
* passing through optional model_cfg for models with config based arch spec
|
||||
* features_only model adaptation
|
||||
* pruning config / model adaptation
|
||||
|
||||
Args:
|
||||
model_cls (nn.Module): model class
|
||||
variant (str): model variant name
|
||||
pretrained (bool): load pretrained weights
|
||||
pretrained_cfg (dict): model's pretrained weight/task config
|
||||
model_cfg (Optional[Dict]): model's architecture config
|
||||
feature_cfg (Optional[Dict]: feature extraction adapter config
|
||||
pretrained_strict (bool): load pretrained weights strictly
|
||||
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
|
||||
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
|
||||
**kwargs: model args passed through to model __init__
|
||||
"""
|
||||
pruned = kwargs.pop('pruned', False)
|
||||
features = False
|
||||
feature_cfg = feature_cfg or {}
|
||||
|
||||
# resolve and update model pretrained config and model kwargs
|
||||
pretrained_cfg = resolve_pretrained_cfg(
|
||||
variant,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay
|
||||
)
|
||||
|
||||
# FIXME converting back to dict, PretrainedCfg use should be propagated further, but not into model
|
||||
pretrained_cfg = pretrained_cfg.to_dict()
|
||||
|
||||
_update_default_kwargs(pretrained_cfg, kwargs, kwargs_filter)
|
||||
|
||||
# Setup for feature extraction wrapper done at end of this fn
|
||||
if kwargs.pop('features_only', False):
|
||||
features = True
|
||||
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
|
||||
if 'out_indices' in kwargs:
|
||||
feature_cfg['out_indices'] = kwargs.pop('out_indices')
|
||||
|
||||
# Instantiate the model
|
||||
if model_cfg is None:
|
||||
model = model_cls(**kwargs)
|
||||
else:
|
||||
model = model_cls(cfg=model_cfg, **kwargs)
|
||||
model.pretrained_cfg = pretrained_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
if pruned:
|
||||
model = adapt_model_from_file(model, variant)
|
||||
|
||||
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
|
||||
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
num_classes=num_classes_pretrained,
|
||||
in_chans=kwargs.get('in_chans', 3),
|
||||
filter_fn=pretrained_filter_fn,
|
||||
strict=pretrained_strict,
|
||||
)
|
||||
|
||||
# Wrap the model in a feature extraction module if enabled
|
||||
if features:
|
||||
feature_cls = FeatureListNet
|
||||
if 'feature_cls' in feature_cfg:
|
||||
feature_cls = feature_cfg.pop('feature_cls')
|
||||
if isinstance(feature_cls, str):
|
||||
feature_cls = feature_cls.lower()
|
||||
if 'hook' in feature_cls:
|
||||
feature_cls = FeatureHookNet
|
||||
elif feature_cls == 'fx':
|
||||
feature_cls = FeatureGraphNet
|
||||
else:
|
||||
assert False, f'Unknown feature class {feature_cls}'
|
||||
model = feature_cls(model, **feature_cfg)
|
||||
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back default_cfg
|
||||
model.default_cfg = model.pretrained_cfg # alias for backwards compat
|
||||
|
||||
return model
|
@ -1,103 +0,0 @@
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from urllib.parse import urlsplit
|
||||
|
||||
from timm.layers import set_layer_config
|
||||
from ._pretrained import PretrainedCfg, split_model_name_tag
|
||||
from ._helpers import load_checkpoint
|
||||
from ._hub import load_model_config_from_hf
|
||||
from ._registry import is_model, model_entrypoint
|
||||
|
||||
|
||||
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
||||
|
||||
|
||||
def parse_model_name(model_name):
|
||||
if model_name.startswith('hf_hub'):
|
||||
# NOTE for backwards compat, deprecate hf_hub use
|
||||
model_name = model_name.replace('hf_hub', 'hf-hub')
|
||||
parsed = urlsplit(model_name)
|
||||
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
||||
if parsed.scheme == 'hf-hub':
|
||||
# FIXME may use fragment as revision, currently `@` in URI path
|
||||
return parsed.scheme, parsed.path
|
||||
else:
|
||||
model_name = os.path.split(parsed.path)[-1]
|
||||
return 'timm', model_name
|
||||
|
||||
|
||||
def safe_model_name(model_name, remove_source=True):
|
||||
# return a filename / path safe model name
|
||||
def make_safe(name):
|
||||
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
||||
if remove_source:
|
||||
model_name = parse_model_name(model_name)[-1]
|
||||
return make_safe(model_name)
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name: str,
|
||||
pretrained: bool = False,
|
||||
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
||||
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
||||
checkpoint_path: str = '',
|
||||
scriptable: Optional[bool] = None,
|
||||
exportable: Optional[bool] = None,
|
||||
no_jit: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a model
|
||||
|
||||
Lookup model's entrypoint function and pass relevant args to create a new model.
|
||||
|
||||
**kwargs will be passed through entrypoint fn to timm.models.build_model_with_cfg()
|
||||
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
||||
|
||||
Args:
|
||||
model_name (str): name of model to instantiate
|
||||
pretrained (bool): load pretrained ImageNet-1k weights if true
|
||||
pretrained_cfg (Union[str, dict, PretrainedCfg]): pass in external pretrained_cfg for model
|
||||
pretrained_cfg_overlay (dict): replace key-values in base pretrained_cfg with these
|
||||
checkpoint_path (str): path of checkpoint to load _after_ the model is initialized
|
||||
scriptable (bool): set layer config so that model is jit scriptable (not working for all models yet)
|
||||
exportable (bool): set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet)
|
||||
no_jit (bool): set layer config so that model doesn't utilize jit scripted layers (so far activations only)
|
||||
|
||||
Keyword Args:
|
||||
drop_rate (float): dropout rate for training (default: 0.0)
|
||||
global_pool (str): global pool type (default: 'avg')
|
||||
**: other kwargs are consumed by builder or model __init__()
|
||||
"""
|
||||
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
||||
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
||||
# non-supporting models don't break and default args remain in effect.
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
model_source, model_name = parse_model_name(model_name)
|
||||
if model_source == 'hf-hub':
|
||||
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
||||
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
||||
# load model weights + pretrained_cfg from Hugging Face hub.
|
||||
pretrained_cfg, model_name = load_model_config_from_hf(model_name)
|
||||
else:
|
||||
model_name, pretrained_tag = split_model_name_tag(model_name)
|
||||
if not pretrained_cfg:
|
||||
# a valid pretrained_cfg argument takes priority over tag in model name
|
||||
pretrained_cfg = pretrained_tag
|
||||
|
||||
if not is_model(model_name):
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
create_fn = model_entrypoint(model_name)
|
||||
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
||||
model = create_fn(
|
||||
pretrained=pretrained,
|
||||
pretrained_cfg=pretrained_cfg,
|
||||
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if checkpoint_path:
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
@ -1,359 +0,0 @@
|
||||
""" PyTorch Feature Extraction Helpers
|
||||
|
||||
A collection of classes, functions, modules to help extract features from models
|
||||
and provide a common interface for describing them.
|
||||
|
||||
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
|
||||
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
from collections import OrderedDict, defaultdict
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
|
||||
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
||||
|
||||
|
||||
class FeatureInfo:
|
||||
|
||||
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
||||
prev_reduction = 1
|
||||
for fi in feature_info:
|
||||
# sanity check the mandatory fields, there may be additional fields depending on the model
|
||||
assert 'num_chs' in fi and fi['num_chs'] > 0
|
||||
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
||||
prev_reduction = fi['reduction']
|
||||
assert 'module' in fi
|
||||
self.out_indices = out_indices
|
||||
self.info = feature_info
|
||||
|
||||
def from_other(self, out_indices: Tuple[int]):
|
||||
return FeatureInfo(deepcopy(self.info), out_indices)
|
||||
|
||||
def get(self, key, idx=None):
|
||||
""" Get value by key at specified index (indices)
|
||||
if idx == None, returns value for key at each output index
|
||||
if idx is an integer, return value for that feature module index (ignoring output indices)
|
||||
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
||||
"""
|
||||
if idx is None:
|
||||
return [self.info[i][key] for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i][key] for i in idx]
|
||||
else:
|
||||
return self.info[idx][key]
|
||||
|
||||
def get_dicts(self, keys=None, idx=None):
|
||||
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
||||
"""
|
||||
if idx is None:
|
||||
if keys is None:
|
||||
return [self.info[i] for i in self.out_indices]
|
||||
else:
|
||||
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
||||
if isinstance(idx, (tuple, list)):
|
||||
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
||||
else:
|
||||
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
||||
|
||||
def channels(self, idx=None):
|
||||
""" feature channels accessor
|
||||
"""
|
||||
return self.get('num_chs', idx)
|
||||
|
||||
def reduction(self, idx=None):
|
||||
""" feature reduction (output stride) accessor
|
||||
"""
|
||||
return self.get('reduction', idx)
|
||||
|
||||
def module_name(self, idx=None):
|
||||
""" feature module name accessor
|
||||
"""
|
||||
return self.get('module', idx)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.info[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.info)
|
||||
|
||||
|
||||
class FeatureHooks:
|
||||
""" Feature Hook Helper
|
||||
|
||||
This module helps with the setup and extraction of hooks for extracting features from
|
||||
internal nodes in a model by node name.
|
||||
|
||||
FIXME This works well in eager Python but needs redesign for torchscript.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hooks: Sequence[str],
|
||||
named_modules: dict,
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
# setup feature hooks
|
||||
self._feature_outputs = defaultdict(OrderedDict)
|
||||
modules = {k: v for k, v in named_modules}
|
||||
for i, h in enumerate(hooks):
|
||||
hook_name = h['module']
|
||||
m = modules[hook_name]
|
||||
hook_id = out_map[i] if out_map else hook_name
|
||||
hook_fn = partial(self._collect_output_hook, hook_id)
|
||||
hook_type = h.get('hook_type', default_hook_type)
|
||||
if hook_type == 'forward_pre':
|
||||
m.register_forward_pre_hook(hook_fn)
|
||||
elif hook_type == 'forward':
|
||||
m.register_forward_hook(hook_fn)
|
||||
else:
|
||||
assert False, "Unsupported hook type"
|
||||
|
||||
def _collect_output_hook(self, hook_id, *args):
|
||||
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
||||
if isinstance(x, tuple):
|
||||
x = x[0] # unwrap input tuple
|
||||
self._feature_outputs[x.device][hook_id] = x
|
||||
|
||||
def get_output(self, device) -> Dict[str, torch.tensor]:
|
||||
output = self._feature_outputs[device]
|
||||
self._feature_outputs[device] = OrderedDict() # clear after reading
|
||||
return output
|
||||
|
||||
|
||||
def _module_list(module, flatten_sequential=False):
|
||||
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
||||
ml = []
|
||||
for name, module in module.named_children():
|
||||
if flatten_sequential and isinstance(module, nn.Sequential):
|
||||
# first level of Sequential containers is flattened into containing model
|
||||
for child_name, child_module in module.named_children():
|
||||
combined = [name, child_name]
|
||||
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
||||
else:
|
||||
ml.append((name, name, module))
|
||||
return ml
|
||||
|
||||
|
||||
def _get_feature_info(net, out_indices):
|
||||
feature_info = getattr(net, 'feature_info')
|
||||
if isinstance(feature_info, FeatureInfo):
|
||||
return feature_info.from_other(out_indices)
|
||||
elif isinstance(feature_info, (list, tuple)):
|
||||
return FeatureInfo(net.feature_info, out_indices)
|
||||
else:
|
||||
assert False, "Provided feature_info is not valid"
|
||||
|
||||
|
||||
def _get_return_layers(feature_info, out_map):
|
||||
module_names = feature_info.module_name()
|
||||
return_layers = {}
|
||||
for i, name in enumerate(module_names):
|
||||
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
||||
return return_layers
|
||||
|
||||
|
||||
class FeatureDictNet(nn.ModuleDict):
|
||||
""" Feature extractor with OrderedDict return
|
||||
|
||||
Wrap a model and extract features as specified by the out indices, the network is
|
||||
partially re-built from contained modules.
|
||||
|
||||
There is a strong assumption that the modules have been registered into the model in the same
|
||||
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
||||
trivial modules like `self.relu = nn.ReLU`.
|
||||
|
||||
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
||||
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
||||
All Sequential containers that are directly assigned to the original model will have their
|
||||
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureDictNet, self).__init__()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.concat = feature_concat
|
||||
self.grad_checkpointing = False
|
||||
self.return_layers = {}
|
||||
|
||||
return_layers = _get_return_layers(self.feature_info, out_map)
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = set(return_layers.keys())
|
||||
layers = OrderedDict()
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
if old_name in remaining:
|
||||
# return id has to be consistently str type for torchscript
|
||||
self.return_layers[new_name] = str(return_layers[old_name])
|
||||
remaining.remove(old_name)
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining and len(self.return_layers) == len(return_layers), \
|
||||
f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
||||
out = OrderedDict()
|
||||
for i, (name, module) in enumerate(self.items()):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# Skipping checkpoint of first module because need a gradient at input
|
||||
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
||||
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
||||
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
||||
x = module(x) if first_or_last_module else checkpoint(module, x)
|
||||
else:
|
||||
x = module(x)
|
||||
|
||||
if name in self.return_layers:
|
||||
out_id = self.return_layers[name]
|
||||
if isinstance(x, (tuple, list)):
|
||||
# If model tap is a tuple or list, concat or select first element
|
||||
# FIXME this may need to be more generic / flexible for some nets
|
||||
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
||||
else:
|
||||
out[out_id] = x
|
||||
return out
|
||||
|
||||
def forward(self, x) -> Dict[str, torch.Tensor]:
|
||||
return self._collect(x)
|
||||
|
||||
|
||||
class FeatureListNet(FeatureDictNet):
|
||||
""" Feature extractor with list return
|
||||
|
||||
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
feature_concat: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
||||
first element e.g. `x[0]`
|
||||
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
||||
"""
|
||||
super(FeatureListNet, self).__init__(
|
||||
model,
|
||||
out_indices=out_indices,
|
||||
feature_concat=feature_concat,
|
||||
flatten_sequential=flatten_sequential,
|
||||
)
|
||||
|
||||
def forward(self, x) -> (List[torch.Tensor]):
|
||||
return list(self._collect(x).values())
|
||||
|
||||
|
||||
class FeatureHookNet(nn.ModuleDict):
|
||||
""" FeatureHookNet
|
||||
|
||||
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
||||
|
||||
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
||||
network in any way.
|
||||
|
||||
If `no_rewrite` is False, the model will be re-written as in the
|
||||
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
||||
|
||||
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
||||
out_map: Sequence[Union[int, str]] = None,
|
||||
out_as_dict: bool = False,
|
||||
no_rewrite: bool = False,
|
||||
flatten_sequential: bool = False,
|
||||
default_hook_type: str = 'forward',
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model: Model from which to extract features.
|
||||
out_indices: Output indices of the model features to extract.
|
||||
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
||||
out_as_dict: Output features as a dict.
|
||||
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
||||
flatten_sequential arg must also be False if this is set True.
|
||||
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
||||
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
||||
"""
|
||||
super(FeatureHookNet, self).__init__()
|
||||
assert not torch.jit.is_scripting()
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
self.out_as_dict = out_as_dict
|
||||
self.grad_checkpointing = False
|
||||
|
||||
layers = OrderedDict()
|
||||
hooks = []
|
||||
if no_rewrite:
|
||||
assert not flatten_sequential
|
||||
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
||||
model.reset_classifier(0)
|
||||
layers['body'] = model
|
||||
hooks.extend(self.feature_info.get_dicts())
|
||||
else:
|
||||
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
||||
remaining = {
|
||||
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
||||
for f in self.feature_info.get_dicts()
|
||||
}
|
||||
for new_name, old_name, module in modules:
|
||||
layers[new_name] = module
|
||||
for fn, fm in module.named_modules(prefix=old_name):
|
||||
if fn in remaining:
|
||||
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
||||
del remaining[fn]
|
||||
if not remaining:
|
||||
break
|
||||
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
||||
self.update(layers)
|
||||
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
||||
|
||||
def set_grad_checkpointing(self, enable: bool = True):
|
||||
self.grad_checkpointing = enable
|
||||
|
||||
def forward(self, x):
|
||||
for i, (name, module) in enumerate(self.items()):
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
# Skipping checkpoint of first module because need a gradient at input
|
||||
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
||||
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
||||
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
||||
x = module(x) if first_or_last_module else checkpoint(module, x)
|
||||
else:
|
||||
x = module(x)
|
||||
out = self.hooks.get_output(x.device)
|
||||
return out if self.out_as_dict else list(out.values())
|
@ -1,110 +0,0 @@
|
||||
""" PyTorch FX Based Feature Extraction Helpers
|
||||
Using https://pytorch.org/vision/stable/feature_extraction.html
|
||||
"""
|
||||
from typing import Callable, List, Dict, Union, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from ._features import _get_feature_info
|
||||
|
||||
try:
|
||||
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
||||
has_fx_feature_extraction = True
|
||||
except ImportError:
|
||||
has_fx_feature_extraction = False
|
||||
|
||||
# Layers we went to treat as leaf modules
|
||||
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
||||
from timm.layers.non_local_attn import BilinearAttnTransform
|
||||
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
||||
|
||||
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
||||
# BUT modules from timm.models should use the registration mechanism below
|
||||
_leaf_modules = {
|
||||
BilinearAttnTransform, # reason: flow control t <= 1
|
||||
# Reason: get_same_padding has a max which raises a control flow error
|
||||
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
||||
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
|
||||
}
|
||||
|
||||
try:
|
||||
from timm.layers import InplaceAbn
|
||||
_leaf_modules.add(InplaceAbn)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['register_notrace_module', 'register_notrace_function', 'create_feature_extractor',
|
||||
'FeatureGraphNet', 'GraphExtractNet']
|
||||
|
||||
|
||||
def register_notrace_module(module: Type[nn.Module]):
|
||||
"""
|
||||
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
||||
"""
|
||||
_leaf_modules.add(module)
|
||||
return module
|
||||
|
||||
|
||||
# Functions we want to autowrap (treat them as leaves)
|
||||
_autowrap_functions = set()
|
||||
|
||||
|
||||
def register_notrace_function(func: Callable):
|
||||
"""
|
||||
Decorator for functions which ought not to be traced through
|
||||
"""
|
||||
_autowrap_functions.add(func)
|
||||
return func
|
||||
|
||||
|
||||
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
return _create_feature_extractor(
|
||||
model, return_nodes,
|
||||
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
||||
)
|
||||
|
||||
|
||||
class FeatureGraphNet(nn.Module):
|
||||
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
||||
"""
|
||||
def __init__(self, model, out_indices, out_map=None):
|
||||
super().__init__()
|
||||
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
||||
self.feature_info = _get_feature_info(model, out_indices)
|
||||
if out_map is not None:
|
||||
assert len(out_map) == len(out_indices)
|
||||
return_nodes = {
|
||||
info['module']: out_map[i] if out_map is not None else info['module']
|
||||
for i, info in enumerate(self.feature_info) if i in out_indices}
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x):
|
||||
return list(self.graph_module(x).values())
|
||||
|
||||
|
||||
class GraphExtractNet(nn.Module):
|
||||
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
||||
NOTE:
|
||||
* one can use feature_extractor directly if dictionary output is desired
|
||||
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
||||
metadata for builtin feature extraction mode
|
||||
* create_feature_extractor can be used directly if dictionary output is acceptable
|
||||
|
||||
Args:
|
||||
model: model to extract features from
|
||||
return_nodes: node names to return features from (dict or list)
|
||||
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
||||
"""
|
||||
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
||||
super().__init__()
|
||||
self.squeeze_out = squeeze_out
|
||||
self.graph_module = create_feature_extractor(model, return_nodes)
|
||||
|
||||
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
||||
out = list(self.graph_module(x).values())
|
||||
if self.squeeze_out and len(out) == 1:
|
||||
return out[0]
|
||||
return out
|
@ -1,126 +0,0 @@
|
||||
""" Model creation / weight loading / state_dict helpers
|
||||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
import timm.models._builder
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
||||
|
||||
|
||||
def clean_state_dict(state_dict):
|
||||
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
||||
cleaned_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] if k.startswith('module.') else k
|
||||
cleaned_state_dict[name] = v
|
||||
return cleaned_state_dict
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_path, use_ema=True):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
# Check if safetensors or not and load weights accordingly
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
|
||||
state_dict_key = ''
|
||||
if isinstance(checkpoint, dict):
|
||||
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
||||
state_dict_key = 'state_dict_ema'
|
||||
elif use_ema and checkpoint.get('model_ema', None) is not None:
|
||||
state_dict_key = 'model_ema'
|
||||
elif 'state_dict' in checkpoint:
|
||||
state_dict_key = 'state_dict'
|
||||
elif 'model' in checkpoint:
|
||||
state_dict_key = 'model'
|
||||
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
|
||||
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
||||
return state_dict
|
||||
else:
|
||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
||||
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
||||
# numpy checkpoint, try to load via model specific load_pretrained fn
|
||||
if hasattr(model, 'load_pretrained'):
|
||||
timm.models._model_builder.load_pretrained(checkpoint_path)
|
||||
else:
|
||||
raise NotImplementedError('Model cannot load numpy checkpoint')
|
||||
return
|
||||
state_dict = load_state_dict(checkpoint_path, use_ema)
|
||||
if remap:
|
||||
state_dict = remap_checkpoint(model, state_dict)
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
||||
return incompatible_keys
|
||||
|
||||
|
||||
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
||||
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
||||
This assumes models (and originating state dict) were created with params registered in same order.
|
||||
"""
|
||||
out_dict = {}
|
||||
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
||||
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
if va.shape != vb.shape:
|
||||
if allow_reshape:
|
||||
vb = vb.reshape(va.shape)
|
||||
else:
|
||||
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
||||
out_dict[ka] = vb
|
||||
return out_dict
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring model state from checkpoint...')
|
||||
state_dict = clean_state_dict(checkpoint['state_dict'])
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
if optimizer is not None and 'optimizer' in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring optimizer state from checkpoint...')
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
|
||||
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
||||
if log_info:
|
||||
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
||||
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
||||
|
||||
if 'epoch' in checkpoint:
|
||||
resume_epoch = checkpoint['epoch']
|
||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
if log_info:
|
||||
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return resume_epoch
|
||||
else:
|
||||
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
@ -1,370 +0,0 @@
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Iterable, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
||||
|
||||
try:
|
||||
from torch.hub import get_dir
|
||||
except ImportError:
|
||||
from torch.hub import _get_torch_home as get_dir
|
||||
|
||||
try:
|
||||
import safetensors.torch
|
||||
_has_safetensors = True
|
||||
except ImportError:
|
||||
_has_safetensors = False
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from typing import Literal
|
||||
else:
|
||||
from typing_extensions import Literal
|
||||
|
||||
from timm import __version__
|
||||
from timm.models._pretrained import filter_pretrained_cfg
|
||||
|
||||
try:
|
||||
from huggingface_hub import (
|
||||
create_repo, get_hf_file_metadata,
|
||||
hf_hub_download, hf_hub_url,
|
||||
repo_type_and_id_from_hf_id, upload_folder)
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
||||
_has_hf_hub = True
|
||||
except ImportError:
|
||||
hf_hub_download = None
|
||||
_has_hf_hub = False
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
||||
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
||||
|
||||
# Default name for a weights file hosted on the Huggingface Hub.
|
||||
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
||||
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
||||
|
||||
|
||||
def get_cache_dir(child_dir=''):
|
||||
"""
|
||||
Returns the location of the directory where models are cached (and creates it if necessary).
|
||||
"""
|
||||
# Issue warning to move data if old env is set
|
||||
if os.getenv('TORCH_MODEL_ZOO'):
|
||||
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
||||
|
||||
hub_dir = get_dir()
|
||||
child_dir = () if not child_dir else (child_dir,)
|
||||
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
return model_dir
|
||||
|
||||
|
||||
def download_cached_file(url, check_hash=True, progress=False):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if not os.path.exists(cached_file):
|
||||
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
||||
hash_prefix = None
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
||||
return cached_file
|
||||
|
||||
|
||||
def check_cached_file(url, check_hash=True):
|
||||
if isinstance(url, (list, tuple)):
|
||||
url, filename = url
|
||||
else:
|
||||
parts = urlparse(url)
|
||||
filename = os.path.basename(parts.path)
|
||||
cached_file = os.path.join(get_cache_dir(), filename)
|
||||
if os.path.exists(cached_file):
|
||||
if check_hash:
|
||||
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
||||
hash_prefix = r.group(1) if r else None
|
||||
if hash_prefix:
|
||||
with open(cached_file, 'rb') as f:
|
||||
hd = hashlib.sha256(f.read()).hexdigest()
|
||||
if hd[:len(hash_prefix)] != hash_prefix:
|
||||
return False
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_hf_hub(necessary=False):
|
||||
if not _has_hf_hub and necessary:
|
||||
# if no HF Hub module installed, and it is necessary to continue, raise error
|
||||
raise RuntimeError(
|
||||
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
||||
return _has_hf_hub
|
||||
|
||||
|
||||
def hf_split(hf_id: str):
|
||||
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
|
||||
rev_split = hf_id.split('@')
|
||||
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
||||
hf_model_id = rev_split[0]
|
||||
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
||||
return hf_model_id, hf_revision
|
||||
|
||||
|
||||
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
|
||||
def download_from_hf(model_id: str, filename: str):
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
||||
|
||||
|
||||
def load_model_config_from_hf(model_id: str):
|
||||
assert has_hf_hub(True)
|
||||
cached_file = download_from_hf(model_id, 'config.json')
|
||||
|
||||
hf_config = load_cfg_from_json(cached_file)
|
||||
if 'pretrained_cfg' not in hf_config:
|
||||
# old form, pull pretrain_cfg out of the base dict
|
||||
pretrained_cfg = hf_config
|
||||
hf_config = {}
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
||||
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
|
||||
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
|
||||
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
||||
pretrained_cfg = hf_config['pretrained_cfg']
|
||||
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
||||
pretrained_cfg['source'] = 'hf-hub'
|
||||
|
||||
# model should be created with base config num_classes if its exist
|
||||
if 'num_classes' in hf_config:
|
||||
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
||||
|
||||
# label meta-data in base config overrides saved pretrained_cfg on load
|
||||
if 'label_names' in hf_config:
|
||||
pretrained_cfg['label_names'] = hf_config.pop('label_names')
|
||||
if 'label_descriptions' in hf_config:
|
||||
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
|
||||
|
||||
model_name = hf_config['architecture']
|
||||
return pretrained_cfg, model_name
|
||||
|
||||
|
||||
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
||||
assert has_hf_hub(True)
|
||||
hf_model_id, hf_revision = hf_split(model_id)
|
||||
|
||||
# Look for .safetensors alternatives and load from it if it exists
|
||||
if _has_safetensors:
|
||||
for safe_filename in _get_safe_alternatives(filename):
|
||||
try:
|
||||
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
||||
_logger.info(
|
||||
f"[{model_id}] Safe alternative available for '{filename}' "
|
||||
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
||||
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
||||
except EntryNotFoundError:
|
||||
pass
|
||||
|
||||
# Otherwise, load using pytorch.load
|
||||
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
||||
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
||||
return torch.load(cached_file, map_location='cpu')
|
||||
|
||||
|
||||
def save_config_for_hf(
|
||||
model,
|
||||
config_path: str,
|
||||
model_config: Optional[dict] = None
|
||||
):
|
||||
model_config = model_config or {}
|
||||
hf_config = {}
|
||||
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
||||
# set some values at root config level
|
||||
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
||||
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
|
||||
hf_config['num_features'] = model_config.get('num_features', model.num_features)
|
||||
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
|
||||
if isinstance(global_pool_type, str) and global_pool_type:
|
||||
hf_config['global_pool'] = global_pool_type
|
||||
|
||||
if 'labels' in model_config:
|
||||
_logger.warning(
|
||||
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
|
||||
" Renaming provided 'labels' field to 'label_names'.")
|
||||
model_config.setdefault('label_names', model_config.pop('labels'))
|
||||
|
||||
label_names = model_config.pop('label_names', None)
|
||||
if label_names:
|
||||
assert isinstance(label_names, (dict, list, tuple))
|
||||
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
|
||||
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
|
||||
hf_config['label_names'] = label_names
|
||||
|
||||
label_descriptions = model_config.pop('label_descriptions', None)
|
||||
if label_descriptions:
|
||||
assert isinstance(label_descriptions, dict)
|
||||
# maps label names -> descriptions
|
||||
hf_config['label_descriptions'] = label_descriptions
|
||||
|
||||
hf_config['pretrained_cfg'] = pretrained_cfg
|
||||
hf_config.update(model_config)
|
||||
|
||||
with config_path.open('w') as f:
|
||||
json.dump(hf_config, f, indent=2)
|
||||
|
||||
|
||||
def save_for_hf(
|
||||
model,
|
||||
save_directory: str,
|
||||
model_config: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
assert has_hf_hub(True)
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
||||
tensors = model.state_dict()
|
||||
if safe_serialization is True or safe_serialization == "both":
|
||||
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
||||
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
||||
if safe_serialization is False or safe_serialization == "both":
|
||||
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
||||
|
||||
config_path = save_directory / 'config.json'
|
||||
save_config_for_hf(model, config_path, model_config=model_config)
|
||||
|
||||
|
||||
def push_to_hf_hub(
|
||||
model,
|
||||
repo_id: str,
|
||||
commit_message: str = 'Add model',
|
||||
token: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
private: bool = False,
|
||||
create_pr: bool = False,
|
||||
model_config: Optional[dict] = None,
|
||||
model_card: Optional[dict] = None,
|
||||
safe_serialization: Union[bool, Literal["both"]] = False,
|
||||
):
|
||||
"""
|
||||
Arguments:
|
||||
(...)
|
||||
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
||||
Can be set to `"both"` in order to push both safe and unsafe weights.
|
||||
"""
|
||||
# Create repo if it doesn't exist yet
|
||||
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
||||
|
||||
# Infer complete repo_id from repo_url
|
||||
# Can be different from the input `repo_id` if repo_owner was implicit
|
||||
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
||||
repo_id = f"{repo_owner}/{repo_name}"
|
||||
|
||||
# Check if README file already exist in repo
|
||||
try:
|
||||
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
||||
has_readme = True
|
||||
except EntryNotFoundError:
|
||||
has_readme = False
|
||||
|
||||
# Dump model and push to Hub
|
||||
with TemporaryDirectory() as tmpdir:
|
||||
# Save model weights and config.
|
||||
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
|
||||
|
||||
# Add readme if it does not exist
|
||||
if not has_readme:
|
||||
model_card = model_card or {}
|
||||
model_name = repo_id.split('/')[-1]
|
||||
readme_path = Path(tmpdir) / "README.md"
|
||||
readme_text = generate_readme(model_card, model_name)
|
||||
readme_path.write_text(readme_text)
|
||||
|
||||
# Upload model and return
|
||||
return upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=tmpdir,
|
||||
revision=revision,
|
||||
create_pr=create_pr,
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
def generate_readme(model_card: dict, model_name: str):
|
||||
readme_text = "---\n"
|
||||
readme_text += "tags:\n- image-classification\n- timm\n"
|
||||
readme_text += "library_tag: timm\n"
|
||||
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
|
||||
if 'details' in model_card and 'Dataset' in model_card['details']:
|
||||
readme_text += 'datasets:\n'
|
||||
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
||||
if 'Pretrain Dataset' in model_card['details']:
|
||||
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
|
||||
readme_text += "---\n"
|
||||
readme_text += f"# Model card for {model_name}\n"
|
||||
if 'description' in model_card:
|
||||
readme_text += f"\n{model_card['description']}\n"
|
||||
if 'details' in model_card:
|
||||
readme_text += f"\n## Model Details\n"
|
||||
for k, v in model_card['details'].items():
|
||||
if isinstance(v, (list, tuple)):
|
||||
readme_text += f"- **{k}:**\n"
|
||||
for vi in v:
|
||||
readme_text += f" - {vi}\n"
|
||||
elif isinstance(v, dict):
|
||||
readme_text += f"- **{k}:**\n"
|
||||
for ki, vi in v.items():
|
||||
readme_text += f" - {ki}: {vi}\n"
|
||||
else:
|
||||
readme_text += f"- **{k}:** {v}\n"
|
||||
if 'usage' in model_card:
|
||||
readme_text += f"\n## Model Usage\n"
|
||||
readme_text += model_card['usage']
|
||||
readme_text += '\n'
|
||||
|
||||
if 'comparison' in model_card:
|
||||
readme_text += f"\n## Model Comparison\n"
|
||||
readme_text += model_card['comparison']
|
||||
readme_text += '\n'
|
||||
|
||||
if 'citation' in model_card:
|
||||
readme_text += f"\n## Citation\n"
|
||||
if not isinstance(model_card['citation'], (list, tuple)):
|
||||
citations = [model_card['citation']]
|
||||
else:
|
||||
citations = model_card['citation']
|
||||
for c in citations:
|
||||
readme_text += f"```bibtex\n{c}\n```\n"
|
||||
return readme_text
|
||||
|
||||
|
||||
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
||||
"""Returns potential safetensors alternatives for a given filename.
|
||||
|
||||
Use case:
|
||||
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
|
||||
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
|
||||
"""
|
||||
if filename == HF_WEIGHTS_NAME:
|
||||
yield HF_SAFE_WEIGHTS_NAME
|
||||
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
|
||||
return filename[:-4] + ".safetensors"
|
@ -1,258 +0,0 @@
|
||||
import collections.abc
|
||||
import math
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from itertools import chain
|
||||
from typing import Callable, Union, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
__all__ = ['model_parameters', 'named_apply', 'named_modules', 'named_modules_with_params', 'adapt_input_conv',
|
||||
'group_with_matcher', 'group_modules', 'group_parameters', 'flatten_modules', 'checkpoint_seq']
|
||||
|
||||
|
||||
def model_parameters(model, exclude_head=False):
|
||||
if exclude_head:
|
||||
# FIXME this a bit of a quick and dirty hack to skip classifier head params based on ordering
|
||||
return [p for p in model.parameters()][:-2]
|
||||
else:
|
||||
return model.parameters()
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
yield name, module
|
||||
|
||||
|
||||
def named_modules_with_params(module: nn.Module, name='', depth_first=True, include_root=False):
|
||||
if module._parameters and not depth_first and include_root:
|
||||
yield name, module
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
yield from named_modules_with_params(
|
||||
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if module._parameters and depth_first and include_root:
|
||||
yield name, module
|
||||
|
||||
|
||||
MATCH_PREV_GROUP = (99999,)
|
||||
|
||||
|
||||
def group_with_matcher(
|
||||
named_objects,
|
||||
group_matcher: Union[Dict, Callable],
|
||||
output_values: bool = False,
|
||||
reverse: bool = False
|
||||
):
|
||||
if isinstance(group_matcher, dict):
|
||||
# dictionary matcher contains a dict of raw-string regex expr that must be compiled
|
||||
compiled = []
|
||||
for group_ordinal, (group_name, mspec) in enumerate(group_matcher.items()):
|
||||
if mspec is None:
|
||||
continue
|
||||
# map all matching specifications into 3-tuple (compiled re, prefix, suffix)
|
||||
if isinstance(mspec, (tuple, list)):
|
||||
# multi-entry match specifications require each sub-spec to be a 2-tuple (re, suffix)
|
||||
for sspec in mspec:
|
||||
compiled += [(re.compile(sspec[0]), (group_ordinal,), sspec[1])]
|
||||
else:
|
||||
compiled += [(re.compile(mspec), (group_ordinal,), None)]
|
||||
group_matcher = compiled
|
||||
|
||||
def _get_grouping(name):
|
||||
if isinstance(group_matcher, (list, tuple)):
|
||||
for match_fn, prefix, suffix in group_matcher:
|
||||
r = match_fn.match(name)
|
||||
if r:
|
||||
parts = (prefix, r.groups(), suffix)
|
||||
# map all tuple elem to int for numeric sort, filter out None entries
|
||||
return tuple(map(float, chain.from_iterable(filter(None, parts))))
|
||||
return float('inf'), # un-matched layers (neck, head) mapped to largest ordinal
|
||||
else:
|
||||
ord = group_matcher(name)
|
||||
if not isinstance(ord, collections.abc.Iterable):
|
||||
return ord,
|
||||
return tuple(ord)
|
||||
|
||||
# map layers into groups via ordinals (ints or tuples of ints) from matcher
|
||||
grouping = defaultdict(list)
|
||||
for k, v in named_objects:
|
||||
grouping[_get_grouping(k)].append(v if output_values else k)
|
||||
|
||||
# remap to integers
|
||||
layer_id_to_param = defaultdict(list)
|
||||
lid = -1
|
||||
for k in sorted(filter(lambda x: x is not None, grouping.keys())):
|
||||
if lid < 0 or k[-1] != MATCH_PREV_GROUP[0]:
|
||||
lid += 1
|
||||
layer_id_to_param[lid].extend(grouping[k])
|
||||
|
||||
if reverse:
|
||||
assert not output_values, "reverse mapping only sensible for name output"
|
||||
# output reverse mapping
|
||||
param_to_layer_id = {}
|
||||
for lid, lm in layer_id_to_param.items():
|
||||
for n in lm:
|
||||
param_to_layer_id[n] = lid
|
||||
return param_to_layer_id
|
||||
|
||||
return layer_id_to_param
|
||||
|
||||
|
||||
def group_parameters(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
module.named_parameters(), group_matcher, output_values=output_values, reverse=reverse)
|
||||
|
||||
|
||||
def group_modules(
|
||||
module: nn.Module,
|
||||
group_matcher,
|
||||
output_values=False,
|
||||
reverse=False,
|
||||
):
|
||||
return group_with_matcher(
|
||||
named_modules_with_params(module), group_matcher, output_values=output_values, reverse=reverse)
|
||||
|
||||
|
||||
def flatten_modules(named_modules, depth=1, prefix='', module_types='sequential'):
|
||||
prefix_is_tuple = isinstance(prefix, tuple)
|
||||
if isinstance(module_types, str):
|
||||
if module_types == 'container':
|
||||
module_types = (nn.Sequential, nn.ModuleList, nn.ModuleDict)
|
||||
else:
|
||||
module_types = (nn.Sequential,)
|
||||
for name, module in named_modules:
|
||||
if depth and isinstance(module, module_types):
|
||||
yield from flatten_modules(
|
||||
module.named_children(),
|
||||
depth - 1,
|
||||
prefix=(name,) if prefix_is_tuple else name,
|
||||
module_types=module_types,
|
||||
)
|
||||
else:
|
||||
if prefix_is_tuple:
|
||||
name = prefix + (name,)
|
||||
yield name, module
|
||||
else:
|
||||
if prefix:
|
||||
name = '.'.join([prefix, name])
|
||||
yield name, module
|
||||
|
||||
|
||||
def checkpoint_seq(
|
||||
functions,
|
||||
x,
|
||||
every=1,
|
||||
flatten=False,
|
||||
skip_last=False,
|
||||
preserve_rng_state=True
|
||||
):
|
||||
r"""A helper function for checkpointing sequential models.
|
||||
|
||||
Sequential models execute a list of modules/functions in order
|
||||
(sequentially). Therefore, we can divide such a sequence into segments
|
||||
and checkpoint each segment. All segments except run in :func:`torch.no_grad`
|
||||
manner, i.e., not storing the intermediate activations. The inputs of each
|
||||
checkpointed segment will be saved for re-running the segment in the backward pass.
|
||||
|
||||
See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.
|
||||
|
||||
.. warning::
|
||||
Checkpointing currently only supports :func:`torch.autograd.backward`
|
||||
and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
|
||||
is not supported.
|
||||
|
||||
.. warning:
|
||||
At least one of the inputs needs to have :code:`requires_grad=True` if
|
||||
grads are needed for model inputs, otherwise the checkpointed part of the
|
||||
model won't have gradients.
|
||||
|
||||
Args:
|
||||
functions: A :class:`torch.nn.Sequential` or the list of modules or functions to run sequentially.
|
||||
x: A Tensor that is input to :attr:`functions`
|
||||
every: checkpoint every-n functions (default: 1)
|
||||
flatten (bool): flatten nn.Sequential of nn.Sequentials
|
||||
skip_last (bool): skip checkpointing the last function in the sequence if True
|
||||
preserve_rng_state (bool, optional, default=True): Omit stashing and restoring
|
||||
the RNG state during each checkpoint.
|
||||
|
||||
Returns:
|
||||
Output of running :attr:`functions` sequentially on :attr:`*inputs`
|
||||
|
||||
Example:
|
||||
>>> model = nn.Sequential(...)
|
||||
>>> input_var = checkpoint_seq(model, input_var, every=2)
|
||||
"""
|
||||
def run_function(start, end, functions):
|
||||
def forward(_x):
|
||||
for j in range(start, end + 1):
|
||||
_x = functions[j](_x)
|
||||
return _x
|
||||
return forward
|
||||
|
||||
if isinstance(functions, torch.nn.Sequential):
|
||||
functions = functions.children()
|
||||
if flatten:
|
||||
functions = chain.from_iterable(functions)
|
||||
if not isinstance(functions, (tuple, list)):
|
||||
functions = tuple(functions)
|
||||
|
||||
num_checkpointed = len(functions)
|
||||
if skip_last:
|
||||
num_checkpointed -= 1
|
||||
end = -1
|
||||
for start in range(0, num_checkpointed, every):
|
||||
end = min(start + every - 1, num_checkpointed - 1)
|
||||
x = checkpoint(run_function(start, end, functions), x, preserve_rng_state=preserve_rng_state)
|
||||
if skip_last:
|
||||
return run_function(end + 1, len(functions) - 1, functions)(x)
|
||||
return x
|
||||
|
||||
|
||||
def adapt_input_conv(in_chans, conv_weight):
|
||||
conv_type = conv_weight.dtype
|
||||
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
|
||||
O, I, J, K = conv_weight.shape
|
||||
if in_chans == 1:
|
||||
if I > 3:
|
||||
assert conv_weight.shape[1] % 3 == 0
|
||||
# For models with space2depth stems
|
||||
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
|
||||
conv_weight = conv_weight.sum(dim=2, keepdim=False)
|
||||
else:
|
||||
conv_weight = conv_weight.sum(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
if I != 3:
|
||||
raise NotImplementedError('Weight format not supported by conversion.')
|
||||
else:
|
||||
# NOTE this strategy should be better than random init, but there could be other combinations of
|
||||
# the original RGB input layer weights that'd work better for specific cases.
|
||||
repeat = int(math.ceil(in_chans / 3))
|
||||
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
|
||||
conv_weight *= (3 / float(in_chans))
|
||||
conv_weight = conv_weight.to(conv_type)
|
||||
return conv_weight
|
@ -1,113 +0,0 @@
|
||||
import os
|
||||
import pkgutil
|
||||
from copy import deepcopy
|
||||
|
||||
from torch import nn as nn
|
||||
|
||||
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
||||
|
||||
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
|
||||
|
||||
|
||||
def extract_layer(model, layer):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
if not hasattr(model, 'module') and layer[0] == 'module':
|
||||
layer = layer[1:]
|
||||
for l in layer:
|
||||
if hasattr(module, l):
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
else:
|
||||
return module
|
||||
return module
|
||||
|
||||
|
||||
def set_layer(model, layer, val):
|
||||
layer = layer.split('.')
|
||||
module = model
|
||||
if hasattr(model, 'module') and layer[0] != 'module':
|
||||
module = model.module
|
||||
lst_index = 0
|
||||
module2 = module
|
||||
for l in layer:
|
||||
if hasattr(module2, l):
|
||||
if not l.isdigit():
|
||||
module2 = getattr(module2, l)
|
||||
else:
|
||||
module2 = module2[int(l)]
|
||||
lst_index += 1
|
||||
lst_index -= 1
|
||||
for l in layer[:lst_index]:
|
||||
if not l.isdigit():
|
||||
module = getattr(module, l)
|
||||
else:
|
||||
module = module[int(l)]
|
||||
l = layer[lst_index]
|
||||
setattr(module, l, val)
|
||||
|
||||
|
||||
def adapt_model_from_string(parent_module, model_string):
|
||||
separator = '***'
|
||||
state_dict = {}
|
||||
lst_shape = model_string.split(separator)
|
||||
for k in lst_shape:
|
||||
k = k.split(':')
|
||||
key = k[0]
|
||||
shape = k[1][1:-1].split(',')
|
||||
if shape[0] != '':
|
||||
state_dict[key] = [int(i) for i in shape]
|
||||
|
||||
new_module = deepcopy(parent_module)
|
||||
for n, m in parent_module.named_modules():
|
||||
old_module = extract_layer(parent_module, n)
|
||||
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
||||
if isinstance(old_module, Conv2dSame):
|
||||
conv = Conv2dSame
|
||||
else:
|
||||
conv = nn.Conv2d
|
||||
s = state_dict[n + '.weight']
|
||||
in_channels = s[1]
|
||||
out_channels = s[0]
|
||||
g = 1
|
||||
if old_module.groups > 1:
|
||||
in_channels = out_channels
|
||||
g = in_channels
|
||||
new_conv = conv(
|
||||
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
||||
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
||||
groups=g, stride=old_module.stride)
|
||||
set_layer(new_module, n, new_conv)
|
||||
elif isinstance(old_module, BatchNormAct2d):
|
||||
new_bn = BatchNormAct2d(
|
||||
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
new_bn.drop = old_module.drop
|
||||
new_bn.act = old_module.act
|
||||
set_layer(new_module, n, new_bn)
|
||||
elif isinstance(old_module, nn.BatchNorm2d):
|
||||
new_bn = nn.BatchNorm2d(
|
||||
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
||||
affine=old_module.affine, track_running_stats=True)
|
||||
set_layer(new_module, n, new_bn)
|
||||
elif isinstance(old_module, nn.Linear):
|
||||
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
||||
num_features = state_dict[n + '.weight'][1]
|
||||
new_fc = Linear(
|
||||
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
||||
set_layer(new_module, n, new_fc)
|
||||
if hasattr(new_module, 'num_features'):
|
||||
new_module.num_features = num_features
|
||||
new_module.eval()
|
||||
parent_module.eval()
|
||||
|
||||
return new_module
|
||||
|
||||
|
||||
def adapt_model_from_file(parent_module, model_variant):
|
||||
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
|
||||
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
|
@ -1,234 +0,0 @@
|
||||
""" Model Registry
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
import sys
|
||||
from collections import defaultdict, deque
|
||||
from copy import deepcopy
|
||||
from dataclasses import replace
|
||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
|
||||
|
||||
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
|
||||
|
||||
__all__ = [
|
||||
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
|
||||
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
|
||||
|
||||
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
|
||||
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
|
||||
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
|
||||
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
|
||||
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
|
||||
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
|
||||
|
||||
|
||||
def get_arch_name(model_name: str) -> str:
|
||||
return split_model_name_tag(model_name)[0]
|
||||
|
||||
|
||||
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
|
||||
# add model to __all__ in module
|
||||
model_name = fn.__name__
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(model_name)
|
||||
else:
|
||||
mod.__all__ = [model_name] # type: ignore
|
||||
|
||||
# add entries to registry dict/sets
|
||||
_model_entrypoints[model_name] = fn
|
||||
_model_to_module[model_name] = module_name
|
||||
_module_to_models[module_name].add(model_name)
|
||||
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||
# entrypoints or non-matching combos
|
||||
default_cfg = mod.default_cfgs[model_name]
|
||||
if not isinstance(default_cfg, DefaultCfg):
|
||||
# new style default cfg dataclass w/ multiple entries per model-arch
|
||||
assert isinstance(default_cfg, dict)
|
||||
# old style cfg dict per model-arch
|
||||
pretrained_cfg = PretrainedCfg(**default_cfg)
|
||||
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
|
||||
|
||||
for tag_idx, tag in enumerate(default_cfg.tags):
|
||||
is_default = tag_idx == 0
|
||||
pretrained_cfg = default_cfg.cfgs[tag]
|
||||
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
|
||||
replace_items = dict(architecture=model_name, tag=tag if tag else None)
|
||||
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
|
||||
# auto-complete hub name w/ architecture.tag
|
||||
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
|
||||
pretrained_cfg = replace(pretrained_cfg, **replace_items)
|
||||
|
||||
if is_default:
|
||||
_model_pretrained_cfgs[model_name] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add tagless entry if it's default and has weights
|
||||
_model_has_pretrained.add(model_name)
|
||||
|
||||
if tag:
|
||||
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
|
||||
if pretrained_cfg.has_weights:
|
||||
# add model w/ tag if tag is valid
|
||||
_model_has_pretrained.add(model_name_tag)
|
||||
_model_with_tags[model_name].append(model_name_tag)
|
||||
else:
|
||||
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
|
||||
|
||||
_model_default_cfgs[model_name] = default_cfg
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def _natural_key(string_: str) -> List[Union[int, str]]:
|
||||
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
|
||||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(
|
||||
filter: Union[str, List[str]] = '',
|
||||
module: str = '',
|
||||
pretrained: bool = False,
|
||||
exclude_filters: Union[str, List[str]] = '',
|
||||
name_matches_cfg: bool = False,
|
||||
include_tags: Optional[bool] = None,
|
||||
) -> List[str]:
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
filter - Wildcard filter string that works with fnmatch
|
||||
module - Limit model selection to a specific submodule (ie 'vision_transformer')
|
||||
pretrained - Include only models with valid pretrained weights if True
|
||||
exclude_filters - Wildcard filters to exclude models after including them with filter
|
||||
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
|
||||
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
|
||||
set to True when pretrained=True else False (default: None)
|
||||
|
||||
Returns:
|
||||
models - The sorted list of models
|
||||
|
||||
Example:
|
||||
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
|
||||
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
|
||||
"""
|
||||
if include_tags is None:
|
||||
# FIXME should this be default behaviour? or default to include_tags=True?
|
||||
include_tags = pretrained
|
||||
|
||||
if module:
|
||||
all_models: Iterable[str] = list(_module_to_models[module])
|
||||
else:
|
||||
all_models = _model_entrypoints.keys()
|
||||
|
||||
if include_tags:
|
||||
# expand model names to include names w/ pretrained tags
|
||||
models_with_tags = []
|
||||
for m in all_models:
|
||||
models_with_tags.extend(_model_with_tags[m])
|
||||
all_models = models_with_tags
|
||||
|
||||
if filter:
|
||||
models: Set[str] = set()
|
||||
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
|
||||
for f in include_filters:
|
||||
include_models = fnmatch.filter(all_models, f) # include these models
|
||||
if len(include_models):
|
||||
models = models.union(include_models)
|
||||
else:
|
||||
models = set(all_models)
|
||||
|
||||
if exclude_filters:
|
||||
if not isinstance(exclude_filters, (tuple, list)):
|
||||
exclude_filters = [exclude_filters]
|
||||
for xf in exclude_filters:
|
||||
exclude_models = fnmatch.filter(models, xf) # exclude these models
|
||||
if len(exclude_models):
|
||||
models = models.difference(exclude_models)
|
||||
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
|
||||
if name_matches_cfg:
|
||||
models = set(_model_pretrained_cfgs).intersection(models)
|
||||
|
||||
return sorted(models, key=_natural_key)
|
||||
|
||||
|
||||
def list_pretrained(
|
||||
filter: Union[str, List[str]] = '',
|
||||
exclude_filters: str = '',
|
||||
) -> List[str]:
|
||||
return list_models(
|
||||
filter=filter,
|
||||
pretrained=True,
|
||||
exclude_filters=exclude_filters,
|
||||
include_tags=True,
|
||||
)
|
||||
|
||||
|
||||
def is_model(model_name: str) -> bool:
|
||||
""" Check if a model name exists
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
return arch_name in _model_entrypoints
|
||||
|
||||
|
||||
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
|
||||
"""Fetch a model entrypoint for specified model name
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
if module_filter and arch_name not in _module_to_models.get(module_filter, {}):
|
||||
raise RuntimeError(f'Model ({model_name} not found in module {module_filter}.')
|
||||
return _model_entrypoints[arch_name]
|
||||
|
||||
|
||||
def list_modules() -> List[str]:
|
||||
""" Return list of module names that contain models / model entrypoints
|
||||
"""
|
||||
modules = _module_to_models.keys()
|
||||
return sorted(modules)
|
||||
|
||||
|
||||
def is_model_in_modules(
|
||||
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
|
||||
) -> bool:
|
||||
"""Check if a model exists within a subset of modules
|
||||
|
||||
Args:
|
||||
model_name - name of model to check
|
||||
module_names - names of modules to search in
|
||||
"""
|
||||
arch_name = get_arch_name(model_name)
|
||||
assert isinstance(module_names, (tuple, list, set))
|
||||
return any(arch_name in _module_to_models[n] for n in module_names)
|
||||
|
||||
|
||||
def is_model_pretrained(model_name: str) -> bool:
|
||||
return model_name in _model_has_pretrained
|
||||
|
||||
|
||||
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
|
||||
if model_name in _model_pretrained_cfgs:
|
||||
return deepcopy(_model_pretrained_cfgs[model_name])
|
||||
arch_name, tag = split_model_name_tag(model_name)
|
||||
if arch_name in _model_default_cfgs:
|
||||
# if model arch exists, but the tag is wrong, error out
|
||||
raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.')
|
||||
if allow_unregistered:
|
||||
# if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created
|
||||
return None
|
||||
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
|
||||
|
||||
|
||||
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
|
||||
""" Get a specific model default_cfg value by key. None if key doesn't exist.
|
||||
"""
|
||||
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
|
||||
return getattr(cfg, cfg_key, None)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue