Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/1414/head
Ross Wightman 2 years ago
commit 1ba0ec4c18

@ -23,8 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### March 23, 2022
* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795)
* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs.
### March 21, 2022 ### March 21, 2022
* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch `0.5.x` or a previous 0.5.x release can be used if stability is required. * Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required.
* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights) * Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights)
* `regnety_040` - 82.3 @ 224, 82.96 @ 288 * `regnety_040` - 82.3 @ 224, 82.96 @ 288
* `regnety_064` - 83.0 @ 224, 83.65 @ 288 * `regnety_064` - 83.0 @ 224, 83.65 @ 288
@ -45,7 +49,8 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
* `resnetrs200` - 83.85 @ 256, 84.44 @ 320 * `resnetrs200` - 83.85 @ 256, 84.44 @ 320
* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon) * HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks. * SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks.
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets ( * Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets
* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer * PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
* VOLO models w/ weights adapted from https://github.com/sail-sg/volo * VOLO models w/ weights adapted from https://github.com/sail-sg/volo
* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc * Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
@ -344,13 +349,16 @@ A full version of the list below with source links can be found in the [document
* FBNet-V3 - https://arxiv.org/abs/2006.02049 * FBNet-V3 - https://arxiv.org/abs/2006.02049
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646
* LCNet - https://arxiv.org/abs/2109.15099 * LCNet - https://arxiv.org/abs/2109.15099
* MobileViT - https://arxiv.org/abs/2110.02178
* NASNet-A - https://arxiv.org/abs/1707.07012 * NASNet-A - https://arxiv.org/abs/1707.07012
* NesT - https://arxiv.org/abs/2105.12723 * NesT - https://arxiv.org/abs/2105.12723
* NFNet-F - https://arxiv.org/abs/2102.06171 * NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559 * PNasNet - https://arxiv.org/abs/1712.00559
* PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* RegNet - https://arxiv.org/abs/2003.13678 * RegNet - https://arxiv.org/abs/2003.13678
* RegNetZ - https://arxiv.org/abs/2103.06877
* RepVGG - https://arxiv.org/abs/2101.03697 * RepVGG - https://arxiv.org/abs/2101.03697
* ResMLP - https://arxiv.org/abs/2105.03404 * ResMLP - https://arxiv.org/abs/2105.03404
* ResNet/ResNeXt * ResNet/ResNeXt
@ -367,12 +375,15 @@ A full version of the list below with source links can be found in the [document
* ReXNet - https://arxiv.org/abs/2007.00992 * ReXNet - https://arxiv.org/abs/2007.00992
* SelecSLS - https://arxiv.org/abs/1907.00837 * SelecSLS - https://arxiv.org/abs/1907.00837
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586
* Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725
* Swin Transformer - https://arxiv.org/abs/2103.14030 * Swin Transformer - https://arxiv.org/abs/2103.14030
* Swin Transformer V2 - https://arxiv.org/abs/2111.09883
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
* TResNet - https://arxiv.org/abs/2003.13630 * TResNet - https://arxiv.org/abs/2003.13630
* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf * Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
* Visformer - https://arxiv.org/abs/2104.12533 * Visformer - https://arxiv.org/abs/2104.12533
* Vision Transformer - https://arxiv.org/abs/2010.11929 * Vision Transformer - https://arxiv.org/abs/2010.11929
* VOLO (Vision Outlooker) - https://arxiv.org/abs/2106.13112
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
* Xception - https://arxiv.org/abs/1610.02357 * Xception - https://arxiv.org/abs/1610.02357
* Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611 * Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611

@ -44,8 +44,14 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''), convnext_nano_hnf=_cfg(url=''),
convnext_tiny_hnf=_cfg(url=''), convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95),
convnext_tiny_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'),
convnext_small_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'),
convnext_base_in22ft1k=_cfg( convnext_base_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
convnext_large_in22ft1k=_cfg( convnext_large_in22ft1k=_cfg(
@ -53,6 +59,12 @@ default_cfgs = dict(
convnext_xlarge_in22ft1k=_cfg( convnext_xlarge_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),
convnext_tiny_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
convnext_small_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
convnext_base_384_in22ft1k=_cfg( convnext_base_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
@ -63,6 +75,10 @@ default_cfgs = dict(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
convnext_tiny_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
convnext_small_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
convnext_base_in22k=_cfg( convnext_base_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
convnext_large_in22k=_cfg( convnext_large_in22k=_cfg(
@ -322,6 +338,8 @@ def _init_weights(module, name=None, head_init_scale=1.0):
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """ """ Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
if 'model' in state_dict: if 'model' in state_dict:
state_dict = state_dict['model'] state_dict = state_dict['model']
out_dict = {} out_dict = {}
@ -401,6 +419,20 @@ def convnext_large(pretrained=False, **kwargs):
return model return model
@register_model
def convnext_tiny_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def convnext_base_in22ft1k(pretrained=False, **kwargs): def convnext_base_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
@ -422,6 +454,20 @@ def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
return model return model
@register_model
def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def convnext_base_384_in22ft1k(pretrained=False, **kwargs): def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
@ -443,6 +489,20 @@ def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
return model return model
@register_model
def convnext_tiny_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def convnext_base_in22k(pretrained=False, **kwargs): def convnext_base_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
@ -462,6 +522,3 @@ def convnext_xlarge_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
return model return model

@ -15,25 +15,17 @@ except ImportError:
has_fx_feature_extraction = False has_fx_feature_extraction = False
# Layers we went to treat as leaf modules # Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .layers.non_local_attn import BilinearAttnTransform from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below # BUT modules from timm.models should use the registration mechanism below
_leaf_modules = { _leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1 BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error # Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
} }
try: try:

@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear from .layers import Conv2dSame, Linear, BatchNormAct2d
from .registry import get_pretrained_cfg from .registry import get_pretrained_cfg
@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string):
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride) groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv) set_layer(new_module, n, new_conv)
if isinstance(old_module, nn.BatchNorm2d): elif isinstance(old_module, BatchNormAct2d):
new_bn = BatchNormAct2d(
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
new_bn.drop = old_module.drop
new_bn.act = old_module.act
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d( new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True) affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn) set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear): elif isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1] num_features = state_dict[n + '.weight'][1]
new_fc = Linear( new_fc = Linear(

@ -39,4 +39,4 @@ class BlurPool2d(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect') x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)

@ -107,13 +107,13 @@ class DropBlock2d(nn.Module):
def __init__( def __init__(
self, self,
drop_prob=0.1, drop_prob: float = 0.1,
block_size=7, block_size: int = 7,
gamma_scale=1.0, gamma_scale: float = 1.0,
with_noise=False, with_noise: bool = False,
inplace=False, inplace: bool = False,
batchwise=False, batchwise: bool = False,
fast=True): fast: bool = True):
super(DropBlock2d, self).__init__() super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.gamma_scale = gamma_scale self.gamma_scale = gamma_scale
@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
""" """
def __init__(self, drop_prob=None, scale_by_keep=True): def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep self.scale_by_keep = scale_by_keep

@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert(C % groups == 0, '') _assert(C % groups == 0, '')
x_dtype = x.dtype x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W) x = x.reshape(B, groups, C // groups, H, W)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
return rms.expand(x.shape).reshape(B, C, H, W) return rms.expand(x.shape).reshape(B, C, H, W)
@ -160,14 +160,14 @@ class EvoNorm2dB1(nn.Module):
n = x.numel() / x.shape[1] n = x.numel() / x.shape[1]
self.running_var.copy_( self.running_var.copy_(
self.running_var * (1 - self.momentum) + self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else: else:
var = self.running_var var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape) var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_() left = var.add(self.eps).sqrt_()
right = (x + 1) * instance_rms(x, self.eps) right = (x + 1) * instance_rms(x, self.eps)
x = x / left.max(right) x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dB2(nn.Module): class EvoNorm2dB2(nn.Module):
@ -195,14 +195,14 @@ class EvoNorm2dB2(nn.Module):
n = x.numel() / x.shape[1] n = x.numel() / x.shape[1]
self.running_var.copy_( self.running_var.copy_(
self.running_var * (1 - self.momentum) + self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else: else:
var = self.running_var var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape) var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_() left = var.add(self.eps).sqrt_()
right = instance_rms(x, self.eps) - x right = instance_rms(x, self.eps) - x
x = x / left.max(right) x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS0(nn.Module): class EvoNorm2dS0(nn.Module):
@ -231,9 +231,9 @@ class EvoNorm2dS0(nn.Module):
x_dtype = x.dtype x_dtype = x.dtype
v_shape = (1, -1, 1, 1) v_shape = (1, -1, 1, 1)
if self.v is not None: if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype) v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS0a(EvoNorm2dS0): class EvoNorm2dS0a(EvoNorm2dS0):
@ -247,10 +247,10 @@ class EvoNorm2dS0a(EvoNorm2dS0):
v_shape = (1, -1, 1, 1) v_shape = (1, -1, 1, 1)
d = group_std(x, self.groups, self.eps) d = group_std(x, self.groups, self.eps)
if self.v is not None: if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype) v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid() x = x * (x * v).sigmoid()
x = x / d x = x / d
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS1(nn.Module): class EvoNorm2dS1(nn.Module):
@ -284,7 +284,7 @@ class EvoNorm2dS1(nn.Module):
v_shape = (1, -1, 1, 1) v_shape = (1, -1, 1, 1)
if self.apply_act: if self.apply_act:
x = self.act(x) / group_std(x, self.groups, self.eps) x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS1a(EvoNorm2dS1): class EvoNorm2dS1a(EvoNorm2dS1):
@ -299,7 +299,7 @@ class EvoNorm2dS1a(EvoNorm2dS1):
x_dtype = x.dtype x_dtype = x.dtype
v_shape = (1, -1, 1, 1) v_shape = (1, -1, 1, 1)
x = self.act(x) / group_std(x, self.groups, self.eps) x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS2(nn.Module): class EvoNorm2dS2(nn.Module):
@ -332,7 +332,7 @@ class EvoNorm2dS2(nn.Module):
v_shape = (1, -1, 1, 1) v_shape = (1, -1, 1, 1)
if self.apply_act: if self.apply_act:
x = self.act(x) / group_rms(x, self.groups, self.eps) x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS2a(EvoNorm2dS2): class EvoNorm2dS2a(EvoNorm2dS2):
@ -347,4 +347,4 @@ class EvoNorm2dS2a(EvoNorm2dS2):
x_dtype = x.dtype x_dtype = x.dtype
v_shape = (1, -1, 1, 1) v_shape = (1, -1, 1, 1)
x = self.act(x) / group_rms(x, self.groups, self.eps) x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)

@ -6,6 +6,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .trace_utils import _assert
from .create_act import get_act_layer from .create_act import get_act_layer
@ -29,9 +30,10 @@ class BatchNormAct2d(nn.BatchNorm2d):
else: else:
self.act = nn.Identity() self.act = nn.Identity()
def _forward_jit(self, x): def forward(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
""" _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
# exponential_average_factor is set to self.momentum # exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated # (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX. # in ONNX graph when this node is exported to ONNX.
@ -63,7 +65,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None). used for normalization (i.e. in eval mode when buffers are not None).
""" """
return F.batch_norm( x = F.batch_norm(
x, x,
# If buffers are not to be tracked, ensure that they won't be updated # If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None, self.running_mean if not self.training or self.track_running_stats else None,
@ -74,17 +76,6 @@ class BatchNormAct2d(nn.BatchNorm2d):
exponential_average_factor, exponential_average_factor,
self.eps, self.eps,
) )
@torch.jit.ignore
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
if torch.jit.is_scripting():
x = self._forward_jit(x)
else:
x = self._forward_python(x)
x = self.drop(x) x = self.drop(x)
x = self.act(x) x = self.act(x)
return x return x

@ -155,10 +155,10 @@ default_cfgs = dict(
regnety_040s_gn=_cfg(url=''), regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg( regnetv_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem'), first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetv_064=_cfg( regnetv_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem'), first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetz_005=_cfg(url=''), regnetz_005=_cfg(url=''),
regnetz_040=_cfg( regnetz_040=_cfg(

@ -4,6 +4,9 @@ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shi
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from
- https://github.com/microsoft/Cream/tree/main/AutoFormerV2
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
""" """
# -------------------------------------------------------- # --------------------------------------------------------
@ -669,7 +672,7 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
@register_model @register_model
def swin_s3_tiny_224(pretrained=False, **kwargs): def swin_s3_tiny_224(pretrained=False, **kwargs):
""" Swin-S3-T @ 224x224, ImageNet-1k """ Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
@ -679,7 +682,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
@register_model @register_model
def swin_s3_small_224(pretrained=False, **kwargs): def swin_s3_small_224(pretrained=False, **kwargs):
""" Swin-S3-S @ 224x224, trained ImageNet-1k """ Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
@ -689,7 +692,7 @@ def swin_s3_small_224(pretrained=False, **kwargs):
@register_model @register_model
def swin_s3_base_224(pretrained=False, **kwargs): def swin_s3_base_224(pretrained=False, **kwargs):
""" Swin-S3-B @ 224x224, trained ImageNet-1k """ Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
""" """
model_kwargs = dict( model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),

@ -170,6 +170,11 @@ default_cfgs = {
'/vit_base_patch16_224_1k_miil_84_4.pth', '/vit_base_patch16_224_1k_miil_84_4.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
), ),
# experimental
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),
} }
@ -201,28 +206,81 @@ class Attention(nn.Module):
return x return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module): class Block(nn.Module):
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
x = x + self.drop_path1(self.attn(self.norm1(x))) x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.mlp(self.norm2(x))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x return x
class ParallelBlock(nn.Module):
def __init__(
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.num_parallel = num_parallel
self.attns = nn.ModuleList()
self.ffns = nn.ModuleList()
for _ in range(num_parallel):
self.attns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
])))
self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
])))
def _forward_jit(self, x):
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
return x
@torch.jit.ignore
def _forward(self, x):
x = x + sum(attn(x) for attn in self.attns)
x = x + sum(ffn(x) for ffn in self.ffns)
return x
def forward(self, x):
if torch.jit.is_scripting() or torch.jit.is_tracing():
return self._forward_jit(x)
else:
return self._forward(x)
class VisionTransformer(nn.Module): class VisionTransformer(nn.Module):
""" Vision Transformer """ Vision Transformer
@ -233,8 +291,8 @@ class VisionTransformer(nn.Module):
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None): embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -248,10 +306,11 @@ class VisionTransformer(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
weight_init: (str): weight init scheme
drop_rate (float): dropout rate drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate drop_path_rate (float): stochastic depth rate
weight_init: (str): weight init scheme
init_values: (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer act_layer: (nn.Module): MLP activation layer
@ -277,9 +336,9 @@ class VisionTransformer(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
Block( block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)]) for i in range(depth)])
use_fc_norm = self.global_pool == 'avg' use_fc_norm = self.global_pool == 'avg'
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -941,3 +1000,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model

@ -30,7 +30,16 @@ class PlateauLRScheduler(Scheduler):
noise_seed=None, noise_seed=None,
initialize=True, initialize=True,
): ):
super().__init__(optimizer, 'lr', initialize=initialize) super().__init__(
optimizer,
'lr',
noise_range_t=noise_range_t,
noise_type=noise_type,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, self.optimizer,
@ -43,11 +52,6 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min min_lr=lr_min
) )
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.warmup_t = warmup_t self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init self.warmup_lr_init = warmup_lr_init
if self.warmup_t: if self.warmup_t:
@ -84,7 +88,6 @@ class PlateauLRScheduler(Scheduler):
if self._is_apply_noise(epoch): if self._is_apply_noise(epoch):
self._apply_noise(epoch) self._apply_noise(epoch)
def _apply_noise(self, epoch): def _apply_noise(self, epoch):
noise = self._calculate_noise(epoch) noise = self._calculate_noise(epoch)

Loading…
Cancel
Save