diff --git a/README.md b/README.md index ad5bcc45..e79845b3 100644 --- a/README.md +++ b/README.md @@ -23,8 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### March 23, 2022 +* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) +* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. + ### March 21, 2022 -* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch `0.5.x` or a previous 0.5.x release can be used if stability is required. +* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required. * Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights) * `regnety_040` - 82.3 @ 224, 82.96 @ 288 * `regnety_064` - 83.0 @ 224, 83.65 @ 288 @@ -45,7 +49,8 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor * `resnetrs200` - 83.85 @ 256, 84.44 @ 320 * HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon) * SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks. -* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets ( +* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2 +* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets * PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer * VOLO models w/ weights adapted from https://github.com/sail-sg/volo * Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc @@ -344,13 +349,16 @@ A full version of the list below with source links can be found in the [document * FBNet-V3 - https://arxiv.org/abs/2006.02049 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 * LCNet - https://arxiv.org/abs/2109.15099 +* MobileViT - https://arxiv.org/abs/2110.02178 * NASNet-A - https://arxiv.org/abs/1707.07012 * NesT - https://arxiv.org/abs/2105.12723 * NFNet-F - https://arxiv.org/abs/2102.06171 * NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692 * PNasNet - https://arxiv.org/abs/1712.00559 +* PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418 * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * RegNet - https://arxiv.org/abs/2003.13678 +* RegNetZ - https://arxiv.org/abs/2103.06877 * RepVGG - https://arxiv.org/abs/2101.03697 * ResMLP - https://arxiv.org/abs/2105.03404 * ResNet/ResNeXt @@ -367,12 +375,15 @@ A full version of the list below with source links can be found in the [document * ReXNet - https://arxiv.org/abs/2007.00992 * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 +* Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725 * Swin Transformer - https://arxiv.org/abs/2103.14030 +* Swin Transformer V2 - https://arxiv.org/abs/2111.09883 * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * TResNet - https://arxiv.org/abs/2003.13630 * Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf * Visformer - https://arxiv.org/abs/2104.12533 * Vision Transformer - https://arxiv.org/abs/2010.11929 +* VOLO (Vision Outlooker) - https://arxiv.org/abs/2106.13112 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * Xception - https://arxiv.org/abs/1610.02357 * Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611 diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 0a2df3de..9fd4525a 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -44,8 +44,14 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_nano_hnf=_cfg(url=''), - convnext_tiny_hnf=_cfg(url=''), - + convnext_tiny_hnf=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', + crop_pct=0.95), + + convnext_tiny_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'), + convnext_small_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'), convnext_base_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'), convnext_large_in22ft1k=_cfg( @@ -53,6 +59,12 @@ default_cfgs = dict( convnext_xlarge_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'), + convnext_tiny_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + convnext_small_384_in22ft1k=_cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), convnext_base_384_in22ft1k=_cfg( url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), @@ -63,6 +75,10 @@ default_cfgs = dict( url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), + convnext_tiny_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), + convnext_small_in22k=_cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), convnext_base_in22k=_cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), convnext_large_in22k=_cfg( @@ -322,6 +338,8 @@ def _init_weights(module, name=None, head_init_scale=1.0): def checkpoint_filter_fn(state_dict, model): """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint if 'model' in state_dict: state_dict = state_dict['model'] out_dict = {} @@ -401,6 +419,20 @@ def convnext_large(pretrained=False, **kwargs): return model +@register_model +def convnext_tiny_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_small_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_base_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) @@ -422,6 +454,20 @@ def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): return model +@register_model +def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_small_384_in22ft1k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_base_384_in22ft1k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) @@ -443,6 +489,20 @@ def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): return model +@register_model +def convnext_tiny_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args) + return model + + +@register_model +def convnext_small_in22k(pretrained=False, **kwargs): + model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_base_in22k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) @@ -462,6 +522,3 @@ def convnext_xlarge_in22k(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) return model - - - diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index cbb51980..a9c05b0a 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -15,25 +15,17 @@ except ImportError: has_fx_feature_extraction = False # Layers we went to treat as leaf modules -from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath -from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2 -from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame from .layers.non_local_attn import BilinearAttnTransform from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here # BUT modules from timm.models should use the registration mechanism below _leaf_modules = { - BatchNormAct2d, # reason: flow control for jit scripting BilinearAttnTransform, # reason: flow control t <= 1 - BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] # Reason: get_same_padding has a max which raises a control flow error Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) - DropPath, # reason: TypeError: rand recieved Proxy in `size` argument - EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?) - EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, - } try: diff --git a/timm/models/helpers.py b/timm/models/helpers.py index eda09680..bbedd7a8 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .fx_features import FeatureGraphNet from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf -from .layers import Conv2dSame, Linear +from .layers import Conv2dSame, Linear, BatchNormAct2d from .registry import get_pretrained_cfg @@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string): bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation, groups=g, stride=old_module.stride) set_layer(new_module, n, new_conv) - if isinstance(old_module, nn.BatchNorm2d): + elif isinstance(old_module, BatchNormAct2d): + new_bn = BatchNormAct2d( + state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, + affine=old_module.affine, track_running_stats=True) + new_bn.drop = old_module.drop + new_bn.act = old_module.act + set_layer(new_module, n, new_bn) + elif isinstance(old_module, nn.BatchNorm2d): new_bn = nn.BatchNorm2d( num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum, affine=old_module.affine, track_running_stats=True) set_layer(new_module, n, new_bn) - if isinstance(old_module, nn.Linear): + elif isinstance(old_module, nn.Linear): # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer? num_features = state_dict[n + '.weight'][1] new_fc = Linear( diff --git a/timm/models/layers/blur_pool.py b/timm/models/layers/blur_pool.py index ca4ce756..e73d8863 100644 --- a/timm/models/layers/blur_pool.py +++ b/timm/models/layers/blur_pool.py @@ -39,4 +39,4 @@ class BlurPool2d(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: x = F.pad(x, self.padding, 'reflect') - return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1]) + return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels) diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 14945efc..ae065277 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -107,13 +107,13 @@ class DropBlock2d(nn.Module): def __init__( self, - drop_prob=0.1, - block_size=7, - gamma_scale=1.0, - with_noise=False, - inplace=False, - batchwise=False, - fast=True): + drop_prob: float = 0.1, + block_size: int = 7, + gamma_scale: float = 1.0, + with_noise: bool = False, + inplace: bool = False, + batchwise: bool = False, + fast: bool = True): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale @@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ - def __init__(self, drop_prob=None, scale_by_keep=True): + def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): super(DropPath, self).__init__() self.drop_prob = drop_prob self.scale_by_keep = scale_by_keep diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 42636236..b643302c 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5): _assert(C % groups == 0, '') x_dtype = x.dtype x = x.reshape(B, groups, C // groups, H, W) - rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) + rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype) return rms.expand(x.shape).reshape(B, C, H, W) @@ -160,14 +160,14 @@ class EvoNorm2dB1(nn.Module): n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + - var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) else: var = self.running_var - var = var.to(dtype=x_dtype).view(v_shape) + var = var.to(x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = (x + 1) * instance_rms(x, self.eps) x = x / left.max(right) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dB2(nn.Module): @@ -195,14 +195,14 @@ class EvoNorm2dB2(nn.Module): n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + - var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) + var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1))) else: var = self.running_var - var = var.to(dtype=x_dtype).view(v_shape) + var = var.to(x_dtype).view(v_shape) left = var.add(self.eps).sqrt_() right = instance_rms(x, self.eps) - x x = x / left.max(right) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS0(nn.Module): @@ -231,9 +231,9 @@ class EvoNorm2dS0(nn.Module): x_dtype = x.dtype v_shape = (1, -1, 1, 1) if self.v is not None: - v = self.v.view(v_shape).to(dtype=x_dtype) + v = self.v.view(v_shape).to(x_dtype) x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS0a(EvoNorm2dS0): @@ -247,10 +247,10 @@ class EvoNorm2dS0a(EvoNorm2dS0): v_shape = (1, -1, 1, 1) d = group_std(x, self.groups, self.eps) if self.v is not None: - v = self.v.view(v_shape).to(dtype=x_dtype) + v = self.v.view(v_shape).to(x_dtype) x = x * (x * v).sigmoid() x = x / d - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS1(nn.Module): @@ -284,7 +284,7 @@ class EvoNorm2dS1(nn.Module): v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_std(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS1a(EvoNorm2dS1): @@ -299,7 +299,7 @@ class EvoNorm2dS1a(EvoNorm2dS1): x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_std(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS2(nn.Module): @@ -332,7 +332,7 @@ class EvoNorm2dS2(nn.Module): v_shape = (1, -1, 1, 1) if self.apply_act: x = self.act(x) / group_rms(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) class EvoNorm2dS2a(EvoNorm2dS2): @@ -347,4 +347,4 @@ class EvoNorm2dS2a(EvoNorm2dS2): x_dtype = x.dtype v_shape = (1, -1, 1, 1) x = self.act(x) / group_rms(x, self.groups, self.eps) - return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 5ddb07af..0f386260 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -6,6 +6,7 @@ import torch from torch import nn as nn from torch.nn import functional as F +from .trace_utils import _assert from .create_act import get_act_layer @@ -29,9 +30,10 @@ class BatchNormAct2d(nn.BatchNorm2d): else: self.act = nn.Identity() - def _forward_jit(self, x): - """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function - """ + def forward(self, x): + # cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing + _assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)') + # exponential_average_factor is set to self.momentum # (when it is available) only so that it gets updated # in ONNX graph when this node is exported to ONNX. @@ -63,7 +65,7 @@ class BatchNormAct2d(nn.BatchNorm2d): passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are used for normalization (i.e. in eval mode when buffers are not None). """ - return F.batch_norm( + x = F.batch_norm( x, # If buffers are not to be tracked, ensure that they won't be updated self.running_mean if not self.training or self.track_running_stats else None, @@ -74,17 +76,6 @@ class BatchNormAct2d(nn.BatchNorm2d): exponential_average_factor, self.eps, ) - - @torch.jit.ignore - def _forward_python(self, x): - return super(BatchNormAct2d, self).forward(x) - - def forward(self, x): - # FIXME cannot call parent forward() and maintain jit.script compatibility? - if torch.jit.is_scripting(): - x = self._forward_jit(x) - else: - x = self._forward_python(x) x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 87ea32a6..3e22bf56 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -155,10 +155,10 @@ default_cfgs = dict( regnety_040s_gn=_cfg(url=''), regnetv_040=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth', - first_conv='stem'), + first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)), regnetv_064=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth', - first_conv='stem'), + first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)), regnetz_005=_cfg(url=''), regnetz_040=_cfg( diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index b8262749..ef87dc88 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -4,6 +4,9 @@ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shi Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below +S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from + - https://github.com/microsoft/Cream/tree/main/AutoFormerV2 + Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman """ # -------------------------------------------------------- @@ -669,7 +672,7 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs): @register_model def swin_s3_tiny_224(pretrained=False, **kwargs): - """ Swin-S3-T @ 224x224, ImageNet-1k + """ Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725 """ model_kwargs = dict( patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2), @@ -679,7 +682,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs): @register_model def swin_s3_small_224(pretrained=False, **kwargs): - """ Swin-S3-S @ 224x224, trained ImageNet-1k + """ Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725 """ model_kwargs = dict( patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2), @@ -689,7 +692,7 @@ def swin_s3_small_224(pretrained=False, **kwargs): @register_model def swin_s3_base_224(pretrained=False, **kwargs): - """ Swin-S3-B @ 224x224, trained ImageNet-1k + """ Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725 """ model_kwargs = dict( patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2), diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 79778ab1..17faba53 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -170,6 +170,11 @@ default_cfgs = { '/vit_base_patch16_224_1k_miil_84_4.pth', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', ), + + # experimental + 'vit_small_patch16_36x1_224': _cfg(url=''), + 'vit_small_patch16_18x2_224': _cfg(url=''), + 'vit_base_patch16_18x2_224': _cfg(url=''), } @@ -201,28 +206,81 @@ class Attention(nn.Module): return x +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + class Block(nn.Module): def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): - x = x + self.drop_path1(self.attn(self.norm1(x))) - x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x)))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x +class ParallelBlock(nn.Module): + + def __init__( + self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.num_parallel = num_parallel + self.attns = nn.ModuleList() + self.ffns = nn.ModuleList() + for _ in range(num_parallel): + self.attns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + self.ffns.append(nn.Sequential(OrderedDict([ + ('norm', norm_layer(dim)), + ('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)), + ('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()), + ('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity()) + ]))) + + def _forward_jit(self, x): + x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0) + x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0) + return x + + @torch.jit.ignore + def _forward(self, x): + x = x + sum(attn(x) for attn in self.attns) + x = x + sum(ffn(x) for ffn in self.ffns) + return x + + def forward(self, x): + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return self._forward_jit(x) + else: + return self._forward(x) + + class VisionTransformer(nn.Module): """ Vision Transformer @@ -233,8 +291,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None): + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -248,10 +306,11 @@ class VisionTransformer(nn.Module): mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - weight_init: (str): weight init scheme drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate + weight_init: (str): weight init scheme + init_values: (float): layer-scale init values embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer @@ -277,9 +336,9 @@ class VisionTransformer(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ - Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, - attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) use_fc_norm = self.global_pool == 'avg' self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() @@ -941,3 +1000,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def vit_small_patch16_36x1_224(pretrained=False, **kwargs): + """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_patch16_18x2_224(pretrained=False, **kwargs): + """ ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) + model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_18x2_224(pretrained=False, **kwargs): + """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. + Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) + model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/scheduler/plateau_lr.py b/timm/scheduler/plateau_lr.py index fbfc531f..cacfab3c 100644 --- a/timm/scheduler/plateau_lr.py +++ b/timm/scheduler/plateau_lr.py @@ -30,7 +30,16 @@ class PlateauLRScheduler(Scheduler): noise_seed=None, initialize=True, ): - super().__init__(optimizer, 'lr', initialize=initialize) + super().__init__( + optimizer, + 'lr', + noise_range_t=noise_range_t, + noise_type=noise_type, + noise_pct=noise_pct, + noise_std=noise_std, + noise_seed=noise_seed, + initialize=initialize, + ) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer, @@ -43,11 +52,6 @@ class PlateauLRScheduler(Scheduler): min_lr=lr_min ) - self.noise_range_t = noise_range_t - self.noise_pct = noise_pct - self.noise_type = noise_type - self.noise_std = noise_std - self.noise_seed = noise_seed if noise_seed is not None else 42 self.warmup_t = warmup_t self.warmup_lr_init = warmup_lr_init if self.warmup_t: @@ -84,7 +88,6 @@ class PlateauLRScheduler(Scheduler): if self._is_apply_noise(epoch): self._apply_noise(epoch) - def _apply_noise(self, epoch): noise = self._calculate_noise(epoch)