Merge remote-tracking branch 'origin/norm_norm_norm' into bits_and_tpu

pull/1239/head
Ross Wightman 2 years ago
commit cad170e494

@ -30,7 +30,7 @@ jobs:
- name: Install testing dependencies
run: |
python -m pip install --upgrade pip
pip install pytest pytest-timeout expecttest
pip install pytest pytest-timeout pytest-xdist expecttest
- name: Install torch on mac
if: startsWith(matrix.os, 'macOS')
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
@ -48,4 +48,4 @@ jobs:
env:
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
run: |
pytest -vv --durations=0 ./tests
pytest -vv --forked --durations=0 ./tests

@ -23,6 +23,25 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* `sebotnet33ts_256` (new) - 81.2 @ 224
* `lamhalobotnet50ts_256` - 81.5 @ 256
* `halonet50ts` - 81.7 @ 256
* `halo2botnet50ts_256` - 82.0 @ 256
* `resnet101` - 82.0 @ 224, 82.8 @ 288
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `resnet152` - 82.8 @ 224, 83.5 @ 288
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
### Oct 19, 2021
* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
* BCE loss and Repeated Augmentation support for RSB paper

@ -339,9 +339,17 @@ EXCLUDE_FX_FILTERS = []
if 'GITHUB_ACTIONS' in os.environ:
EXCLUDE_FX_FILTERS += [
'beit_large*',
'swin_large*',
'mixer_l*',
'*nfnet_f2*',
'*resnext101_32x32d',
'resnetv2_152x2*',
'resmlp_big*',
'resnetrs270',
'swin_large*',
'vgg*',
'vit_large*',
'vit_base_patch8*',
'xcit_large*',
]
@ -362,81 +370,89 @@ def test_model_forward_fx(model_name, batch_size):
input_size = _get_input_size(model=model, target=TARGET_FWD_FX_SIZE)
if max(input_size) > MAX_FWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
with torch.no_grad():
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
model = _create_fx_model(model)
fx_outputs = tuple(model(inputs).values())
if isinstance(fx_outputs, tuple):
fx_outputs = torch.cat(fx_outputs)
assert torch.all(fx_outputs == outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.train()
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
if 'GITHUB_ACTIONS' not in os.environ:
# FIXME this test is causing GitHub actions to run out of RAM and abruptly kill the test process
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FX_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward_fx(model_name, batch_size):
"""Symbolically trace each model and run single backward pass through the resulting GraphModule"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_FX_SIZE)
if max(input_size) > MAX_BWD_FX_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
model.train()
num_params = sum([x.numel() for x in model.parameters()])
if 'GITHUB_ACTIONS' in os.environ and num_params > 100e6:
pytest.skip("Skipping FX backward test on model with more than 100M params.")
model = _create_fx_model(model, train=True)
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
assert outputs.shape[-1] == 42
assert num_params == num_grad, 'Some parameters are missing gradients'
assert not torch.isnan(outputs).any(), 'Output included NaNs'
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
EXCLUDE_FX_JIT_FILTERS = [
'deit_*_distilled_patch16_224',
'levit*',
'pit_*_distilled_224',
] + EXCLUDE_FX_FILTERS
model = torch.jit.script(_create_fx_model(model))
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.parametrize(
'model_name', list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_fx_torchscript(model_name, batch_size):
"""Symbolically trace each model, script it, and run single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
model = torch.jit.script(_create_fx_model(model))
with torch.no_grad():
outputs = tuple(model(torch.randn((batch_size, *input_size))).values())
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@ -21,6 +21,10 @@ try:
print("Warning: This version of tfds doesn't have the latest even_splits impl. "
"Please update or use tfds-nightly for better fine-grained split behaviour.")
has_buggy_even_splits = True
# NOTE uncomment below if having file limit issues on dataset build (or alter your OS defaults)
# import resource
# low, high = resource.getrlimit(resource.RLIMIT_NOFILE)
# resource.setrlimit(resource.RLIMIT_NOFILE, (high, high))
except ImportError as e:
print(e)
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")

@ -76,10 +76,10 @@ default_cfgs = {
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
'lamhalobotnet50ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet_a1h_256-c9bc4e74.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halo2botnet50ts_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h_256-ad9e16fb.pth',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
}

@ -35,7 +35,8 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
from .registry import register_model
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@ -152,6 +153,12 @@ default_cfgs = {
'regnetz_e8': _cfgr(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0),
'regnetz_b16_evos': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv',
crop_pct=0.94),
'regnetz_d8_evob': _cfgr(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95),
@ -597,6 +604,23 @@ model_cfgs = dict(
),
# experimental EvoNorm configs
regnetz_b16_evos=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
),
stem_chs=32,
stem_pool='',
downsample='',
num_features=1536,
act_layer='silu',
norm_layer=partial(EvoNorm2dS0a, group_size=16),
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
),
regnetz_d8_evob=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
@ -610,7 +634,7 @@ model_cfgs = dict(
downsample='',
num_features=1792,
act_layer='silu',
norm_layer='evonormbatch',
norm_layer='evonormb0',
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
@ -628,7 +652,7 @@ model_cfgs = dict(
downsample='',
num_features=1792,
act_layer='silu',
norm_layer=partial(EvoNormSample2d, groups=32),
norm_layer=partial(EvoNorm2dS0a, group_size=16),
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.25),
block_kwargs=dict(bottle_in=True, linear_out=True),
@ -856,6 +880,13 @@ def regnetz_e8(pretrained=False, **kwargs):
return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
@register_model
def regnetz_b16_evos(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs)
@register_model
def regnetz_d8_evob(pretrained=False, **kwargs):
"""

@ -14,6 +14,8 @@ except ImportError:
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
@ -24,9 +26,12 @@ _leaf_modules = {
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
}
try:

@ -11,11 +11,11 @@ from typing import Any, Callable, Optional, Tuple
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf, load_state_dict_from_url
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear
@ -184,12 +184,12 @@ def load_pretrained(model, default_cfg=None, num_classes=1000, in_chans=3, filte
if not pretrained_url and not hf_hub_id:
_logger.warning("No pretrained weights exist for this model. Using random initialization.")
return
if hf_hub_id and has_hf_hub(necessary=not pretrained_url):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
else:
if pretrained_url:
_logger.info(f'Loading pretrained weights from url ({pretrained_url})')
state_dict = load_state_dict_from_url(pretrained_url, progress=progress, map_location='cpu')
elif hf_hub_id and has_hf_hub(necessary=True):
_logger.info(f'Loading pretrained weights from Hugging Face hub ({hf_hub_id})')
state_dict = load_state_dict_from_hf(hf_hub_id)
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:

@ -2,10 +2,11 @@ import json
import logging
import os
from functools import partial
from typing import Union, Optional
from pathlib import Path
from typing import Union
import torch
from torch.hub import load_state_dict_from_url, download_url_to_file, urlparse, HASH_REGEX
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
try:
from torch.hub import get_dir
except ImportError:
@ -13,12 +14,12 @@ except ImportError:
from timm import __version__
try:
from huggingface_hub import hf_hub_url
from huggingface_hub import cached_download
from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url
cached_download = partial(cached_download, library_name="timm", library_version=__version__)
_has_hf_hub = True
except ImportError:
hf_hub_url = None
cached_download = None
_has_hf_hub = False
_logger = logging.getLogger(__name__)
@ -53,11 +54,11 @@ def download_cached_file(url, check_hash=True, progress=False):
def has_hf_hub(necessary=False):
if hf_hub_url is None and necessary:
if not _has_hf_hub and necessary:
# if no HF Hub module installed and it is necessary to continue, raise error
raise RuntimeError(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
return hf_hub_url is not None
return _has_hf_hub
def hf_split(hf_id):
@ -94,3 +95,77 @@ def load_state_dict_from_hf(model_id: str):
cached_file = _download_from_hf(model_id, 'pytorch_model.bin')
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
model_config = model_config or {}
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = model.default_cfg
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
hf_config['labels'] = model_config.pop('labels', [f"LABEL_{i}" for i in range(hf_config['num_classes'])])
hf_config.update(model_config)
with config_path.open('w') as f:
json.dump(hf_config, f, indent=2)
def push_to_hf_hub(
model,
local_dir,
repo_namespace_or_url=None,
commit_message='Add model',
use_auth_token=True,
git_email=None,
git_user=None,
revision=None,
model_config=None,
):
if repo_namespace_or_url:
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
else:
if isinstance(use_auth_token, str):
token = use_auth_token
else:
token = HfFolder.get_token()
if token is None:
raise ValueError(
"You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
"token as the `use_auth_token` argument."
)
repo_owner = HfApi().whoami(token)['name']
repo_name = Path(local_dir).name
repo_url = f'https://huggingface.co/{repo_owner}/{repo_name}'
repo = Repository(
local_dir,
clone_from=repo_url,
use_auth_token=use_auth_token,
git_user=git_user,
git_email=git_email,
revision=revision,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
with repo.commit(commit_message):
# Save model weights and config.
save_for_hf(model, repo.local_dir, model_config=model_config)
# Save a model card if it doesn't exist.
readme_path = Path(repo.local_dir) / 'README.md'
if not readme_path.exists():
readme_path.write_text(readme_text)
return repo.git_remote_url()

@ -14,7 +14,9 @@ from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible

@ -116,9 +116,6 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
# custom autograd, then fallback
if name in _ACT_FN_ME:
return _ACT_FN_ME[name]
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
@ -132,14 +129,12 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
"""
if not name:
return None
if isinstance(name, type):
if not isinstance(name, str):
# callable, module, etc
return name
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]

@ -9,36 +9,42 @@ Hacked together by / Copyright 2020 Ross Wightman
import types
import functools
import torch
import torch.nn as nn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .evo_norm import *
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .inplace_abn import InplaceAbn
_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn}
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type
_NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d,
groupnorm=GroupNormAct,
evonormb0=EvoNorm2dB0,
evonormb1=EvoNorm2dB1,
evonormb2=EvoNorm2dB2,
evonorms0=EvoNorm2dS0,
evonorms0a=EvoNorm2dS0a,
evonorms1=EvoNorm2dS1,
evonorms1a=EvoNorm2dS1a,
evonorms2=EvoNorm2dS2,
evonorms2a=EvoNorm2dS2a,
frn=FilterResponseNormAct2d,
frntlu=FilterResponseNormTlu2d,
inplaceabn=InplaceAbn,
iabn=InplaceAbn,
)
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
# has act_layer arg to define act type
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn}
def get_norm_act_layer(layer_class):
layer_class = layer_class.replace('_', '').lower()
if layer_class.startswith("batchnorm"):
layer = BatchNormAct2d
elif layer_class.startswith("groupnorm"):
layer = GroupNormAct
elif layer_class == "evonormbatch":
layer = EvoNormBatch2d
elif layer_class == "evonormsample":
layer = EvoNormSample2d
elif layer_class == "iabn" or layer_class == "inplaceabn":
layer = InplaceAbn
else:
assert False, "Invalid norm_act layer (%s)" % layer_class
def get_norm_act_layer(layer_name):
layer_name = layer_name.replace('_', '').lower().split('-')[0]
layer = _NORM_ACT_MAP.get(layer_name, None)
assert layer is not None, "Invalid norm_act layer (%s)" % layer_name
return layer
def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu
def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu
assert len(layer_parts) in (1, 2)
layer = get_norm_act_layer(layer_parts[0])
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?

@ -1,82 +1,332 @@
"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch
""" EvoNorm in PyTorch
Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967
@inproceedings{NEURIPS2020,
author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc},
booktitle = {Advances in Neural Information Processing Systems},
editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin},
pages = {13539--13550},
publisher = {Curran Associates, Inc.},
title = {Evolving Normalization-Activation Layers},
url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf},
volume = {33},
year = {2020}
}
An attempt at getting decent performing EvoNorms running in PyTorch.
While currently faster than other impl, still quite a ways off the built-in BN
in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed).
While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm
in terms of memory usage and throughput on GPUs.
Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts.
I'm testing these modules on TPU w/ PyTorch XLA. Promising start but
currently working around some issues with builtin torch/tensor.var/std. Unlike
GPU, similar train speeds for EvoNormS variants and BatchNorm.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .trace_utils import _assert
class EvoNormBatch2d(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None):
super(EvoNormBatch2d, self).__init__()
def instance_std(x, eps: float = 1e-5):
rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
return rms.expand(x.shape)
def instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype)
return rms.expand(x.shape)
def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
B, C, H, W = x.shape
x_dtype = x.dtype
_assert(C % groups == 0, '')
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt()
x = x.reshape(B, groups, C // groups, H, W)
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt()
return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype)
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False):
# This is a workaround for some stability / odd behaviour of .var and .std
# running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
B, C, H, W = x.shape
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.float().reshape(B, groups, C // groups, H, W)
xm = x.mean(dim=(2, 3, 4), keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
else:
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype)
# group_std = group_std_tpu # temporary, for TPU / PT XLA
def group_rms(x, groups: int = 32, eps: float = 1e-5):
B, C, H, W = x.shape
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
return sqm.expand(x.shape).reshape(B, C, H, W)
class EvoNorm2dB0(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.apply_act:
if self.v is not None:
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
x_type = x.dtype
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.v is not None:
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False)
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach() * self.momentum * (n / (n - 1)))
else:
var = self.running_var
left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x)
v = self.v.to(x_dtype).view(v_shape)
right = x * v + instance_std(x, self.eps)
x = x / left.max(right)
return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape)
class EvoNormSample2d(nn.Module):
def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None):
super(EvoNormSample2d, self).__init__()
class EvoNorm2dB1(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups
self.momentum = momentum
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False)
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = (x + 1) * instance_rms(x, self.eps)
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dB2(nn.Module):
def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False)
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = instance_rms(x, self.eps) - x
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS0(nn.Module):
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
if group_size:
assert num_features % group_size == 0
self.groups = num_features // group_size
else:
self.groups = groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.v is not None:
nn.init.ones_(self.v)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape
_assert(C % self.groups == 0, '')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.v is not None:
n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
x = x.reshape(B, self.groups, -1)
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W)
return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
v = self.v.view(v_shape).to(dtype=x_dtype)
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS0a(EvoNorm2dS0):
def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
d = group_std(x, self.groups, self.eps)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
x = x * (x * v).sigmoid_()
x = x / d
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS1(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
else:
self.act = nn.Identity()
if group_size:
assert num_features % group_size == 0
self.groups = num_features // group_size
else:
self.groups = groups
self.eps = eps
self.pre_act_norm = False
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS1a(EvoNorm2dS1):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS2(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__()
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
else:
self.act = nn.Identity()
if group_size:
assert num_features % group_size == 0
self.groups = num_features // group_size
else:
self.groups = groups
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
class EvoNorm2dS2a(EvoNorm2dS2):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)

@ -0,0 +1,68 @@
""" Filter Response Norm in PyTorch
Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
import torch.nn as nn
from .create_act import create_act_layer
from .trace_utils import _assert
def inv_instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype)
return rms.expand(x.shape)
class FilterResponseNormTlu2d(nn.Module):
def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_):
super(FilterResponseNormTlu2d, self).__init__()
self.apply_act = apply_act # apply activation (non-linearity)
self.rms = rms
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.tau is not None:
nn.init.zeros_(self.tau)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = x * inv_instance_rms(x, self.eps)
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x
class FilterResponseNormAct2d(nn.Module):
def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_):
super(FilterResponseNormAct2d, self).__init__()
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer, inplace=inplace)
else:
self.act = nn.Identity()
self.rms = rms
self.eps = eps
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
def forward(self, x):
_assert(x.dim() == 4, 'expected 4D input')
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = x * inv_instance_rms(x, self.eps)
x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return self.act(x)

@ -128,6 +128,13 @@ default_cfgs = dict(
url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_12_224_dino=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
resmlp_24_224_dino=_cfg(
url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmlp_ti16_224=_cfg(),
gmlp_s16_224=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
@ -589,6 +596,33 @@ def resmlp_big_24_224_in22ft1k(pretrained=False, **kwargs):
return model
@register_model
def resmlp_12_224_dino(pretrained=False, **kwargs):
""" ResMLP-12
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
"""
model_args = dict(
patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224_dino', pretrained=pretrained, **model_args)
return model
@register_model
def resmlp_24_224_dino(pretrained=False, **kwargs):
""" ResMLP-24
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
Model pretrained via DINO (self-supervised) - https://arxiv.org/abs/2104.14294
"""
model_args = dict(
patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224_dino', pretrained=pretrained, **model_args)
return model
@register_model
def gmlp_ti16_224(pretrained=False, **kwargs):
""" gMLP-Tiny

@ -38,7 +38,8 @@ from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model
from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\
from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0,\
EvoNorm2dS1, EvoNorm2dS2, FilterResponseNormTlu2d, FilterResponseNormAct2d,\
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
@ -125,7 +126,11 @@ default_cfgs = {
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evob': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos': _cfg(
'resnetv2_50d_evos0': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_evos1': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
'resnetv2_50d_frn': _cfg(
interpolation='bicubic', first_conv='stem.conv1'),
}
@ -660,13 +665,29 @@ def resnetv2_50d_gn(pretrained=False, **kwargs):
def resnetv2_50d_evob(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evob', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
@register_model
def resnetv2_50d_evos0(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos0', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_evos1(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos1', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=partial(EvoNorm2dS1, group_size=16),
stem_type='deep', avg_down=True, **kwargs)
@register_model
def resnetv2_50d_evos(pretrained=False, **kwargs):
def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d,
'resnetv2_50d_frn', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True, **kwargs)

@ -395,7 +395,7 @@ def eca_vovnet39b(pretrained=False, **kwargs):
@register_model
def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **nkwargs):
return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs)
return create_norm_act('evonorms0', num_features, jit=False, **nkwargs)
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)

@ -37,7 +37,7 @@ class PolyLRScheduler(Scheduler):
noise_pct=0.67,
noise_std=1.0,
noise_seed=42,
k_decay=.5,
k_decay=1.0,
initialize=True) -> None:
super().__init__(
optimizer, param_group_field="lr",

Loading…
Cancel
Save