Compare commits

..

19 Commits

Author SHA1 Message Date
Benjamin Bossan a5b01ec04e Add type annotations to _registry.py
2 years ago
Benjamin Bossan c9406ce608
Some additions to the CONTRIBUTING guide (#1685)
2 years ago
Ross Wightman a32c4eff69
Create CONTRIBUTING.md
2 years ago
Ross Wightman a0772f03e0
Update README.md
2 years ago
Ross Wightman 47f1de9bec Version bump
2 years ago
Ross Wightman 11f7b589e5 Update setup.py for huggingface changes.
2 years ago
Ross Wightman 4d9c3ae2fb Add laion2b 320x320 ConvNeXt-Large CLIP weights
2 years ago
Ross Wightman d0b45c9b4d Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent).
2 years ago
Ross Wightman 7d9e321b76 Improve tracing of window attn models with simpler reshape logic
2 years ago
Ross Wightman a3c6685e20
Delete requirements-modelindex.txt
2 years ago
Ross Wightman 022403ce0a Update README
2 years ago
Ross Wightman 2e38d53dca Remove dead line
2 years ago
Ross Wightman f77c04ff36 Torchscript fixes/hacks for rms_norm, refactor ParallelScalingBlock with manual combination of input projections, closer paper match
2 years ago
Ross Wightman 122621daef Add Final annotation to attn_fas to avoid symbol lookup of new scaled_dot_product_attn fn on old PyTorch in jit
2 years ago
Ross Wightman 621e1b2182 Add ideas from 'Scaling ViT to 22-B Params', testing PyTorch 2.0 fused F.scaled_dot_product_attention impl in vit, vit_relpos, maxxvit / coatnet.
2 years ago
Ross Wightman a3d528524a Version 0.8.12dev0
2 years ago
testbot a09d403c24 changed warning to info
2 years ago
testbot 8470e29541 Add support to load safetensors weights
2 years ago
Ross Wightman f35d6ea57b Add multi-tensor (foreach) version of Lion in style of upcoming PyTorch 2.0 optimizers
2 years ago

@ -0,0 +1,112 @@
*This guideline is very much a work-in-progress.*
Contriubtions to `timm` for code, documentation, tests are more than welcome!
There haven't been any formal guidelines to date so please bear with me, and feel free to add to this guide.
# Coding style
Code linting and auto-format (black) are not currently in place but open to consideration. In the meantime, the style to follow is (mostly) aligned with Google's guide: https://google.github.io/styleguide/pyguide.html.
A few specific differences from Google style (or black)
1. Line length is 120 char. Going over is okay in some cases (e.g. I prefer not to break URL across lines).
2. Hanging indents are always prefered, please avoid aligning arguments with closing brackets or braces.
Example, from Google guide, but this is a NO here:
```
# Aligned with opening delimiter.
foo = long_function_name(var_one, var_two,
var_three, var_four)
meal = (spam,
beans)
# Aligned with opening delimiter in a dictionary.
foo = {
'long_dictionary_key': value1 +
value2,
...
}
```
This is YES:
```
# 4-space hanging indent; nothing on first line,
# closing parenthesis on a new line.
foo = long_function_name(
var_one, var_two, var_three,
var_four
)
meal = (
spam,
beans,
)
# 4-space hanging indent in a dictionary.
foo = {
'long_dictionary_key':
long_dictionary_value,
...
}
```
When there is descrepancy in a given source file (there are many origins for various bits of code and not all have been updated to what I consider current goal), please follow the style in a given file.
In general, if you add new code, formatting it with black using the following options should result in a style that is compatible with the rest of the code base:
```
black --skip-string-normalization --line-length 120 <path-to-file>
```
Avoid formatting code that is unrelated to your PR though.
PR with pure formatting / style fixes will be accepted but only in isolation from functional changes, best to ask before starting such a change.
# Documentation
As with code style, docstrings style based on the Google guide: guide: https://google.github.io/styleguide/pyguide.html
The goal for the code is to eventually move to have all major functions and `__init__` methods use PEP484 type annotations.
When type annotations are used for a function, as per the Google pyguide, they should **NOT** be duplicated in the docstrings, please leave annotations as the one source of truth re typing.
There are a LOT of gaps in current documentation relative to the functionality in timm, please, document away!
# Installation
Create a Python virtual environment using Python 3.10. Inside the environment, install the following test dependencies:
```
python -m pip install pytest pytest-timeout pytest-xdist pytest-forked expecttest
```
Install `torch` and `torchvision` using the instructions matching your system as listed on the [PyTorch website](https://pytorch.org/).
Then install the remaining dependencies:
```
python -m pip install -r requirements.txt
python -m pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git
python -m pip install -e .
```
## Unit tests
Run the tests using:
```
pytest tests/
```
Since the whole test suite takes a lot of time to run locally (a few hours), you may want to select a subset of tests relating to the changes you made by using the `-k` option of [`pytest`](https://docs.pytest.org/en/7.1.x/example/markers.html#using-k-expr-to-select-tests-based-on-their-name). Moreover, running tests in parallel (in this example 4 processes) with the `-n` option may help:
```
pytest -k "substring-to-match" -n 4 tests/
```
## Building documentation
Please refer to [this document](https://github.com/huggingface/pytorch-image-models/tree/main/hfdocs).
# Questions
If you have any questions about contribution, where / how to contribute, please ask in the [Discussions](https://github.com/huggingface/pytorch-image-models/discussions/categories/contributing) (there is a `Contributing` topic).

@ -24,6 +24,17 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Feb 20, 2023
* Add 320x320 `convnext_large_mlp.clip_laion2b_ft_320` and `convnext_lage_mlp.clip_laion2b_ft_soup_320` CLIP image tower weights for features & fine-tune
* 0.8.13dev0 pypi release for latest changes w/ move to huggingface org
### Feb 16, 2023
* `safetensor` checkpoint support added
* Add ideas from 'Scaling Vision Transformers to 22 B. Params' (https://arxiv.org/abs/2302.05442) -- qk norm, RmsNorm, parallel block
* Add F.scaled_dot_product_attention support (PyTorch 2.0 only) to `vit_*`, `vit_relpos*`, `coatnet` / `maxxvit` (to start)
* Lion optimizer (w/ multi-tensor option) added (https://arxiv.org/abs/2302.06675)
* gradient checkpointing works with `features_only=True`
### Feb 7, 2023 ### Feb 7, 2023
* New inference benchmark numbers added in [results](results/) folder. * New inference benchmark numbers added in [results](results/) folder.
* Add convnext LAION CLIP trained weights and initial set of in1k fine-tunes * Add convnext LAION CLIP trained weights and initial set of in1k fine-tunes

@ -17,20 +17,30 @@ import os
import glob import glob
import hashlib import hashlib
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH', parser.add_argument('--input', default='', type=str, metavar='PATH',
help='path to base input folder containing checkpoints') help='path to base input folder containing checkpoints')
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
help='checkpoint filter (path wildcard)') help='checkpoint filter (path wildcard)')
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
help='output filename') help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='Force not using ema version of weights (if present)') help='Force not using ema version of weights (if present)')
parser.add_argument('--no-sort', dest='no_sort', action='store_true', parser.add_argument('--no-sort', dest='no_sort', action='store_true',
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
parser.add_argument('-n', type=int, default=10, metavar='N', parser.add_argument('-n', type=int, default=10, metavar='N',
help='Number of checkpoints to average') help='Number of checkpoints to average')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path): def checkpoint_metric(checkpoint_path):
@ -55,8 +65,23 @@ def main():
# by default sort by checkpoint metric (if present) and avg top n checkpoints # by default sort by checkpoint metric (if present) and avg top n checkpoints
args.sort = not args.no_sort args.sort = not args.no_sort
if os.path.exists(args.output): if args.safetensors and args.output == DEFAULT_OUTPUT:
print("Error: Output filename ({}) already exists.".format(args.output)) # Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(output))
exit(1) exit(1)
pattern = args.input pattern = args.input
@ -73,22 +98,27 @@ def main():
checkpoint_metrics.append((metric, c)) checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics)) checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:] checkpoint_metrics = checkpoint_metrics[-args.n:]
if checkpoint_metrics:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics] [print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics] avg_checkpoints = [c for m, c in checkpoint_metrics]
else: else:
avg_checkpoints = checkpoints avg_checkpoints = checkpoints
if avg_checkpoints:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(c) for c in checkpoints] [print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {} avg_state_dict = {}
avg_counts = {} avg_counts = {}
for c in avg_checkpoints: for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema) new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict: if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) print(f"Error: Checkpoint ({c}) doesn't exist")
continue continue
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
if k not in avg_state_dict: if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64) avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -107,14 +137,15 @@ def main():
v = v.clamp(float32_info.min, float32_info.max) v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32) final_state_dict[k] = v.to(dtype=torch.float32)
try: if args.safetensors:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) assert _has_safetensors, "`pip install safetensors` to use .safetensors"
except: safetensors.torch.save_file(final_state_dict, output)
torch.save(final_state_dict, args.output) else:
torch.save(final_state_dict, output)
with open(args.output, 'rb') as f: with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
if __name__ == '__main__': if __name__ == '__main__':

@ -12,8 +12,13 @@ import argparse
import os import os
import hashlib import hashlib
import shutil import shutil
from collections import OrderedDict import tempfile
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -22,10 +27,12 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path') help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
_TEMP_NAME = './_checkpoint.pth' help='Save weights using safetensors instead of the default torch way (pickle).')
def main(): def main():
@ -35,10 +42,24 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
exit(1) exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn) clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint): if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint)) print("=> Loading checkpoint '{}'".format(checkpoint))
@ -53,22 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
new_state_dict[name] = v new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint)) print("=> Loaded state_dict from '{}'".format(checkpoint))
try: ext = ''
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if output: if output:
checkpoint_root, checkpoint_base = os.path.split(output) checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base = os.path.splitext(checkpoint_base)[0] checkpoint_base, ext = os.path.splitext(checkpoint_base)
else: else:
checkpoint_root = '' checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0] checkpoint_base = os.path.split(checkpoint)[1]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' checkpoint_base = os.path.splitext(checkpoint_base)[0]
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
temp_filename = '__' + checkpoint_base
if safe_serialization:
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
else:
torch.save(new_state_dict, temp_filename)
with open(temp_filename, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if ext:
final_ext = ext
else:
final_ext = ('.safetensors' if safe_serialization else '.pth')
if no_hash:
final_filename = checkpoint_base + final_ext
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename return final_filename
else: else:

@ -1,2 +0,0 @@
model-index==0.1.10
jinja2==2.11.3

@ -2,3 +2,4 @@ torch>=1.7
torchvision torchvision
pyyaml pyyaml
huggingface_hub huggingface_hub
safetensors>=0.2

@ -14,12 +14,12 @@ exec(open('timm/version.py').read())
setup( setup(
name='timm', name='timm',
version=__version__, version=__version__,
description='(Unofficial) PyTorch Image Models', description='PyTorch Image Models',
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type='text/markdown',
url='https://github.com/rwightman/pytorch-image-models', url='https://github.com/huggingface/pytorch-image-models',
author='Ross Wightman', author='Ross Wightman',
author_email='hello@rwightman.com', author_email='ross@huggingface.co',
classifiers=[ classifiers=[
# How mature is this project? Common values are # How mature is this project? Common values are
# 3 - Alpha # 3 - Alpha
@ -29,11 +29,11 @@ setup(
'Intended Audience :: Education', 'Intended Audience :: Education',
'Intended Audience :: Science/Research', 'Intended Audience :: Science/Research',
'License :: OSI Approved :: Apache Software License', 'License :: OSI Approved :: Apache Software License',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Topic :: Scientific/Engineering', 'Topic :: Scientific/Engineering',
'Topic :: Scientific/Engineering :: Artificial Intelligence', 'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development', 'Topic :: Software Development',
@ -45,7 +45,7 @@ setup(
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit', keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet resnet vision transformer vit',
packages=find_packages(exclude=['convert', 'tests', 'results']), packages=find_packages(exclude=['convert', 'tests', 'results']),
include_package_data=True, include_package_data=True,
install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub'], install_requires=['torch >= 1.7', 'torchvision', 'pyyaml', 'huggingface_hub', 'safetensors'],
python_requires='>=3.6', python_requires='>=3.7',
) )

@ -7,6 +7,11 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder import timm.models._builder
@ -26,7 +31,13 @@ def clean_state_dict(state_dict):
def load_state_dict(checkpoint_path, use_ema=True): def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = '' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
if use_ema and checkpoint.get('state_dict_ema', None) is not None: if use_ema and checkpoint.get('state_dict_ema', None) is not None:

@ -2,10 +2,11 @@ import hashlib
import json import json
import logging import logging
import os import os
import sys
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional, Union from typing import Iterable, Optional, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
@ -15,6 +16,17 @@ try:
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from timm import __version__ from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg from timm.models._pretrained import filter_pretrained_cfg
@ -35,6 +47,10 @@ _logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''): def get_cache_dir(child_dir=''):
""" """
@ -150,14 +166,33 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = download_from_hf(model_id, filename) hf_model_id, hf_revision = hf_split(model_id)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict # Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
model_config = model_config or {} model_config = model_config or {}
hf_config = {} hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -195,13 +230,23 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
json.dump(hf_config, f, indent=2) json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None): def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
):
assert has_hf_hub(True) assert has_hf_hub(True)
save_directory = Path(save_directory) save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True) save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin' # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
torch.save(model.state_dict(), weights_path) tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
config_path = save_directory / 'config.json' config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config) save_config_for_hf(model, config_path, model_config=model_config)
@ -217,7 +262,15 @@ def push_to_hf_hub(
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
model_card: Optional[dict] = None, model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
): ):
"""
Arguments:
(...)
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Can be set to `"both"` in order to push both safe and unsafe weights.
"""
# Create repo if it doesn't exist yet # Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -236,7 +289,7 @@ def push_to_hf_hub(
# Dump model and push to Hub # Dump model and push to Hub
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
# Save model weights and config. # Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config) save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
# Add readme if it does not exist # Add readme if it does not exist
if not has_readme: if not has_readme:
@ -302,3 +355,16 @@ def generate_readme(model_card: dict, model_name: str):
for c in citations: for c in citations:
readme_text += f"```bibtex\n{c}\n```\n" readme_text += f"```bibtex\n{c}\n```\n"
return readme_text return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
Use case:
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
return filename[:-4] + ".safetensors"

@ -93,7 +93,7 @@ class DefaultCfg:
return tag, self.cfgs[tag] return tag, self.cfgs[tag]
def split_model_name_tag(model_name: str, no_tag=''): def split_model_name_tag(model_name: str, no_tag: str = '') -> Tuple[str, str]:
model_name, *tag_list = model_name.split('.', 1) model_name, *tag_list = model_name.split('.', 1)
tag = tag_list[0] if tag_list else no_tag tag = tag_list[0] if tag_list else no_tag
return model_name, tag return model_name, tag

@ -8,7 +8,7 @@ import sys
from collections import defaultdict, deque from collections import defaultdict, deque
from copy import deepcopy from copy import deepcopy
from dataclasses import replace from dataclasses import replace
from typing import List, Optional, Union, Tuple from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Sequence, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -16,20 +16,20 @@ __all__ = [
'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules', 'list_models', 'list_pretrained', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules',
'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name'] 'get_pretrained_cfg_value', 'is_model_pretrained', 'get_arch_name']
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module _module_to_models: Dict[str, Set[str]] = defaultdict(set) # dict of sets to check membership of model in module
_model_to_module = {} # mapping of model names to module names _model_to_module: Dict[str, str] = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns _model_entrypoints: Dict[str, Callable[..., Any]] = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present _model_has_pretrained: Set[str] = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects _model_default_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs _model_pretrained_cfgs: Dict[str, PretrainedCfg] = {} # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names _model_with_tags: Dict[str, List[str]] = defaultdict(list) # shortcut to map each model arch to all model + tag names
def get_arch_name(model_name: str) -> Tuple[str, Optional[str]]: def get_arch_name(model_name: str) -> str:
return split_model_name_tag(model_name)[0] return split_model_name_tag(model_name)[0]
def register_model(fn): def register_model(fn: Callable[..., Any]) -> Callable[..., Any]:
# lookup containing module # lookup containing module
mod = sys.modules[fn.__module__] mod = sys.modules[fn.__module__]
module_name_split = fn.__module__.split('.') module_name_split = fn.__module__.split('.')
@ -40,7 +40,7 @@ def register_model(fn):
if hasattr(mod, '__all__'): if hasattr(mod, '__all__'):
mod.__all__.append(model_name) mod.__all__.append(model_name)
else: else:
mod.__all__ = [model_name] mod.__all__ = [model_name] # type: ignore
# add entries to registry dict/sets # add entries to registry dict/sets
_model_entrypoints[model_name] = fn _model_entrypoints[model_name] = fn
@ -87,28 +87,33 @@ def register_model(fn):
return fn return fn
def _natural_key(string_): def _natural_key(string_: str) -> List[Union[int, str]]:
"""See https://blog.codinghorror.com/sorting-for-humans-natural-sort-order/"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models( def list_models(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
module: str = '', module: str = '',
pretrained=False, pretrained: bool = False,
exclude_filters: str = '', exclude_filters: Union[str, List[str]] = '',
name_matches_cfg: bool = False, name_matches_cfg: bool = False,
include_tags: Optional[bool] = None, include_tags: Optional[bool] = None,
): ) -> List[str]:
""" Return list of available model names, sorted alphabetically """ Return list of available model names, sorted alphabetically
Args: Args:
filter (str) - Wildcard filter string that works with fnmatch filter - Wildcard filter string that works with fnmatch
module (str) - Limit model selection to a specific submodule (ie 'vision_transformer') module - Limit model selection to a specific submodule (ie 'vision_transformer')
pretrained (bool) - Include only models with valid pretrained weights if True pretrained - Include only models with valid pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter exclude_filters - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) name_matches_cfg - Include only models w/ model_name matching default_cfg name (excludes some aliases)
include_tags (Optional[boo]) - Include pretrained tags in model names (model.tag). If None, defaults include_tags - Include pretrained tags in model names (model.tag). If None, defaults
set to True when pretrained=True else False (default: None) set to True when pretrained=True else False (default: None)
Returns:
models - The sorted list of models
Example: Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
@ -118,7 +123,7 @@ def list_models(
include_tags = pretrained include_tags = pretrained
if module: if module:
all_models = list(_module_to_models[module]) all_models: Iterable[str] = list(_module_to_models[module])
else: else:
all_models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
@ -130,14 +135,14 @@ def list_models(
all_models = models_with_tags all_models = models_with_tags
if filter: if filter:
models = [] models: Set[str] = set()
include_filters = filter if isinstance(filter, (tuple, list)) else [filter] include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters: for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models): if len(include_models):
models = set(models).union(include_models) models = models.union(include_models)
else: else:
models = all_models models = set(all_models)
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
@ -145,7 +150,7 @@ def list_models(
for xf in exclude_filters: for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models exclude_models = fnmatch.filter(models, xf) # exclude these models
if len(exclude_models): if len(exclude_models):
models = set(models).difference(exclude_models) models = models.difference(exclude_models)
if pretrained: if pretrained:
models = _model_has_pretrained.intersection(models) models = _model_has_pretrained.intersection(models)
@ -153,13 +158,13 @@ def list_models(
if name_matches_cfg: if name_matches_cfg:
models = set(_model_pretrained_cfgs).intersection(models) models = set(_model_pretrained_cfgs).intersection(models)
return list(sorted(models, key=_natural_key)) return sorted(models, key=_natural_key)
def list_pretrained( def list_pretrained(
filter: Union[str, List[str]] = '', filter: Union[str, List[str]] = '',
exclude_filters: str = '', exclude_filters: str = '',
): ) -> List[str]:
return list_models( return list_models(
filter=filter, filter=filter,
pretrained=True, pretrained=True,
@ -168,14 +173,14 @@ def list_pretrained(
) )
def is_model(model_name): def is_model(model_name: str) -> bool:
""" Check if a model name exists """ Check if a model name exists
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
return arch_name in _model_entrypoints return arch_name in _model_entrypoints
def model_entrypoint(model_name, module_filter: Optional[str] = None): def model_entrypoint(model_name: str, module_filter: Optional[str] = None) -> Callable[..., Any]:
"""Fetch a model entrypoint for specified model name """Fetch a model entrypoint for specified model name
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
@ -184,29 +189,32 @@ def model_entrypoint(model_name, module_filter: Optional[str] = None):
return _model_entrypoints[arch_name] return _model_entrypoints[arch_name]
def list_modules(): def list_modules() -> List[str]:
""" Return list of module names that contain models / model entrypoints """ Return list of module names that contain models / model entrypoints
""" """
modules = _module_to_models.keys() modules = _module_to_models.keys()
return list(sorted(modules)) return sorted(modules)
def is_model_in_modules(model_name, module_names): def is_model_in_modules(
model_name: str, module_names: Union[Tuple[str, ...], List[str], Set[str]]
) -> bool:
"""Check if a model exists within a subset of modules """Check if a model exists within a subset of modules
Args: Args:
model_name (str) - name of model to check model_name - name of model to check
module_names (tuple, list, set) - names of modules to search in module_names - names of modules to search in
""" """
arch_name = get_arch_name(model_name) arch_name = get_arch_name(model_name)
assert isinstance(module_names, (tuple, list, set)) assert isinstance(module_names, (tuple, list, set))
return any(arch_name in _module_to_models[n] for n in module_names) return any(arch_name in _module_to_models[n] for n in module_names)
def is_model_pretrained(model_name): def is_model_pretrained(model_name: str) -> bool:
return model_name in _model_has_pretrained return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name, allow_unregistered=True): def get_pretrained_cfg(model_name: str, allow_unregistered: bool = True) -> Optional[PretrainedCfg]:
if model_name in _model_pretrained_cfgs: if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name]) return deepcopy(_model_pretrained_cfgs[model_name])
arch_name, tag = split_model_name_tag(model_name) arch_name, tag = split_model_name_tag(model_name)
@ -219,7 +227,7 @@ def get_pretrained_cfg(model_name, allow_unregistered=True):
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.') raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key): def get_pretrained_cfg_value(model_name: str, cfg_key: str) -> Optional[Any]:
""" Get a specific model default_cfg value by key. None if key doesn't exist. """ Get a specific model default_cfg value by key. None if key doesn't exist.
""" """
cfg = get_pretrained_cfg(model_name, allow_unregistered=False) cfg = get_pretrained_cfg(model_name, allow_unregistered=False)

@ -773,6 +773,16 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
'convnext_large_mlp.clip_laion2b_ft_soup_320': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=768),
}) })

@ -217,9 +217,9 @@ def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -243,9 +243,9 @@ def window_partition(x, window_size: Tuple[int, int]):
@register_notrace_function # reason: int argument is a Proxy @register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -126,9 +126,9 @@ def window_reverse(windows, window_size: int, H: int, W: int):
Returns: Returns:
x: (B, H, W, C) x: (B, H, W, C)
""" """
B = int(windows.shape[0] / (H * W / window_size / window_size)) C = windows.shape[-1]
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) x = windows.view(-1, H // window_size, W // window_size, window_size, window_size, C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -120,9 +120,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C) x: (B, H, W, C)
""" """
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -139,9 +139,9 @@ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, i
x: (B, H, W, C) x: (B, H, W, C)
""" """
H, W = img_size H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) C = windows.shape[-1]
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
return x return x

@ -16,6 +16,8 @@ Original Impl: https://github.com/google/automl/tree/master/lion
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
from typing import List
import torch import torch
from torch.optim.optimizer import Optimizer from torch.optim.optimizer import Optimizer
@ -23,7 +25,15 @@ from torch.optim.optimizer import Optimizer
class Lion(Optimizer): class Lion(Optimizer):
r"""Implements Lion algorithm.""" r"""Implements Lion algorithm."""
def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0): def __init__(
self,
params,
lr=1e-4,
betas=(0.9, 0.99),
weight_decay=0.0,
maximize=False,
foreach=None,
):
"""Initialize the hyperparameters. """Initialize the hyperparameters.
Args: Args:
@ -41,9 +51,21 @@ class Lion(Optimizer):
raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) defaults = dict(
lr=lr,
betas=betas,
weight_decay=weight_decay,
foreach=foreach,
maximize=maximize,
)
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('maximize', False)
group.setdefault('foreach', None)
@torch.no_grad() @torch.no_grad()
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -61,27 +83,144 @@ class Lion(Optimizer):
loss = closure() loss = closure()
for group in self.param_groups: for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
beta1, beta2 = group['betas']
for p in group['params']: for p in group['params']:
if p.grad is None: if p.grad is None:
continue continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('Lion does not support sparse gradients')
grads.append(p.grad)
# Perform stepweight decay
p.data.mul_(1 - group['lr'] * group['weight_decay'])
grad = p.grad
state = self.state[p] state = self.state[p]
# State initialization # State initialization
if len(state) == 0: if len(state) == 0:
# Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg'] exp_avgs.append(state['exp_avg'])
beta1, beta2 = group['betas']
lion(
params_with_grad,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
maximize=group['maximize'],
foreach=group['foreach'],
)
return loss
def lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
maximize: bool = False,
foreach: bool = None,
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
):
r"""Functional API that performs Lion algorithm computation.
"""
if foreach is None:
# Placeholder for more complex foreach logic to be added when value is not set
foreach = False
if foreach and torch.jit.is_scripting():
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
if foreach and not torch.jit.is_scripting():
func = _multi_tensor_lion
else:
func = _single_tensor_lion
func(
params,
grads,
exp_avgs,
beta1=beta1,
beta2=beta2,
lr=lr,
weight_decay=weight_decay,
maximize=maximize,
)
def _single_tensor_lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
for i, param in enumerate(params):
grad = grads[i] if not maximize else -grads[i]
exp_avg = exp_avgs[i]
if torch.is_complex(param):
grad = torch.view_as_real(grad)
exp_avg = torch.view_as_real(exp_avg)
param = torch.view_as_real(param)
# Perform stepweight decay
param.mul_(1 - lr * weight_decay)
# Weight update # Weight update
update = exp_avg * beta1 + grad * (1 - beta1) update = exp_avg.mul(beta1).add_(grad, alpha=1 - beta1)
p.add_(torch.sign(update), alpha=-group['lr']) param.add_(torch.sign(update), alpha=-lr)
# Decay the momentum running average coefficient # Decay the momentum running average coefficient
exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) exp_avg.lerp_(grad, 1 - beta2)
return loss
def _multi_tensor_lion(
params: List[torch.Tensor],
grads: List[torch.Tensor],
exp_avgs: List[torch.Tensor],
*,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
maximize: bool,
):
if len(params) == 0:
return
if maximize:
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
# Perform stepweight decay
torch._foreach_mul_(params, 1 - lr * weight_decay)
# Weight update
updates = torch._foreach_mul(exp_avgs, beta1)
torch._foreach_add_(updates, grads, alpha=1 - beta1)
updates = [u.sign() for u in updates]
torch._foreach_add_(params, updates, alpha=-lr)
# Decay the momentum running average coefficient
torch._foreach_mul_(exp_avgs, beta2)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta2)

@ -36,6 +36,12 @@ except ImportError:
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
# optimizers to default to multi-tensor
_DEFAULT_FOREACH = {
'lion',
}
def param_groups_weight_decay( def param_groups_weight_decay(
model: nn.Module, model: nn.Module,
weight_decay=1e-5, weight_decay=1e-5,
@ -162,7 +168,8 @@ def optimizer_kwargs(cfg):
opt=cfg.opt, opt=cfg.opt,
lr=cfg.lr, lr=cfg.lr,
weight_decay=cfg.weight_decay, weight_decay=cfg.weight_decay,
momentum=cfg.momentum) momentum=cfg.momentum,
)
if getattr(cfg, 'opt_eps', None) is not None: if getattr(cfg, 'opt_eps', None) is not None:
kwargs['eps'] = cfg.opt_eps kwargs['eps'] = cfg.opt_eps
if getattr(cfg, 'opt_betas', None) is not None: if getattr(cfg, 'opt_betas', None) is not None:
@ -171,6 +178,8 @@ def optimizer_kwargs(cfg):
kwargs['layer_decay'] = cfg.layer_decay kwargs['layer_decay'] = cfg.layer_decay
if getattr(cfg, 'opt_args', None) is not None: if getattr(cfg, 'opt_args', None) is not None:
kwargs.update(cfg.opt_args) kwargs.update(cfg.opt_args)
if getattr(cfg, 'opt_foreach', None) is not None:
kwargs['foreach'] = cfg.opt_foreach
return kwargs return kwargs
@ -191,6 +200,7 @@ def create_optimizer_v2(
lr: Optional[float] = None, lr: Optional[float] = None,
weight_decay: float = 0., weight_decay: float = 0.,
momentum: float = 0.9, momentum: float = 0.9,
foreach: Optional[bool] = None,
filter_bias_and_bn: bool = True, filter_bias_and_bn: bool = True,
layer_decay: Optional[float] = None, layer_decay: Optional[float] = None,
param_group_fn: Optional[Callable] = None, param_group_fn: Optional[Callable] = None,
@ -209,6 +219,7 @@ def create_optimizer_v2(
lr: initial learning rate lr: initial learning rate
weight_decay: weight decay to apply in optimizer weight_decay: weight decay to apply in optimizer
momentum: momentum for momentum based optimizers (others may use betas via kwargs) momentum: momentum for momentum based optimizers (others may use betas via kwargs)
foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
**kwargs: extra optimizer specific kwargs to pass through **kwargs: extra optimizer specific kwargs to pass through
@ -228,7 +239,8 @@ def create_optimizer_v2(
model_or_params, model_or_params,
weight_decay=weight_decay, weight_decay=weight_decay,
layer_decay=layer_decay, layer_decay=layer_decay,
no_weight_decay_list=no_weight_decay) no_weight_decay_list=no_weight_decay,
)
weight_decay = 0. weight_decay = 0.
elif weight_decay and filter_bias_and_bn: elif weight_decay and filter_bias_and_bn:
parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay) parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
@ -246,9 +258,16 @@ def create_optimizer_v2(
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
opt_args = dict(weight_decay=weight_decay, **kwargs) opt_args = dict(weight_decay=weight_decay, **kwargs)
if lr is not None: if lr is not None:
opt_args.setdefault('lr', lr) opt_args.setdefault('lr', lr)
if foreach is None:
if opt in _DEFAULT_FOREACH:
opt_args.setdefault('foreach', True)
else:
opt_args['foreach'] = foreach
# basic SGD & related # basic SGD & related
if opt_lower == 'sgd' or opt_lower == 'nesterov': if opt_lower == 'sgd' or opt_lower == 'nesterov':
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons

@ -1 +1 @@
__version__ = '0.8.11dev0' __version__ = '0.8.13dev0'

Loading…
Cancel
Save