Merge branch 'rwightman:master' into master

pull/1012/head
mrT23 3 years ago committed by GitHub
commit d6701d8a81
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -23,10 +23,11 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### Nov 19, 2021
### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* `sebotnet33ts_256` (new) - 81.2 @ 224
* `lamhalobotnet50ts_256` - 81.5 @ 256
@ -35,6 +36,8 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
* `resnet101` - 82.0 @ 224, 82.8 @ 288
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `resnet152` - 82.8 @ 224, 83.5 @ 288
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)

@ -339,15 +339,18 @@ EXCLUDE_FX_FILTERS = []
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'swin_large*',
'mixer_l*',
'*nfnet_f2*',
'*resnext101_32x32d',
'resnetv2_152x2*',
'*nfnet_f2*',
'resmlp_big*',
'resnetrs270',
'swin_large*',
'vgg*',
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
'*evob', '*evos', # remove experimental evonorm models, seem to cause issues with dtype manipulation
]
@ -368,81 +371,89 @@ def test_model_forward_fx(model_name, batch_size):
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.train()
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
model = torch.jit.script(_create_fx_model(model))
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -11,11 +11,11 @@ from typing import Any, Callable, Optional, Tuple
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear
@ -184,12 +184,12 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
if not pretrained_url and not hf_hub_id:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
else:
if pretrained_url:
_logger.info(f'Loading pretrained weights from url ({pretrained_url})')
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
elif hf_hub_id and has_hf_hub(necessary=True):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:

@ -2,10 +2,11 @@ import json
import logging
import os
from functools import partial
from typing import Union, Optional
from pathlib import Path
from typing import Union
import torch
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try:
from torch.hub import get_dir
except ImportError:
@ -13,12 +14,12 @@ except ImportError:
from timm import __version__
try:
from huggingface_hub import hf_hub_url
from huggingface_hub import cached_download
from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_url = None
cached_download = None
_has_hf_hub = False
_logger = logging.getLogger(__name__)
@ -53,11 +54,11 @@ def download_cached_file(url, check_hash=True, progress=False):
def has_hf_hub(necessary=False):
if hf_hub_url is None and necessary:
if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return hf_hub_url is not None
return _has_hf_hub
def hf_split(hf_id):
@ -94,3 +95,77 @@ def load_state_dict_from_hf(model_id: str):
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
model_config = model_config or {}
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = model.default_cfg
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
hf_config.update(model_config)
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def push_to_hf_hub(
model,
local_dir,
repo_namespace_or_url=None,
commit_message='Add model',
use_auth_token=True,
git_email=None,
git_user=None,
revision=None,
model_config=None,
):
if repo_namespace_or_url:
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
else:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
repo_owner = HfApi().whoami(token)['name']
repo_name = Path(local_dir).name
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
# Save model weights and config.
save_for_hf(model, repo.local_dir, model_config=model_config)
# Save a model card if it doesn't exist.
readme_path = Path(repo.local_dir) / 'README.md'
if not readme_path.exists():
readme_path.write_text(readme_text)
return repo.git_remote_url()

@ -34,18 +34,17 @@ class EvoNormBatch2d(nn.Module):
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
_assert(x.dim() == 4, 'expected 4D input')
x_type = x.dtype
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
if self.v is not None:
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))

@ -128,6 +128,13 @@ default_cfgs = dict(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224_dino=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_224_dino=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmlp_ti16_224=_cfg(),
gmlp_s16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
@ -589,6 +596,33 @@ def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
return model
@register_model
def resmlp_12_224_dino(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224_dino', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_224_dino(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224_dino', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_ti16_224(pretrained=False, **kwargs):
""" gMLP-Tiny

@ -37,7 +37,7 @@ class PolyLRScheduler(Scheduler):
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=.5,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",

Loading…
Cancel
Save