Compare commits

..

5 Commits

@ -1,112 +0,0 @@
*This guideline is very much a work-in-progress.*
Contriubtions to `timm` for code, documentation, tests are more than welcome!
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
# Coding style
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
A few specific differences from Google style (or black)
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
Example, from Google guide, but this is a NO here:
```
# Aligned with opening delimiter.
foo = long_function_name(var_one, var_two,
var_three, var_four)
meal = (spam,
beans)
# Aligned with opening delimiter in a dictionary.
foo = {
'long_dictionary_key': value1 +
value2,
...
}
```
This is YES:
```
# 4-space hanging indent; nothing on first line,
# closing parenthesis on a new line.
foo = long_function_name(
var_one, var_two, var_three,
var_four
)
meal = (
spam,
beans,
)
# 4-space hanging indent in a dictionary.
foo = {
'long_dictionary_key':
long_dictionary_value,
...
}
```
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
```
black --skip-string-normalization --line-length 120 <path-to-file>
```
Avoid formatting code that is unrelated to your PR though.
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
# Documentation
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
# Installation
Create a Python virtual environment using Python 3.10. Inside the environment, install the following test dependencies:
```
python -m pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
```
Install `torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
Then install the remaining dependencies:
```
python -m pip install -r requirements.txt
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
python -m pip install -e .
```
## Unit tests
Run the tests using:
```
pytest tests/
```
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
```
pytest -k "substring-to-match" -n 4 tests/
```
## Building documentation
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
# Questions
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).

@ -24,10 +24,6 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Feb 20, 2023
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
### Feb 16, 2023
* `safetensor` checkpoint support added
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block

@ -17,14 +17,10 @@ import os
import glob
import hashlib
from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
DEFAULT_OUTPUT = "./average.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
@ -42,7 +38,6 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {}
@ -68,20 +63,14 @@ def main():
if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
if args.safetensors and not args.output.endswith(".safetensors"):
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(output))
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
pattern = args.input
@ -98,27 +87,22 @@ def main():
checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:]
if checkpoint_metrics:
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics]
else:
avg_checkpoints = checkpoints
if avg_checkpoints:
print("Selected checkpoints:")
[print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {}
avg_counts = {}
for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict:
print(f"Error: Checkpoint ({c}) doesn't exist")
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
continue
for k, v in new_state_dict.items():
if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -138,14 +122,16 @@ def main():
final_state_dict[k] = v.to(dtype=torch.float32)
if args.safetensors:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(final_state_dict, output)
safetensors.torch.save_file(final_state_dict, args.output)
else:
torch.save(final_state_dict, output)
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(output, 'rb') as f:
with open(args.output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
if __name__ == '__main__':

@ -11,14 +11,9 @@ import torch
import argparse
import os
import hashlib
import safetensors.torch
import shutil
import tempfile
from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -27,13 +22,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth'
def main():
args = parser.parse_args()
@ -42,24 +37,10 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint))
@ -74,36 +55,25 @@ def clean_checkpoint(
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))
ext = ''
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base, ext = os.path.splitext(checkpoint_base)
else:
checkpoint_root = ''
checkpoint_base = os.path.split(checkpoint)[1]
checkpoint_base = os.path.splitext(checkpoint_base)[0]
temp_filename = '__' + checkpoint_base
if safe_serialization:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
else:
torch.save(new_state_dict, temp_filename)
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(temp_filename, 'rb') as f:
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if ext:
final_ext = ext
else:
final_ext = ('.safetensors' if safe_serialization else '.pth')
if no_hash:
final_filename = checkpoint_base + final_ext
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base = os.path.splitext(checkpoint_base)[0]
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename
else:

@ -0,0 +1,2 @@
model-index==0.1.10
jinja2==2.11.3

@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
setup(
name='timm',
version=__version__,
description='PyTorch Image Models',
description='(Unofficial) PyTorch Image Models',
long_description=long_description,
long_description_content_type='text/markdown',
url='https://github.com/huggingface/pytorch-image-models',
url='https://github.com/rwightman/pytorch-image-models',
author='Ross Wightman',
author_email='ross@huggingface.co',
author_email='hello@rwightman.com',
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
@ -29,11 +29,11 @@ setup(
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
@ -45,7 +45,7 @@ setup(
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True,
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
python_requires='>=3.7',
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'],
python_requires='>=3.6',
)

@ -88,6 +88,7 @@ class IterableImageDataset(data.IterableDataset):
root,
reader=None,
split='train',
class_map=None,
is_training=False,
batch_size=None,
seed=42,
@ -102,6 +103,7 @@ class IterableImageDataset(data.IterableDataset):
reader,
root=root,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
seed=seed,

@ -157,6 +157,7 @@ def create_dataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
download=download,
batch_size=batch_size,
@ -169,6 +170,7 @@ def create_dataset(
root,
reader=name,
split=split,
class_map=class_map,
is_training=is_training,
batch_size=batch_size,
repeats=repeats,

@ -34,6 +34,7 @@ except ImportError as e:
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
exit(1)
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
@ -94,6 +95,7 @@ class ReaderTfds(Reader):
root,
name,
split='train',
class_map=None,
is_training=False,
batch_size=None,
download=False,
@ -151,6 +153,11 @@ class ReaderTfds(Reader):
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
if download:
self.builder.download_and_prepare()
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
self.split_info = self.builder.info.splits[split]
self.num_samples = self.split_info.num_examples
@ -299,6 +306,8 @@ class ReaderTfds(Reader):
target_data = sample[self.target_name]
if self.target_img_mode:
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:

@ -29,6 +29,7 @@ except ImportError:
wds = None
expand_urls = None
from .class_map import load_class_map
from .reader import Reader
from .shared_count import SharedCount
@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
info_yaml = os.path.join(root, basename + '.yaml')
err_str = ''
try:
with wds.gopen.gopen(info_json) as f:
with wds.gopen(info_json) as f:
info_dict = json.load(f)
return info_dict
except Exception as e:
err_str = str(e)
try:
with wds.gopen.gopen(info_yaml) as f:
with wds.gopen(info_yaml) as f:
info_dict = yaml.safe_load(f)
return info_dict
except Exception:
@ -110,8 +111,8 @@ def _parse_split_info(split: str, info: Dict):
filenames=split_filenames,
)
else:
if split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
if 'splits' not in info or split not in info['splits']:
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
split = split
split_info = info['splits'][split]
split_info = _info_convert(split_info)
@ -290,6 +291,7 @@ class ReaderWds(Reader):
batch_size=None,
repeats=0,
seed=42,
class_map=None,
input_name='jpg',
input_image='RGB',
target_name='cls',
@ -320,6 +322,12 @@ class ReaderWds(Reader):
self.num_samples = self.split_info.num_samples
if not self.num_samples:
raise RuntimeError(f'Invalid split definition, no samples found.')
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = {}
# Distributed world state
self.dist_rank = 0
@ -431,7 +439,10 @@ class ReaderWds(Reader):
i = 0
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
for sample in ds:
yield sample[self.image_key], sample[self.target_key]
target = sample[self.target_key]
if self.remap_class:
target = self.class_to_idx[target]
yield sample[self.image_key], target
i += 1
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug

@ -17,6 +17,7 @@ from .edgenext import *
from .efficientformer import *
from .efficientformer_v2 import *
from .efficientnet import *
from .focalnet import *
from .gcvit import *
from .ghostnet import *
from .gluon_resnet import *

@ -7,11 +7,7 @@ import os
from collections import OrderedDict
import torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder
@ -33,7 +29,6 @@ def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')

@ -7,21 +7,15 @@ from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8):
from typing import Literal
else:
@ -51,7 +45,6 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
@ -171,28 +164,21 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
hf_model_id, hf_revision = hf_split(model_id)
# Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
model_config = model_config or {}
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -234,7 +220,7 @@ def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
safe_serialization: Union[bool, Literal["both"]] = False
):
assert has_hf_hub(True)
save_directory = Path(save_directory)
@ -243,7 +229,6 @@ def save_for_hf(
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
@ -262,7 +247,7 @@ def push_to_hf_hub(
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
safe_serialization: Union[bool, Literal["both"]] = False
):
"""
Arguments:
@ -356,7 +341,6 @@ def generate_readme(model_card: dict, model_name: str):
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
@ -366,5 +350,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
return filename[:-4] + ".safetensors"
if filename.endswith(".bin"):
yield filename[:-4] + ".safetensors"

@ -93,7 +93,7 @@ class DefaultCfg:
return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
def split_model_name_tag(model_name: str, no_tag=''):
model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag
return model_name, tag

@ -8,7 +8,7 @@ import sys
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module: Dict[str, str] = {} # mapping of model names to module names
_model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
_model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> str:
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]:
return split_model_name_tag(model_name)[0]
def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
def register_model(fn):
# lookup containing module
mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.')
@ -40,7 +40,7 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
if hasattr(mod, '__all__'):
mod.__all__.append(model_name)
else:
mod.__all__ = [model_name] # type: ignore
mod.__all__ = [model_name]
# add entries to registry dict/sets
_model_entrypoints[model_name] = fn
@ -87,33 +87,28 @@ def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
return fn
def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(
filter: Union[str, List[str]] = '',
module: str = '',
pretrained: bool = False,
exclude_filters: Union[str, List[str]] = '',
pretrained=False,
exclude_filters: str = '',
name_matches_cfg: bool = False,
include_tags: Optional[bool] = None,
) -> List[str]:
):
""" Return list of available model names, sorted alphabetically
Args:
filter - Wildcard filter string that works with fnmatch
module - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained - Include only models with valid pretrained weights if True
exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags - Include pretrained tags in model names (model.tag). If None, defaults
filter (str) - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None)
Returns:
models - The sorted list of models
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@ -123,7 +118,7 @@ def list_models(
include_tags = pretrained
if module:
all_models: Iterable[str] = list(_module_to_models[module])
all_models = list(_module_to_models[module])
else:
all_models = _model_entrypoints.keys()
@ -135,14 +130,14 @@ def list_models(
all_models = models_with_tags
if filter:
models: Set[str] = set()
models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = models.union(include_models)
models = set(models).union(include_models)
else:
models = set(all_models)
models = all_models
if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)):
@ -150,7 +145,7 @@ def list_models(
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models):
models = models.difference(exclude_models)
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
@ -158,13 +153,13 @@ def list_models(
if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models)
return sorted(models, key=_natural_key)
return list(sorted(models, key=_natural_key))
def list_pretrained(
filter: Union[str, List[str]] = '',
exclude_filters: str = '',
) -> List[str]:
):
return list_models(
filter=filter,
pretrained=True,
@ -173,14 +168,14 @@ def list_pretrained(
)
def is_model(model_name: str) -> bool:
def is_model(model_name):
""" Check if a model name exists
"""
arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints
def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
def model_entrypoint(model_name, module_filter: Optional[str] = None):
"""Fetch a model entrypoint for specified model name
"""
arch_name = get_arch_name(model_name)
@ -189,32 +184,29 @@ def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Ca
return _model_entrypoints[arch_name]
def list_modules() -> List[str]:
def list_modules():
""" Return list of module names that contain models / model entrypoints
"""
modules = _module_to_models.keys()
return sorted(modules)
return list(sorted(modules))
def is_model_in_modules(
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
def is_model_in_modules(model_name, module_names):
"""Check if a model exists within a subset of modules
Args:
model_name - name of model to check
module_names - names of modules to search in
model_name (str) - name of model to check
module_names (tuple, list, set) - names of modules to search in
"""
arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name: str) -> bool:
def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
def get_pretrained_cfg(model_name, allow_unregistered=True):
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
arch_name, tag = split_model_name_tag(model_name)
@ -227,7 +219,7 @@ def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Opti
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if key doesn't exist.
"""
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

@ -773,16 +773,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
})

@ -217,9 +217,9 @@ def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int
Returns:
x: (B, H, W, C)
"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

@ -0,0 +1,637 @@
""" FocalNet
As described in `Focal Modulation Networks` - https://arxiv.org/abs/2203.11926
Significant modifications and refactoring from the original impl at https://github.com/microsoft/FocalNet
This impl is/has:
* fully convolutional, NCHW tensor layout throughout, seemed to have minimal performance impact but more flexible
* re-ordered downsample / layer so that striding always at beginning of layer (stage)
* no input size constraints or input resolution/H/W tracking through the model
* torchscript fixed and a number of quirks cleaned up
* feature extraction support via `features_only=True`
"""
# --------------------------------------------------------
# FocalNets -- Focal Modulation Networks
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Jianwei Yang (jianwyan@microsoft.com)
# --------------------------------------------------------
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, DropPath, LayerNorm2d, trunc_normal_, ClassifierHead, NormMlpClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply
from ._registry import register_model
__all__ = ['FocalNet']
class FocalModulation(nn.Module):
def __init__(
self,
dim,
focal_window,
focal_level,
focal_factor=2,
bias=True,
use_post_norm=False,
normalize_modulator=False,
proj_drop=0.,
norm_layer=LayerNorm2d,
):
super().__init__()
self.dim = dim
self.focal_window = focal_window
self.focal_level = focal_level
self.focal_factor = focal_factor
self.use_post_norm = use_post_norm
self.normalize_modulator = normalize_modulator
self.input_split = [dim, dim, self.focal_level + 1]
self.f = nn.Conv2d(dim, 2 * dim + (self.focal_level + 1), kernel_size=1, bias=bias)
self.h = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
self.act = nn.GELU()
self.proj = nn.Conv2d(dim, dim, kernel_size=1)
self.proj_drop = nn.Dropout(proj_drop)
self.focal_layers = nn.ModuleList()
self.kernel_sizes = []
for k in range(self.focal_level):
kernel_size = self.focal_factor * k + self.focal_window
self.focal_layers.append(nn.Sequential(
nn.Conv2d(dim, dim, kernel_size=kernel_size, groups=dim, padding=kernel_size // 2, bias=False),
nn.GELU(),
))
self.kernel_sizes.append(kernel_size)
self.norm = norm_layer(dim) if self.use_post_norm else nn.Identity()
def forward(self, x):
"""
Args:
x: input features with shape of (B, H, W, C)
"""
C = x.shape[1]
# pre linear projection
x = self.f(x)
q, ctx, gates = torch.split(x, self.input_split, 1)
# context aggreation
ctx_all = 0
for l, focal_layer in enumerate(self.focal_layers):
ctx = focal_layer(ctx)
ctx_all = ctx_all + ctx * gates[:, l:l + 1]
ctx_global = self.act(ctx.mean((2, 3), keepdim=True))
ctx_all = ctx_all + ctx_global * gates[:, self.focal_level:]
# normalize context
if self.normalize_modulator:
ctx_all = ctx_all / (self.focal_level + 1)
# focal modulation
x_out = q * self.h(ctx_all)
x_out = self.norm(x_out)
# post linear projection
x_out = self.proj(x_out)
x_out = self.proj_drop(x_out)
return x_out
class LayerScale2d(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
gamma = self.gamma.view(1, -1, 1, 1)
return x.mul_(gamma) if self.inplace else x * gamma
class FocalNetBlock(nn.Module):
r""" Focal Modulation Network Block.
Args:
dim (int): Number of input channels.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
proj_drop (float, optional): Dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
focal_level (int): Number of focal levels.
focal_window (int): Focal window size at first focal level
layerscale_value (float): Initial layerscale value
use_post_norm (bool): Whether to use layernorm after modulation
"""
def __init__(
self,
dim,
mlp_ratio=4.,
focal_level=1,
focal_window=3,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
layerscale_value=1e-4,
proj_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=LayerNorm2d,
):
super().__init__()
self.dim = dim
self.mlp_ratio = mlp_ratio
self.focal_window = focal_window
self.focal_level = focal_level
self.use_post_norm = use_post_norm
self.norm1 = norm_layer(dim) if not use_post_norm else nn.Identity()
self.modulation = FocalModulation(
dim,
focal_window=focal_window,
focal_level=self.focal_level,
use_post_norm=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
proj_drop=proj_drop,
norm_layer=norm_layer,
)
self.norm1_post = norm_layer(dim) if use_post_norm else nn.Identity()
self.ls1 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) if not use_post_norm else nn.Identity()
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
drop=proj_drop,
use_conv=True,
)
self.norm2_post = norm_layer(dim) if use_post_norm else nn.Identity()
self.ls2 = LayerScale2d(dim, layerscale_value) if layerscale_value is not None else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
# Focal Modulation
x = self.norm1(x)
x = self.modulation(x)
x = self.norm1_post(x)
x = shortcut + self.drop_path1(self.ls1(x))
# FFN
x = x + self.drop_path2(self.ls2(self.norm2_post(self.mlp(self.norm2(x)))))
return x
class BasicLayer(nn.Module):
""" A basic Focal Transformer layer for one stage.
Args:
dim (int): Number of input channels.
depth (int): Number of blocks.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
drop (float, optional): Dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (bool): Downsample layer at start of the layer. Default: True
focal_level (int): Number of focal levels
focal_window (int): Focal window size at first focal level
layerscale_value (float): Initial layerscale value
use_post_norm (bool): Whether to use layer norm after modulation
"""
def __init__(
self,
dim,
out_dim,
depth,
mlp_ratio=4.,
downsample=True,
focal_level=1,
focal_window=1,
use_overlap_down=False,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
layerscale_value=1e-4,
proj_drop=0.,
drop_path=0.,
norm_layer=LayerNorm2d,
):
super().__init__()
self.dim = dim
self.depth = depth
self.grad_checkpointing = False
if downsample:
self.downsample = Downsample(
in_chs=dim,
out_chs=out_dim,
stride=2,
overlap=use_overlap_down,
norm_layer=norm_layer,
)
else:
self.downsample = nn.Identity()
# build blocks
self.blocks = nn.ModuleList([
FocalNetBlock(
dim=out_dim,
mlp_ratio=mlp_ratio,
focal_level=focal_level,
focal_window=focal_window,
use_post_norm=use_post_norm,
use_post_norm_in_modulation=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
layerscale_value=layerscale_value,
proj_drop=proj_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
)
for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
x = self.downsample(x)
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
return x
class Downsample(nn.Module):
r"""
Args:
in_chs (int): Number of input image channels
out_chs (int): Number of linear projection output channels
stride (int): Downsample stride. Default: 4.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(
self,
in_chs,
out_chs,
stride=4,
overlap=False,
norm_layer=None,
):
super().__init__()
self.stride = stride
padding = 0
kernel_size = stride
if overlap:
assert stride in (2, 4)
if stride == 4:
kernel_size, padding = 7, 2
elif stride == 2:
kernel_size, padding = 3, 1
self.proj = nn.Conv2d(in_chs, out_chs, kernel_size=kernel_size, stride=stride, padding=padding)
self.norm = norm_layer(out_chs) if norm_layer is not None else nn.Identity()
def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
class FocalNet(nn.Module):
r""" Focal Modulation Networks (FocalNets)
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Focal Transformer layer.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level.
Default: [1, 1, 1, 1]
focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
use_overlap_down (bool): Whether to use convolutional embedding.
use_post_norm (bool): Whether to use layernorm after modulation (it helps stablize training of large models)
layerscale_value (float): Value for layer scale. Default: 1e-4
drop_rate (float): Dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
embed_dim=96,
depths=(2, 2, 6, 2),
mlp_ratio=4.,
focal_levels=(2, 2, 2, 2),
focal_windows=(3, 3, 3, 3),
use_overlap_down=False,
use_post_norm=False,
use_post_norm_in_modulation=False,
normalize_modulator=False,
head_hidden_size=None,
head_init_scale=1.0,
layerscale_value=None,
drop_rate=0.,
proj_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=partial(LayerNorm2d, eps=1e-5),
**kwargs,
):
super().__init__()
self.num_layers = len(depths)
embed_dim = [embed_dim * (2 ** i) for i in range(self.num_layers)]
self.num_classes = num_classes
self.embed_dim = embed_dim
self.num_features = embed_dim[-1]
self.feature_info = []
self.stem = Downsample(
in_chs=in_chans,
out_chs=embed_dim[0],
overlap=use_overlap_down,
norm_layer=norm_layer,
)
in_dim = embed_dim[0]
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
layers = []
for i_layer in range(self.num_layers):
out_dim = embed_dim[i_layer]
layer = BasicLayer(
dim=in_dim,
out_dim=out_dim,
depth=depths[i_layer],
mlp_ratio=mlp_ratio,
downsample=i_layer > 0,
focal_level=focal_levels[i_layer],
focal_window=focal_windows[i_layer],
use_overlap_down=use_overlap_down,
use_post_norm=use_post_norm,
use_post_norm_in_modulation=use_post_norm_in_modulation,
normalize_modulator=normalize_modulator,
layerscale_value=layerscale_value,
proj_drop=proj_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
)
in_dim = out_dim
layers += [layer]
self.feature_info += [dict(num_chs=out_dim, reduction=4 * 2 ** i_layer, module=f'layers.{i_layer}')]
self.layers = nn.Sequential(*layers)
if head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=norm_layer,
)
else:
self.norm = norm_layer(self.num_features)
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=drop_rate
)
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
def no_weight_decay(self):
return {''}
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for l in self.layers:
l.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.classifier.fc
def reset_classifier(self, num_classes, global_pool=None):
self.classifier.reset(num_classes, global_pool=global_pool)
def forward_features(self, x):
x = self.stem(x)
x = self.layers(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if name and 'head.fc' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.proj', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = {
"focalnet_tiny_srf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_srf.pth'),
"focalnet_small_srf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_srf.pth'),
"focalnet_base_srf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_srf.pth'),
"focalnet_tiny_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_tiny_lrf.pth'),
"focalnet_small_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_small_lrf.pth'),
"focalnet_base_lrf": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_base_lrf.pth'),
"focalnet_large_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_large_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_large_lrf_384_fl4.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_xlarge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_xlarge_lrf_384_fl4.pth',
input_size=(3, 384, 384), crop_pct=1.0, num_classes=21842),
"focalnet_huge_fl3": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224.pth',
num_classes=0),
"focalnet_huge_fl4": _cfg(
url='https://projects4jw.blob.core.windows.net/focalnet/release/classification/focalnet_huge_lrf_224_fl4.pth',
num_classes=0),
}
def checkpoint_filter_fn(state_dict, model: FocalNet):
if 'stem.proj.weight' in state_dict:
return
import re
out_dict = {}
if 'model' in state_dict:
state_dict = state_dict['model']
dest_dict = model.state_dict()
for k, v in state_dict.items():
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
k = k.replace('patch_embed', 'stem')
k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k)
if 'norm' in k and k not in dest_dict:
k = re.sub(r'norm([0-9])', r'norm\1_post', k)
k = k.replace('ln.', 'norm.')
k = k.replace('head', 'head.fc')
if dest_dict[k].shape != v.shape:
v = v.reshape(dest_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_focalnet(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
FocalNet, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
@register_model
def focalnet_tiny_srf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, **kwargs)
return _create_focalnet('focalnet_tiny_srf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_small_srf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, **kwargs)
return _create_focalnet('focalnet_small_srf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_base_srf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, **kwargs)
return _create_focalnet('focalnet_base_srf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_tiny_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 6, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_tiny_lrf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_small_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=96, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_small_lrf', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_base_lrf(pretrained=False, **kwargs):
model_kwargs = dict(depths=[2, 2, 18, 2], embed_dim=128, focal_levels=[3, 3, 3, 3], **kwargs)
return _create_focalnet('focalnet_base_lrf', pretrained=pretrained, **model_kwargs)
# FocalNet large+ models
@register_model
def focalnet_large_fl3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_large_fl3', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_large_fl4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=192, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_large_fl4', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_xlarge_fl3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[3, 3, 3, 3], focal_windows=[5] * 4,
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_xlarge_fl3', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_xlarge_fl4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=256, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_xlarge_fl4', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_huge_fl3(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[3, 3, 3, 3], focal_windows=[3] * 4,
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_huge_fl3', pretrained=pretrained, **model_kwargs)
@register_model
def focalnet_huge_fl4(pretrained=False, **kwargs):
model_kwargs = dict(
depths=[2, 2, 18, 2], embed_dim=352, focal_levels=[4, 4, 4, 4],
use_post_norm=True, use_post_norm_in_modulation=True, use_overlap_down=True, layerscale_value=1e-4, **kwargs)
return _create_focalnet('focalnet_huge_fl4', pretrained=pretrained, **model_kwargs)

@ -243,9 +243,9 @@ def window_partition(x, window_size: Tuple[int, int]):
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
H, W = img_size
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

@ -126,9 +126,9 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns:
x: (B, H, W, C)
"""
C = windows.shape[-1]
x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

@ -120,9 +120,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C)
"""
H, W = img_size
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

@ -139,9 +139,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C)
"""
H, W = img_size
C = windows.shape[-1]
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x

@ -1 +1 @@
__version__ = '0.8.13dev0'
__version__ = '0.8.12dev0'

Loading…
Cancel
Save