Compare commits

..

28 Commits

Author SHA1 Message Date
Benjamin Bossan a5b01ec04e Add type annotations to _registry.py
2 years ago
Benjamin Bossan c9406ce608
Some additions to the CONTRIBUTING guide (#1685)
2 years ago
Ross Wightman a32c4eff69
Create CONTRIBUTING.md
2 years ago
Ross Wightman a0772f03e0
Update README.md
2 years ago
Ross Wightman 47f1de9bec Version bump
2 years ago
Ross Wightman 11f7b589e5 Update setup.py for huggingface changes.
2 years ago
Ross Wightman 4d9c3ae2fb Add laion2b 320x320 ConvNeXt-Large CLIP weights
2 years ago
Ross Wightman d0b45c9b4d Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent).
2 years ago
Ross Wightman 7d9e321b76 Improve tracing of window attn models with simpler reshape logic
2 years ago
Ross Wightman a3c6685e20
Delete requirements-modelindex.txt
2 years ago
Ross Wightman 022403ce0a Update README
2 years ago
Ross Wightman 2e38d53dca Remove dead line
2 years ago
Ross Wightman f77c04ff36 Torchscript fixes/hacks for rms_norm, refactor ParallelScalingBlock with manual combination of input projections, closer paper match
2 years ago
Ross Wightman 122621daef Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit
2 years ago
Ross Wightman 621e1b2182 Add ideas from 'Scaling ViT to 22-B Params', testing PyTorch 2.0 fused F.scaled_dot_product_attention impl in vit, vit_relpos, maxxvit / coatnet.
2 years ago
Ross Wightman a3d528524a Version 0.8.12dev0
2 years ago
testbot a09d403c24 changed warning to info
2 years ago
testbot 8470e29541 Add support to load safetensors weights
2 years ago
Ross Wightman f35d6ea57b Add multi-tensor (foreach) version of Lion in style of upcoming PyTorch 2.0 optimizers
2 years ago
Ross Wightman 709d5e0d9d Add Lion optimizer
2 years ago
Ross Wightman 624266148d Remove unused imports from _hub helpers
2 years ago
Ross Wightman 2cfff0581b Add grad_checkpointing support to features_only, test in EfficientDet.
2 years ago
Ross Wightman 45af496197 Version 0.8.11dev0
2 years ago
Ross Wightman 9c14654a0d Improve support for custom dataset label name/description through HF hub export, via pretrained_cfg
2 years ago
Ross Wightman 1e0b347227 Fix README
2 years ago
Ross Wightman 497be8343c Update README and version
2 years ago
Ross Wightman 0d33127df2 Add 384x384 convnext_large_mlp laion2b fine-tune on in1k
2 years ago
Ross Wightman 88a5b8491d
Merge pull request #1662 from rwightman/dataset_info
2 years ago

@ -0,0 +1,112 @@
*This guideline is very much a work-in-progress.*
Contriubtions to `timm` for code, documentation, tests are more than welcome!
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
# Coding style
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
A few specific differences from Google style (or black)
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
Example, from Google guide, but this is a NO here:
```
# Aligned with opening delimiter.
foo = long_function_name(var_one, var_two,
var_three, var_four)
meal = (spam,
beans)
# Aligned with opening delimiter in a dictionary.
foo = {
'long_dictionary_key': value1 +
value2,
...
}
```
This is YES:
```
# 4-space hanging indent; nothing on first line,
# closing parenthesis on a new line.
foo = long_function_name(
var_one, var_two, var_three,
var_four
)
meal = (
spam,
beans,
)
# 4-space hanging indent in a dictionary.
foo = {
'long_dictionary_key':
long_dictionary_value,
...
}
```
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
```
black --skip-string-normalization --line-length 120 <path-to-file>
```
Avoid formatting code that is unrelated to your PR though.
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
# Documentation
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
# Installation
Create a Python virtual environment using Python 3.10. Inside the environment, install the following test dependencies:
```
python -m pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
```
Install `torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
Then install the remaining dependencies:
```
python -m pip install -r requirements.txt
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
python -m pip install -e .
```
## Unit tests
Run the tests using:
```
pytest tests/
```
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
```
pytest -k "substring-to-match" -n 4 tests/
```
## Building documentation
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
# Questions
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).

@ -24,6 +24,35 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Feb 20, 2023
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
### Feb 16, 2023
* `safetensor` checkpoint support added
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
* Add F.scaled_dot_product_attention support (PyTorch 2.0 only) to `vit_*`, `vit_relpos*`, `coatnet` / `maxxvit` (to start)
* Lion optimizer (w/ multi-tensor option) added (https://arxiv.org/abs/2302.06675)
* gradient checkpointing works with `features_only=True`
### Feb 7, 2023
* New inference benchmark numbers added in [results](results/) folder.
* Add convnext LAION CLIP trained weights and initial set of in1k fine-tunes
* `convnext_base.clip_laion2b_augreg_ft_in1k` - 86.2% @ 256x256
* `convnext_base.clip_laiona_augreg_ft_in1k_384` - 86.5% @ 384x384
* `convnext_large_mlp.clip_laion2b_augreg_ft_in1k` - 87.3% @ 256x256
* `convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384` - 87.9% @ 384x384
* Add DaViT models. Supports `features_only=True`. Adapted from https://github.com/dingmyu/davit by [Fredo](https://github.com/fffffgggg54).
* Use a common NormMlpClassifierHead across MaxViT, ConvNeXt, DaViT
* Add EfficientFormer-V2 model, update EfficientFormer, and refactor LeViT (closely related architectures). Weights on HF hub.
* New EfficientFormer-V2 arch, significant refactor from original at (https://github.com/snap-research/EfficientFormer). Supports `features_only=True`.
* Minor updates to EfficientFormer.
* Refactor LeViT models to stages, add `features_only=True` support to new `conv` variants, weight remap required.
* Move ImageNet meta-data (synsets, indices) from `/results` to [`timm/data/_info`](timm/data/_info/).
* Add ImageNetInfo / DatasetInfo classes to provide labelling for various ImageNet classifier layouts in `timm`
* Update `inference.py` to use, try: `python inference.py /folder/to/images --model convnext_small.in12k --label-type detail --topk 5`
* Ready for 0.8.10 pypi pre-release (final testing).
### Jan 20, 2023 ### Jan 20, 2023
* Add two convnext 12k -> 1k fine-tunes at 384x384 * Add two convnext 12k -> 1k fine-tunes at 384x384
* `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384 * `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384
@ -571,7 +600,7 @@ Several (less common) features that I often utilize in my projects are included.
## Results ## Results
Model validation results can be found in the [documentation](https://rwightman.github.io/pytorch-image-models/results/) and in the [results tables](results/README.md) Model validation results can be found in the [results tables](results/README.md)
## Getting Started (Documentation) ## Getting Started (Documentation)

@ -17,20 +17,30 @@ import os
import glob import glob
import hashlib import hashlib
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH', parser.add_argument('--input', default='', type=str, metavar='PATH',
help='path to base input folder containing checkpoints') help='path to base input folder containing checkpoints')
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
help='checkpoint filter (path wildcard)') help='checkpoint filter (path wildcard)')
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
help='output filename') help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='Force not using ema version of weights (if present)') help='Force not using ema version of weights (if present)')
parser.add_argument('--no-sort', dest='no_sort', action='store_true', parser.add_argument('--no-sort', dest='no_sort', action='store_true',
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
parser.add_argument('-n', type=int, default=10, metavar='N', parser.add_argument('-n', type=int, default=10, metavar='N',
help='Number of checkpoints to average') help='Number of checkpoints to average')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path): def checkpoint_metric(checkpoint_path):
@ -55,8 +65,23 @@ def main():
# by default sort by checkpoint metric (if present) and avg top n checkpoints # by default sort by checkpoint metric (if present) and avg top n checkpoints
args.sort = not args.no_sort args.sort = not args.no_sort
if os.path.exists(args.output): if args.safetensors and args.output == DEFAULT_OUTPUT:
print("Error: Output filename ({}) already exists.".format(args.output)) # Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(output))
exit(1) exit(1)
pattern = args.input pattern = args.input
@ -73,22 +98,27 @@ def main():
checkpoint_metrics.append((metric, c)) checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics)) checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:] checkpoint_metrics = checkpoint_metrics[-args.n:]
if checkpoint_metrics:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics] [print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics] avg_checkpoints = [c for m, c in checkpoint_metrics]
else: else:
avg_checkpoints = checkpoints avg_checkpoints = checkpoints
if avg_checkpoints:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(c) for c in checkpoints] [print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {} avg_state_dict = {}
avg_counts = {} avg_counts = {}
for c in avg_checkpoints: for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema) new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict: if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) print(f"Error: Checkpoint ({c}) doesn't exist")
continue continue
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
if k not in avg_state_dict: if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64) avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -107,14 +137,15 @@ def main():
v = v.clamp(float32_info.min, float32_info.max) v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32) final_state_dict[k] = v.to(dtype=torch.float32)
try: if args.safetensors:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) assert _has_safetensors, "`pip install safetensors` to use .safetensors"
except: safetensors.torch.save_file(final_state_dict, output)
torch.save(final_state_dict, args.output) else:
torch.save(final_state_dict, output)
with open(args.output, 'rb') as f: with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
if __name__ == '__main__': if __name__ == '__main__':

@ -12,8 +12,13 @@ import argparse
import os import os
import hashlib import hashlib
import shutil import shutil
from collections import OrderedDict import tempfile
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -22,10 +27,12 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path') help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
_TEMP_NAME = './_checkpoint.pth' help='Save weights using safetensors instead of the default torch way (pickle).')
def main(): def main():
@ -35,10 +42,24 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
exit(1) exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn) clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint): if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint)) print("=> Loading checkpoint '{}'".format(checkpoint))
@ -53,22 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
new_state_dict[name] = v new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint)) print("=> Loaded state_dict from '{}'".format(checkpoint))
try: ext = ''
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if output: if output:
checkpoint_root, checkpoint_base = os.path.split(output) checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base = os.path.splitext(checkpoint_base)[0] checkpoint_base, ext = os.path.splitext(checkpoint_base)
else: else:
checkpoint_root = '' checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0] checkpoint_base = os.path.split(checkpoint)[1]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' checkpoint_base = os.path.splitext(checkpoint_base)[0]
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
temp_filename = '__' + checkpoint_base
if safe_serialization:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
else:
torch.save(new_state_dict, temp_filename)
with open(temp_filename, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if ext:
final_ext = ext
else:
final_ext = ('.safetensors' if safe_serialization else '.pth')
if no_hash:
final_filename = checkpoint_base + final_ext
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename return final_filename
else: else:

@ -1,2 +0,0 @@
model-index==0.1.10
jinja2==2.11.3

@ -2,3 +2,4 @@ torch>=1.7
torchvision torchvision
pyyaml pyyaml
huggingface_hub huggingface_hub
safetensors>=0.2

@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
setup( setup(
name='timm', name='timm',
version=__version__, version=__version__,
description='(Unofficial) PyTorch Image Models', description='PyTorch Image Models',
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
url='https://github.com/rwightman/pytorch-image-models', url='https://github.com/huggingface/pytorch-image-models',
author='Ross Wightman', author='Ross Wightman',
author_email='hello@rwightman.com', author_email='ross@huggingface.co',
classifiers=[ classifiers=[
# How mature is this project? Common values are # How mature is this project? Common values are
# 3 - Alpha # 3 - Alpha
@ -29,11 +29,11 @@ setup(
'Intended Audience :: Education', 'Intended Audience :: Education',
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development', 'Topic :: Software Development',
@ -45,7 +45,7 @@ setup(
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit', keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']), packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True, include_package_data=True,
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'], install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
python_requires='>=3.6', python_requires='>=3.7',
) )

@ -4,7 +4,7 @@ from .config import resolve_data_config, resolve_model_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo from .dataset_info import DatasetInfo, CustomDatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader from .loader import create_loader
from .mixup import Mixup, FastCollateMixup from .mixup import Mixup, FastCollateMixup

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
class DatasetInfo(ABC): class DatasetInfo(ABC):
@ -30,3 +30,44 @@ class DatasetInfo(ABC):
@abstractmethod @abstractmethod
def label_name_to_description(self, label: str, detailed: bool = False) -> str: def label_name_to_description(self, label: str, detailed: bool = False) -> str:
pass pass
class CustomDatasetInfo(DatasetInfo):
""" DatasetInfo that wraps passed values for custom datasets."""
def __init__(
self,
label_names: Union[List[str], Dict[int, str]],
label_descriptions: Optional[Dict[str, str]] = None
):
super().__init__()
assert len(label_names) > 0
self._label_names = label_names # label index => label name mapping
self._label_descriptions = label_descriptions # label name => label description mapping
if self._label_descriptions is not None:
# validate descriptions (label names required)
assert isinstance(self._label_descriptions, dict)
for n in self._label_names:
assert n in self._label_descriptions
def num_classes(self):
return len(self._label_names)
def label_names(self):
return self._label_names
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
return self._label_descriptions
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
if self._label_descriptions:
return self._label_descriptions[label]
return label # return label name itself if a descriptions is not present
def index_to_label_name(self, index) -> str:
assert 0 <= index < len(self._label_names)
return self._label_names[index]
def index_to_description(self, index: int, detailed: bool = False) -> str:
label = self.index_to_label_name(index)
return self.label_name_to_description(label, detailed=detailed)

@ -28,7 +28,7 @@ from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same

@ -17,6 +17,12 @@ try:
except ImportError: except ImportError:
has_apex = False has_apex = False
try:
from apex.normalization.fused_layer_norm import fused_rms_norm_affine, fused_rms_norm
has_apex_rmsnorm = True
except ImportError:
has_apex_rmsnorm = False
# fast (ie lower precision LN) can be disabled with this flag if issues crop up # fast (ie lower precision LN) can be disabled with this flag if issues crop up
_USE_FAST_NORM = False # defaulting to False for now _USE_FAST_NORM = False # defaulting to False for now
@ -76,3 +82,45 @@ def fast_layer_norm(
with torch.cuda.amp.autocast(enabled=False): with torch.cuda.amp.autocast(enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps) return F.layer_norm(x, normalized_shape, weight, bias, eps)
def rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
):
norm_ndim = len(normalized_shape)
if torch.jit.is_scripting():
# ndim = len(x.shape)
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
assert norm_ndim == 1
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
else:
dims = tuple(range(-1, -norm_ndim - 1, -1))
v = torch.var(x, dim=dims, keepdim=True)
x = x * torch.rsqrt(v + eps)
if weight is not None:
x = x * weight
return x
def fast_rms_norm(
x: torch.Tensor,
normalized_shape: List[int],
weight: Optional[torch.Tensor] = None,
eps: float = 1e-5,
) -> torch.Tensor:
if torch.jit.is_scripting():
# this must be by itself, cannot merge with has_apex_rmsnorm
return rms_norm(x, normalized_shape, weight, eps)
if has_apex_rmsnorm:
if weight is None:
return fused_rms_norm(x, normalized_shape, eps)
else:
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
# fallback
return rms_norm(x, normalized_shape, weight, eps)

@ -4,12 +4,14 @@ Norm layer definitions that support fast norm and consistent channel arg order (
Hacked together by / Copyright 2022 Ross Wightman Hacked together by / Copyright 2022 Ross Wightman
""" """
import numbers
from typing import Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
class GroupNorm(nn.GroupNorm): class GroupNorm(nn.GroupNorm):
@ -115,3 +117,39 @@ class LayerNormExp2d(nn.LayerNorm):
else: else:
x = _layer_norm_cf(x, self.weight, self.bias, self.eps) x = _layer_norm_cf(x, self.weight, self.bias, self.eps)
return x return x
class RmsNorm(nn.Module):
""" RmsNorm w/ fast (apex) norm if available
"""
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
normalized_shape: Tuple[int, ...]
eps: float
elementwise_affine: bool
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
normalized_shape = channels
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = affine
if self.elementwise_affine:
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter('weight', None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_affine:
nn.init.ones_(self.weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
return x

@ -83,8 +83,8 @@ def gen_relative_log_coords(
pretrained_win_size: Tuple[int, int] = (0, 0), pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin', mode='swin',
): ):
assert mode in ('swin', 'cr', 'rw') assert mode in ('swin', 'cr')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well # as per official swin-v2 impl, supporting timm specific 'cr' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
@ -99,15 +99,6 @@ def gen_relative_log_coords(
relative_coords_table *= 8 # normalize to -8, 8 relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2( relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8) 1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else: else:
# mode == 'cr' # mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log( relative_coords_table = torch.sign(relative_coords_table) * torch.log(
@ -141,10 +132,6 @@ class RelPosMlp(nn.Module):
self.bias_act = nn.Sigmoid() self.bias_act = nn.Sigmoid()
self.bias_gain = 16 self.bias_gain = 16
mlp_bias = (True, False) mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else: else:
self.bias_act = nn.Identity() self.bias_act = nn.Identity()
self.bias_gain = None self.bias_gain = None

@ -11,10 +11,11 @@ Hacked together by / Copyright 2020 Ross Wightman
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from copy import deepcopy from copy import deepcopy
from functools import partial from functools import partial
from typing import Dict, List, Tuple from typing import Dict, List, Sequence, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.checkpoint import checkpoint
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] __all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
@ -88,12 +89,20 @@ class FeatureHooks:
""" Feature Hook Helper """ Feature Hook Helper
This module helps with the setup and extraction of hooks for extracting features from This module helps with the setup and extraction of hooks for extracting features from
internal nodes in a model by node name. This works quite well in eager Python but needs internal nodes in a model by node name.
redesign for torchscript.
FIXME This works well in eager Python but needs redesign for torchscript.
""" """
def __init__(self, hooks, named_modules, out_map=None, default_hook_type='forward'): def __init__(
self,
hooks: Sequence[str],
named_modules: dict,
out_map: Sequence[Union[int, str]] = None,
default_hook_type: str = 'forward',
):
# setup feature hooks # setup feature hooks
self._feature_outputs = defaultdict(OrderedDict)
modules = {k: v for k, v in named_modules} modules = {k: v for k, v in named_modules}
for i, h in enumerate(hooks): for i, h in enumerate(hooks):
hook_name = h['module'] hook_name = h['module']
@ -107,7 +116,6 @@ class FeatureHooks:
m.register_forward_hook(hook_fn) m.register_forward_hook(hook_fn)
else: else:
assert False, "Unsupported hook type" assert False, "Unsupported hook type"
self._feature_outputs = defaultdict(OrderedDict)
def _collect_output_hook(self, hook_id, *args): def _collect_output_hook(self, hook_id, *args):
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
@ -167,23 +175,30 @@ class FeatureDictNet(nn.ModuleDict):
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
All Sequential containers that are directly assigned to the original model will have their All Sequential containers that are directly assigned to the original model will have their
modules assigned to this module with the name `model.features.1` being changed to `model.features_1` modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
Arguments:
model (nn.Module): model from which we will extract the features
out_indices (tuple[int]): model output indices to extract features for
out_map (sequence): list or tuple specifying desired return id for each out index,
otherwise str(index) is used
feature_concat (bool): whether to concatenate intermediate features that are lists or tuples
vs select element [0]
flatten_sequential (bool): whether to flatten sequential modules assigned to model
""" """
def __init__( def __init__(
self, model, self,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
feature_concat: bool = False,
flatten_sequential: bool = False,
):
"""
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
"""
super(FeatureDictNet, self).__init__() super(FeatureDictNet, self).__init__()
self.feature_info = _get_feature_info(model, out_indices) self.feature_info = _get_feature_info(model, out_indices)
self.concat = feature_concat self.concat = feature_concat
self.grad_checkpointing = False
self.return_layers = {} self.return_layers = {}
return_layers = _get_return_layers(self.feature_info, out_map) return_layers = _get_return_layers(self.feature_info, out_map)
modules = _module_list(model, flatten_sequential=flatten_sequential) modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = set(return_layers.keys()) remaining = set(return_layers.keys())
@ -200,10 +215,21 @@ class FeatureDictNet(nn.ModuleDict):
f'Return layers ({remaining}) are not present in model' f'Return layers ({remaining}) are not present in model'
self.update(layers) self.update(layers)
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def _collect(self, x) -> (Dict[str, torch.Tensor]): def _collect(self, x) -> (Dict[str, torch.Tensor]):
out = OrderedDict() out = OrderedDict()
for name, module in self.items(): for i, (name, module) in enumerate(self.items()):
if self.grad_checkpointing and not torch.jit.is_scripting():
# Skipping checkpoint of first module because need a gradient at input
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
x = module(x) if first_or_last_module else checkpoint(module, x)
else:
x = module(x) x = module(x)
if name in self.return_layers: if name in self.return_layers:
out_id = self.return_layers[name] out_id = self.return_layers[name]
if isinstance(x, (tuple, list)): if isinstance(x, (tuple, list)):
@ -221,15 +247,29 @@ class FeatureDictNet(nn.ModuleDict):
class FeatureListNet(FeatureDictNet): class FeatureListNet(FeatureDictNet):
""" Feature extractor with list return """ Feature extractor with list return
See docstring for FeatureDictNet above, this class exists only to appease Torchscript typing constraints. A specialization of FeatureDictNet that always returns features as a list (values() of dict).
In eager Python we could have returned List[Tensor] vs Dict[id, Tensor] based on a member bool.
""" """
def __init__( def __init__(
self, model, self,
out_indices=(0, 1, 2, 3, 4), out_map=None, feature_concat=False, flatten_sequential=False): model: nn.Module,
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
feature_concat: bool = False,
flatten_sequential: bool = False,
):
"""
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
first element e.g. `x[0]`
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
"""
super(FeatureListNet, self).__init__( super(FeatureListNet, self).__init__(
model, out_indices=out_indices, out_map=out_map, feature_concat=feature_concat, model,
flatten_sequential=flatten_sequential) out_indices=out_indices,
feature_concat=feature_concat,
flatten_sequential=flatten_sequential,
)
def forward(self, x) -> (List[torch.Tensor]): def forward(self, x) -> (List[torch.Tensor]):
return list(self._collect(x).values()) return list(self._collect(x).values())
@ -249,13 +289,33 @@ class FeatureHookNet(nn.ModuleDict):
FIXME this does not currently work with Torchscript, see FeatureHooks class FIXME this does not currently work with Torchscript, see FeatureHooks class
""" """
def __init__( def __init__(
self, model, self,
out_indices=(0, 1, 2, 3, 4), out_map=None, out_as_dict=False, no_rewrite=False, model: nn.Module,
feature_concat=False, flatten_sequential=False, default_hook_type='forward'): out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
out_map: Sequence[Union[int, str]] = None,
out_as_dict: bool = False,
no_rewrite: bool = False,
flatten_sequential: bool = False,
default_hook_type: str = 'forward',
):
"""
Args:
model: Model from which to extract features.
out_indices: Output indices of the model features to extract.
out_map: Return id mapping for each output index, otherwise str(index) is used.
out_as_dict: Output features as a dict.
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
flatten_sequential arg must also be False if this is set True.
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
default_hook_type: The default hook type to use if not specified in model.feature_info.
"""
super(FeatureHookNet, self).__init__() super(FeatureHookNet, self).__init__()
assert not torch.jit.is_scripting() assert not torch.jit.is_scripting()
self.feature_info = _get_feature_info(model, out_indices) self.feature_info = _get_feature_info(model, out_indices)
self.out_as_dict = out_as_dict self.out_as_dict = out_as_dict
self.grad_checkpointing = False
layers = OrderedDict() layers = OrderedDict()
hooks = [] hooks = []
if no_rewrite: if no_rewrite:
@ -266,8 +326,10 @@ class FeatureHookNet(nn.ModuleDict):
hooks.extend(self.feature_info.get_dicts()) hooks.extend(self.feature_info.get_dicts())
else: else:
modules = _module_list(model, flatten_sequential=flatten_sequential) modules = _module_list(model, flatten_sequential=flatten_sequential)
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type remaining = {
for f in self.feature_info.get_dicts()} f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
for f in self.feature_info.get_dicts()
}
for new_name, old_name, module in modules: for new_name, old_name, module in modules:
layers[new_name] = module layers[new_name] = module
for fn, fm in module.named_modules(prefix=old_name): for fn, fm in module.named_modules(prefix=old_name):
@ -280,8 +342,18 @@ class FeatureHookNet(nn.ModuleDict):
self.update(layers) self.update(layers)
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
def set_grad_checkpointing(self, enable: bool = True):
self.grad_checkpointing = enable
def forward(self, x): def forward(self, x):
for name, module in self.items(): for i, (name, module) in enumerate(self.items()):
if self.grad_checkpointing and not torch.jit.is_scripting():
# Skipping checkpoint of first module because need a gradient at input
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
x = module(x) if first_or_last_module else checkpoint(module, x)
else:
x = module(x) x = module(x)
out = self.hooks.get_output(x.device) out = self.hooks.get_output(x.device)
return out if self.out_as_dict else list(out.values()) return out if self.out_as_dict else list(out.values())

@ -7,6 +7,11 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder import timm.models._builder
@ -26,7 +31,13 @@ def clean_state_dict(state_dict):
def load_state_dict(checkpoint_path, use_ema=True): def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = '' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
if use_ema and checkpoint.get('state_dict_ema', None) is not None: if use_ema and checkpoint.get('state_dict_ema', None) is not None:

@ -2,10 +2,11 @@ import hashlib
import json import json
import logging import logging
import os import os
import sys
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional, Union from typing import Iterable, Optional, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
@ -15,6 +16,17 @@ try:
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from timm import __version__ from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg from timm.models._pretrained import filter_pretrained_cfg
@ -35,6 +47,10 @@ _logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''): def get_cache_dir(child_dir=''):
""" """
@ -96,7 +112,7 @@ def has_hf_hub(necessary=False):
return _has_hf_hub return _has_hf_hub
def hf_split(hf_id): def hf_split(hf_id: str):
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
rev_split = hf_id.split('@') rev_split = hf_id.split('@')
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.' assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
@ -127,30 +143,56 @@ def load_model_config_from_hf(model_id: str):
hf_config = {} hf_config = {}
hf_config['architecture'] = pretrained_cfg.pop('architecture') hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_features'] = pretrained_cfg.pop('num_features', None) hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
if 'labels' in pretrained_cfg: if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
hf_config['label_name'] = pretrained_cfg.pop('labels') pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
hf_config['pretrained_cfg'] = pretrained_cfg hf_config['pretrained_cfg'] = pretrained_cfg
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
pretrained_cfg = hf_config['pretrained_cfg'] pretrained_cfg = hf_config['pretrained_cfg']
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
pretrained_cfg['source'] = 'hf-hub' pretrained_cfg['source'] = 'hf-hub'
# model should be created with base config num_classes if its exist
if 'num_classes' in hf_config: if 'num_classes' in hf_config:
# model should be created with parent num_classes if they exist
pretrained_cfg['num_classes'] = hf_config['num_classes'] pretrained_cfg['num_classes'] = hf_config['num_classes']
model_name = hf_config['architecture']
# label meta-data in base config overrides saved pretrained_cfg on load
if 'label_names' in hf_config:
pretrained_cfg['label_names'] = hf_config.pop('label_names')
if 'label_descriptions' in hf_config:
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
model_name = hf_config['architecture']
return pretrained_cfg, model_name return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = download_from_hf(model_id, filename) hf_model_id, hf_revision = hf_split(model_id)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict # Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path, model_config=None): def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
model_config = model_config or {} model_config = model_config or {}
hf_config = {} hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -164,22 +206,22 @@ def save_config_for_hf(model, config_path, model_config=None):
if 'labels' in model_config: if 'labels' in model_config:
_logger.warning( _logger.warning(
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. " "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
"Using provided 'label' field as 'label_name'.") " Renaming provided 'labels' field to 'label_names'.")
model_config['label_name'] = model_config.pop('labels') model_config.setdefault('label_names', model_config.pop('labels'))
label_name = model_config.pop('label_name', None) label_names = model_config.pop('label_names', None)
if label_name: if label_names:
assert isinstance(label_name, (dict, list, tuple)) assert isinstance(label_names, (dict, list, tuple))
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages) # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
# can be a dict id: name if there are id gaps, or tuple/list if no gaps. # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
hf_config['label_name'] = model_config['label_name'] hf_config['label_names'] = label_names
display_name = model_config.pop('display_name', None) label_descriptions = model_config.pop('label_descriptions', None)
if display_name: if label_descriptions:
assert isinstance(display_name, dict) assert isinstance(label_descriptions, dict)
# map label_name -> user interface display name # maps label names -> descriptions
hf_config['display_name'] = model_config['display_name'] hf_config['label_descriptions'] = label_descriptions
hf_config['pretrained_cfg'] = pretrained_cfg hf_config['pretrained_cfg'] = pretrained_cfg
hf_config.update(model_config) hf_config.update(model_config)
@ -188,13 +230,23 @@ def save_config_for_hf(model, config_path, model_config=None):
json.dump(hf_config, f, indent=2) json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory, model_config=None): def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
):
assert has_hf_hub(True) assert has_hf_hub(True)
save_directory = Path(save_directory) save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True) save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin' # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
torch.save(model.state_dict(), weights_path) tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
config_path = save_directory / 'config.json' config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config) save_config_for_hf(model, config_path, model_config=model_config)
@ -210,7 +262,15 @@ def push_to_hf_hub(
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
model_card: Optional[dict] = None, model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
): ):
"""
Arguments:
(...)
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Can be set to `"both"` in order to push both safe and unsafe weights.
"""
# Create repo if it doesn't exist yet # Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -229,7 +289,7 @@ def push_to_hf_hub(
# Dump model and push to Hub # Dump model and push to Hub
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
# Save model weights and config. # Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config) save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
# Add readme if it does not exist # Add readme if it does not exist
if not has_readme: if not has_readme:
@ -249,7 +309,7 @@ def push_to_hf_hub(
) )
def generate_readme(model_card, model_name): def generate_readme(model_card: dict, model_name: str):
readme_text = "---\n" readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n" readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n" readme_text += "library_tag: timm\n"
@ -295,3 +355,16 @@ def generate_readme(model_card, model_name):
for c in citations: for c in citations:
readme_text += f"```bibtex\n{c}\n```\n" readme_text += f"```bibtex\n{c}\n```\n"
return readme_text return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
Use case:
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
return filename[:-4] + ".safetensors"

@ -34,9 +34,11 @@ class PretrainedCfg:
mean: Tuple[float, ...] = (0.485, 0.456, 0.406) mean: Tuple[float, ...] = (0.485, 0.456, 0.406)
std: Tuple[float, ...] = (0.229, 0.224, 0.225) std: Tuple[float, ...] = (0.229, 0.224, 0.225)
# head config # head / classifier config and meta-data
num_classes: int = 1000 num_classes: int = 1000
label_offset: Optional[int] = None label_offset: Optional[int] = None
label_names: Optional[Tuple[str]] = None
label_descriptions: Optional[Dict[str, str]] = None
# model attributes that vary with above or required for pretrained adaptation # model attributes that vary with above or required for pretrained adaptation
pool_size: Optional[Tuple[int, ...]] = None pool_size: Optional[Tuple[int, ...]] = None
@ -91,7 +93,7 @@ class DefaultCfg:
return tag, self.cfgs[tag] return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''): def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1) model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag tag = tag_list[0] if tag_list else no_tag
return model_name, tag return model_name, tag

@ -8,7 +8,7 @@ import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Union, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module: Dict[str, str] = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects _model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs _model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names _model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0] return split_model_name_tag(model_name)[0]
def register_model(fn): def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module # lookup containing module
mod = sys.modules[fn.__module__] mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.') module_name_split = fn.__module__.split('.')
@ -40,7 +40,7 @@ def register_model(fn):
if hasattr(mod, '__all__'): if hasattr(mod, '__all__'):
mod.__all__.append(model_name) mod.__all__.append(model_name)
else: else:
mod.__all__ = [model_name] mod.__all__ = [model_name] # type: ignore
# add entries to registry dict/sets # add entries to registry dict/sets
_model_entrypoints[model_name] = fn _model_entrypoints[model_name] = fn
@ -87,28 +87,33 @@ def register_model(fn):
return fn return fn
def _natural_key(string_): def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models( def list_models(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
module: str = '', module: str = '',
pretrained=False, pretrained: bool = False,
exclude_filters: str = '', exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False, name_matches_cfg: bool = False,
include_tags: Optional[bool] = None, include_tags: Optional[bool] = None,
): ) -> List[str]:
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') module - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True pretrained - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults include_tags - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None) set to True when pretrained=True else False (default: None)
Returns:
models - The sorted list of models
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@ -118,7 +123,7 @@ def list_models(
include_tags = pretrained include_tags = pretrained
if module: if module:
all_models = list(_module_to_models[module]) all_models: Iterable[str] = list(_module_to_models[module])
else: else:
all_models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
@ -130,14 +135,14 @@ def list_models(
all_models = models_with_tags all_models = models_with_tags
if filter: if filter:
models = [] models: Set[str] = set()
include_filters = filter if isinstance(filter, (tuple, list)) else [filter] include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters: for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models): if len(include_models):
models = set(models).union(include_models) models = models.union(include_models)
else: else:
models = all_models models = set(all_models)
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
@ -145,7 +150,7 @@ def list_models(
for xf in exclude_filters: for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models): if len(exclude_models):
models = set(models).difference(exclude_models) models = models.difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
@ -153,13 +158,13 @@ def list_models(
if name_matches_cfg: if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models) models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return sorted(models, key=_natural_key)
def list_pretrained( def list_pretrained(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
exclude_filters: str = '', exclude_filters: str = '',
): ) -> List[str]:
return list_models( return list_models(
filter=filter, filter=filter,
pretrained=True, pretrained=True,
@ -168,14 +173,14 @@ def list_pretrained(
) )
def is_model(model_name): def is_model(model_name: str) -> bool:
""" Check if a model name exists """ Check if a model name exists
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints return arch_name in _model_entrypoints
def model_entrypoint(model_name, module_filter: Optional[str] = None): def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
"""Fetch a model entrypoint for specified model name """Fetch a model entrypoint for specified model name
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
return _model_entrypoints[arch_name] return _model_entrypoints[arch_name]
def list_modules(): def list_modules() -> List[str]:
""" Return list of module names that contain models / model entrypoints """ Return list of module names that contain models / model entrypoints
""" """
modules = _module_to_models.keys() modules = _module_to_models.keys()
return list(sorted(modules)) return sorted(modules)
def is_model_in_modules(model_name, module_names): def is_model_in_modules(
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
"""Check if a model exists within a subset of modules """Check if a model exists within a subset of modules
Args: Args:
model_name (str) - name of model to check model_name - name of model to check
module_names (tuple, list, set) - names of modules to search in module_names - names of modules to search in
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set)) assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names) return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name): def is_model_pretrained(model_name: str) -> bool:
return model_name in _model_has_pretrained return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name, allow_unregistered=True): def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name]) return deepcopy(_model_pretrained_cfgs[model_name])
arch_name, tag = split_model_name_tag(model_name) arch_name, tag = split_model_name_tag(model_name)
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.') raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key): def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
""" Get a specific model default_cfg value by key. None if key doesn't exist. """ Get a specific model default_cfg value by key. None if key doesn't exist.
""" """
cfg = get_pretrained_cfg(model_name, allow_unregistered=False) cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

@ -735,6 +735,11 @@ default_cfgs = generate_default_cfgs({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0 input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
), ),
'convnext_large_mlp.clip_laion2b_augreg_ft_in1k_384': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'
),
# CLIP based weights, original image tower weights and fine-tunes # CLIP based weights, original image tower weights and fine-tunes
@ -768,6 +773,16 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
}) })

@ -217,9 +217,9 @@ def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -41,6 +41,7 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct from timm.layers import create_conv2d, create_classifier, get_norm_act_layer, GroupNormAct
@ -211,6 +212,7 @@ class EfficientNetFeatures(nn.Module):
norm_act_layer = get_norm_act_layer(norm_layer, act_layer) norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False
# Stem # Stem
if not fix_stem: if not fix_stem:
@ -241,6 +243,10 @@ class EfficientNetFeatures(nn.Module):
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
@ -249,6 +255,9 @@ class EfficientNetFeatures(nn.Module):
if 0 in self._stage_out_idx: if 0 in self._stage_out_idx:
features.append(x) # add stem out features.append(x) # add stem out
for i, b in enumerate(self.blocks): for i, b in enumerate(self.blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(b, x)
else:
x = b(x) x = b(x)
if i + 1 in self._stage_out_idx: if i + 1 in self._stage_out_idx:
features.append(x) features.append(x)

@ -243,9 +243,9 @@ def window_partition(x, window_size: Tuple[int, int]):
@register_notrace_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -42,6 +42,7 @@ from typing import Callable, Optional, Union, Tuple, List
import torch import torch
from torch import nn from torch import nn
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
@ -140,6 +141,8 @@ class MaxxVitCfg:
class Attention2d(nn.Module): class Attention2d(nn.Module):
fast_attn: Final[bool]
""" multi-head attention for 2D NCHW tensors""" """ multi-head attention for 2D NCHW tensors"""
def __init__( def __init__(
self, self,
@ -160,6 +163,7 @@ class Attention2d(nn.Module):
self.dim_head = dim_head self.dim_head = dim_head
self.head_first = head_first self.head_first = head_first
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@ -175,15 +179,31 @@ class Attention2d(nn.Module):
else: else:
q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
attn = (q.transpose(-2, -1) @ k) * self.scale if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q.transpose(-1, -2),
k.transpose(-1, -2),
v.transpose(-1, -2),
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
).transpose(-1, -2).reshape(B, -1, H, W)
else:
q = q * self.scale
attn = q.transpose(-2, -1) @ k
if self.rel_pos is not None: if self.rel_pos is not None:
attn = self.rel_pos(attn) attn = self.rel_pos(attn)
elif shared_rel_pos is not None: elif shared_rel_pos is not None:
attn = attn + shared_rel_pos attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -191,6 +211,8 @@ class Attention2d(nn.Module):
class AttentionCl(nn.Module): class AttentionCl(nn.Module):
""" Channels-last multi-head attention (B, ..., C) """ """ Channels-last multi-head attention (B, ..., C) """
fast_attn: Final[bool]
def __init__( def __init__(
self, self,
dim: int, dim: int,
@ -211,6 +233,7 @@ class AttentionCl(nn.Module):
self.dim_head = dim_head self.dim_head = dim_head
self.head_first = head_first self.head_first = head_first
self.scale = dim_head ** -0.5 self.scale = dim_head ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
@ -227,15 +250,30 @@ class AttentionCl(nn.Module):
else: else:
q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
attn = (q @ k.transpose(-2, -1)) * self.scale if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None: if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None: elif shared_rel_pos is not None:
attn = attn + shared_rel_pos attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(restore_shape + (-1,)) x = x.transpose(1, 2).reshape(restore_shape + (-1,))
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x

@ -12,6 +12,7 @@ from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
@ -188,6 +189,7 @@ class MobileNetV3Features(nn.Module):
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
se_layer = se_layer or SqueezeExcite se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False
# Stem # Stem
if not fix_stem: if not fix_stem:
@ -220,6 +222,10 @@ class MobileNetV3Features(nn.Module):
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type')) hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
self.feature_hooks = FeatureHooks(hooks, self.named_modules()) self.feature_hooks = FeatureHooks(hooks, self.named_modules())
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
@ -229,6 +235,9 @@ class MobileNetV3Features(nn.Module):
if 0 in self._stage_out_idx: if 0 in self._stage_out_idx:
features.append(x) # add stem out features.append(x) # add stem out
for i, b in enumerate(self.blocks): for i, b in enumerate(self.blocks):
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(b, x)
else:
x = b(x) x = b(x)
if i + 1 in self._stage_out_idx: if i + 1 in self._stage_out_idx:
features.append(x) features.append(x)

@ -126,9 +126,9 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size / window_size)) C = windows.shape[-1]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -120,9 +120,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C) x: (B, H, W, C)
""" """
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -139,9 +139,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C) x: (B, H, W, C)
""" """
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -33,11 +33,12 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.jit import Final
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD, \
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \ from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, resample_patch_embed, \
resample_abs_pos_embed resample_abs_pos_embed, RmsNorm
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
from ._pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
@ -51,28 +52,51 @@ _logger = logging.getLogger(__name__)
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads self.head_dim = dim // num_heads
self.scale = head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0)
q, k = self.q_norm(q), self.k_norm(k)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.fast_attn:
x = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -96,6 +120,7 @@ class Block(nn.Module):
num_heads, num_heads,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_norm=False,
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
init_values=None, init_values=None,
@ -105,13 +130,25 @@ class Block(nn.Module):
): ):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -129,6 +166,7 @@ class ResPostBlock(nn.Module):
num_heads, num_heads,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_norm=False,
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
init_values=None, init_values=None,
@ -139,11 +177,24 @@ class ResPostBlock(nn.Module):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -161,8 +212,105 @@ class ResPostBlock(nn.Module):
return x return x
class ParallelBlock(nn.Module): class ParallelScalingBlock(nn.Module):
""" Parallel ViT block (MLP & Attention in parallel)
Based on:
'Scaling Vision Transformers to 22 Billion Parameters` - https://arxiv.org/abs/2302.05442
"""
fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
drop=0.,
attn_drop=0.,
init_values=None,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm
):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
mlp_hidden_dim = int(mlp_ratio * dim)
in_proj_out_dim = mlp_hidden_dim + 3 * dim
self.in_norm = norm_layer(dim)
self.in_proj = nn.Linear(dim, in_proj_out_dim, bias=qkv_bias)
self.in_split = [mlp_hidden_dim] + [dim] * 3
if qkv_bias:
self.register_buffer('qkv_bias', None)
self.register_parameter('mlp_bias', None)
else:
self.register_buffer('qkv_bias', torch.zeros(3 * dim), persistent=False)
self.mlp_bias = nn.Parameter(torch.zeros(mlp_hidden_dim))
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.attn_out_proj = nn.Linear(dim, dim)
self.mlp_drop = nn.Dropout(drop)
self.mlp_act = act_layer()
self.mlp_out_proj = nn.Linear(mlp_hidden_dim, dim)
self.ls = LayerScale(dim, init_values=init_values) if init_values is not None else nn.Identity()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
B, N, C = x.shape
# Combined MLP fc1 & qkv projections
y = self.in_norm(x)
if self.mlp_bias is not None:
# Concat constant zero-bias for qkv w/ trainable mlp_bias.
# Appears faster than adding to x_mlp separately
y = F.linear(y, self.in_proj.weight, torch.cat((self.qkv_bias, self.mlp_bias)))
else:
y = self.in_proj(y)
x_mlp, q, k, v = torch.split(y, self.in_split, dim=-1)
# Dot product attention w/ qk norm
q = self.q_norm(q.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
if self.fast_attn:
x_attn = F.scaled_dot_product_attention(
q, k, v,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_attn = attn @ v
x_attn = x_attn.transpose(1, 2).reshape(B, N, C)
x_attn = self.attn_out_proj(x_attn)
# MLP activation, dropout, fc2
x_mlp = self.mlp_act(x_mlp)
x_mlp = self.mlp_drop(x_mlp)
x_mlp = self.mlp_out_proj(x_mlp)
# Add residual w/ drop path & layer scale applied
y = self.drop_path(self.ls(x_attn + x_mlp))
x = x + y
return x
class ParallelThingsBlock(nn.Module):
""" Parallel ViT block (N parallel attention followed by N parallel MLP)
Based on:
`Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
def __init__( def __init__(
self, self,
dim, dim,
@ -170,6 +318,7 @@ class ParallelBlock(nn.Module):
num_parallel=2, num_parallel=2,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=False, qkv_bias=False,
qk_norm=False,
init_values=None, init_values=None,
drop=0., drop=0.,
attn_drop=0., attn_drop=0.,
@ -184,13 +333,26 @@ class ParallelBlock(nn.Module):
for _ in range(num_parallel): for _ in range(num_parallel):
self.attns.append(nn.Sequential(OrderedDict([ self.attns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)), ('norm', norm_layer(dim)),
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), ('attn', Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
attn_drop=attn_drop,
proj_drop=drop,
norm_layer=norm_layer,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
]))) ])))
self.ffns.append(nn.Sequential(OrderedDict([ self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)), ('norm', norm_layer(dim)),
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), ('mlp', Mlp(
dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
]))) ])))
@ -232,6 +394,7 @@ class VisionTransformer(nn.Module):
num_heads=12, num_heads=12,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=True, qkv_bias=True,
qk_norm=False,
init_values=None, init_values=None,
class_token=True, class_token=True,
no_embed_class=False, no_embed_class=False,
@ -305,6 +468,7 @@ class VisionTransformer(nn.Module):
num_heads=num_heads, num_heads=num_heads,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qkv_bias=qkv_bias,
qk_norm=qk_norm,
init_values=init_values, init_values=init_values,
drop=drop_rate, drop=drop_rate,
attn_drop=attn_drop_rate, attn_drop=attn_drop_rate,
@ -641,9 +805,8 @@ def checkpoint_filter_fn(
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
import re import re
out_dict = {} out_dict = {}
if 'model' in state_dict: state_dict = state_dict.get('model', state_dict)
# For deit models state_dict = state_dict.get('state_dict', state_dict)
state_dict = state_dict['model']
if 'visual.class_embedding' in state_dict: if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model) return _convert_openai_clip(state_dict, model)
@ -1129,6 +1292,10 @@ default_cfgs = generate_default_cfgs({
url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True, url='https://storage.googleapis.com/big_vision/flexivit/vit_b30_i21k_300ep.npz', custom_load=True,
hf_hub_id='timm/', hf_hub_id='timm/',
input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843), input_size=(3, 240, 240), crop_pct=0.95, num_classes=21843),
'vit_base_patch16_xp_224.untrained': _cfg(url=''),
'vit_large_patch14_xp_224.untrained': _cfg(url=''),
'vit_huge_patch14_xp_224.untrained': _cfg(url=''),
}) })
@ -1566,7 +1733,7 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock) patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelThingsBlock)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1577,7 +1744,8 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock) model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelThingsBlock)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@ -1625,3 +1793,42 @@ def flexivit_large(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs)) model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model return model
@register_model
def vit_base_patch16_xp_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_base_patch16_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_large_patch14_xp_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_large_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_huge_patch14_xp_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) w/ parallel blocks and qk norm enabled.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, no_embed_class=True,
norm_layer=RmsNorm, block_fn=ParallelScalingBlock, qkv_bias=False, qk_norm=True,
)
model = _create_vision_transformer(
'vit_huge_patch14_xp_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model

@ -11,6 +11,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.jit import Final
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
@ -25,14 +26,29 @@ _logger = logging.getLogger(__name__)
class RelPosAttention(nn.Module): class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): fast_attn: Final[bool]
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
attn_drop=0.,
proj_drop=0.,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads' assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads self.head_dim = dim // num_heads
self.scale = head_dim ** -0.5 self.scale = self.head_dim ** -0.5
self.fast_attn = hasattr(torch.nn.functional, 'scaled_dot_product_attention') # FIXME
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
@ -40,18 +56,35 @@ class RelPosAttention(nn.Module):
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
attn = (q @ k.transpose(-2, -1)) * self.scale if self.fast_attn:
if self.rel_pos is not None:
attn_bias = self.rel_pos.get_bias()
elif shared_rel_pos is not None:
attn_bias = shared_rel_pos
else:
attn_bias = None
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_bias,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if self.rel_pos is not None: if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None: elif shared_rel_pos is not None:
attn = attn + shared_rel_pos attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = attn @ v
x = (attn @ v).transpose(1, 2).reshape(B, N, C) x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -70,18 +103,42 @@ class LayerScale(nn.Module):
class RelPosBlock(nn.Module): class RelPosBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, self,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = RelPosAttention( self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) dim,
num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -94,17 +151,41 @@ class RelPosBlock(nn.Module):
class ResPostRelPosBlock(nn.Module): class ResPostRelPosBlock(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, self,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
qk_norm=False,
rel_pos_cls=None,
init_values=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
):
super().__init__() super().__init__()
self.init_values = init_values self.init_values = init_values
self.attn = RelPosAttention( self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) dim,
num_heads,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
attn_drop=attn_drop,
proj_drop=drop,
)
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=drop,
)
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -144,6 +225,7 @@ class VisionTransformerRelPos(nn.Module):
num_heads=12, num_heads=12,
mlp_ratio=4., mlp_ratio=4.,
qkv_bias=True, qkv_bias=True,
qk_norm=False,
init_values=1e-6, init_values=1e-6,
class_token=False, class_token=False,
fc_norm=False, fc_norm=False,
@ -171,6 +253,7 @@ class VisionTransformerRelPos(nn.Module):
num_heads (int): number of attention heads num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
qk_norm (bool): Enable normalization of query and key in attention
init_values: (float): layer-scale init values init_values: (float): layer-scale init values
class_token (bool): use class token (default: False) class_token (bool): use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool fc_norm (bool): use pre classifier norm instead of pre-pool
@ -197,18 +280,19 @@ class VisionTransformerRelPos(nn.Module):
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
)
feat_size = self.patch_embed.grid_size feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens) rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
if rel_pos_type.startswith('mlp'): if rel_pos_type.startswith('mlp'):
if rel_pos_dim: if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim rel_pos_args['hidden_dim'] = rel_pos_dim
# FIXME experimenting with different relpos log coord configs
if 'swin' in rel_pos_type: if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin' rel_pos_args['mode'] = 'swin'
elif 'rw' in rel_pos_type:
rel_pos_args['mode'] = 'rw'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args) rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else: else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args) rel_pos_cls = partial(RelPosBias, **rel_pos_args)
@ -223,9 +307,19 @@ class VisionTransformerRelPos(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
block_fn( block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, dim=embed_dim,
init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], num_heads=num_heads,
norm_layer=norm_layer, act_layer=act_layer) mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_norm=qk_norm,
rel_pos_cls=rel_pos_cls,
init_values=init_values,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()

@ -0,0 +1,226 @@
""" Lion Optimizer
Paper: `Symbolic Discovery of Optimization Algorithms` - https://arxiv.org/abs/2302.06675
Original Impl: https://github.com/google/automl/tree/master/lion
"""
# Copyright 2023 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import List
import torch
from torch.optim.optimizer import Optimizer
class Lion(Optimizer):
r"""Implements Lion algorithm."""
def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0.0,
maximize=False,
foreach=None,
):
"""Initialize the hyperparameters.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-4)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.99))
weight_decay (float, optional): weight decay coefficient (default: 0)
"""
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('maximize', False)
group.setdefault('foreach', None)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
Returns:
the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
beta1, beta2 = group['betas']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Lion does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
lion(
params_with_grad,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
maximize=group['maximize'],
foreach=group['foreach'],
)
return loss
def lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
maximize: bool = False,
foreach: bool = None,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs Lion algorithm computation.
"""
if foreach is None:
# Placeholder for more complex foreach logic to be added when value is not set
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_lion
else:
func = _single_tensor_lion
func(
params,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
)
def _single_tensor_lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
param = torch.view_as_real(param)
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# Weight update
update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
param.add_(torch.sign(update), alpha=-lr)
# Decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2)
def _multi_tensor_lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
if len(params) == 0:
return
if maximize:
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
# Perform stepweight decay
torch._foreach_mul_(params, 1 - lr * weight_decay)
# Weight update
updates = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(updates, grads, alpha=1 - beta1)
updates = [u.sign() for u in updates]
torch._foreach_add_(params, updates, alpha=-lr)
# Decay the momentum running average coefficient
torch._foreach_mul_(exp_avgs, beta2)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2)

@ -18,6 +18,7 @@ from .adamp import AdamP
from .adan import Adan from .adan import Adan
from .lamb import Lamb from .lamb import Lamb
from .lars import Lars from .lars import Lars
from .lion import Lion
from .lookahead import Lookahead from .lookahead import Lookahead
from .madgrad import MADGRAD from .madgrad import MADGRAD
from .nadam import Nadam from .nadam import Nadam
@ -35,6 +36,12 @@ except ImportError:
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# optimizers to default to multi-tensor
_DEFAULT_FOREACH = {
'lion',
}
def param_groups_weight_decay( def param_groups_weight_decay(
model: nn.Module, model: nn.Module,
weight_decay=1e-5, weight_decay=1e-5,
@ -161,7 +168,8 @@ def optimizer_kwargs(cfg):
opt=cfg.opt, opt=cfg.opt,
lr=cfg.lr, lr=cfg.lr,
weight_decay=cfg.weight_decay, weight_decay=cfg.weight_decay,
momentum=cfg.momentum) momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None: if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None: if getattr(cfg, 'opt_betas', None) is not None:
@ -170,6 +178,8 @@ def optimizer_kwargs(cfg):
kwargs['layer_decay'] = cfg.layer_decay kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None: if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args) kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
return kwargs return kwargs
@ -190,6 +200,7 @@ def create_optimizer_v2(
lr: Optional[float] = None, lr: Optional[float] = None,
weight_decay: float = 0., weight_decay: float = 0.,
momentum: float = 0.9, momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True, filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None, layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None, param_group_fn: Optional[Callable] = None,
@ -208,6 +219,7 @@ def create_optimizer_v2(
lr: initial learning rate lr: initial learning rate
weight_decay: weight decay to apply in optimizer weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs) momentum: momentum for momentum based optimizers (others may use betas via kwargs)
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through **kwargs: extra optimizer specific kwargs to pass through
@ -227,7 +239,8 @@ def create_optimizer_v2(
model_or_params, model_or_params,
weight_decay=weight_decay, weight_decay=weight_decay,
layer_decay=layer_decay, layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay) no_weight_decay_list=no_weight_decay,
)
weight_decay = 0. weight_decay = 0.
elif weight_decay and filter_bias_and_bn: elif weight_decay and filter_bias_and_bn:
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
@ -245,9 +258,16 @@ def create_optimizer_v2(
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs) opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None: if lr is not None:
opt_args.setdefault('lr', lr) opt_args.setdefault('lr', lr)
if foreach is None:
if opt in _DEFAULT_FOREACH:
opt_args.setdefault('foreach', True)
else:
opt_args['foreach'] = foreach
# basic SGD & related # basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov': if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
@ -313,6 +333,8 @@ def create_optimizer_v2(
optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args) optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'rmsproptf': elif opt_lower == 'rmsproptf':
optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args) optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
elif opt_lower == 'lion':
optimizer = Lion(parameters, **opt_args)
# second order # second order
elif opt_lower == 'adahessian': elif opt_lower == 'adahessian':

@ -1 +1 @@
__version__ = '0.8.9dev0' __version__ = '0.8.13dev0'

Loading…
Cancel
Save