Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/1414/head
Ross Wightman 2 years ago
commit 1ba0ec4c18

@ -23,8 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### March 23, 2022
* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795)
* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs.
### March 21, 2022
* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch `0.5.x` or a previous 0.5.x release can be used if stability is required.
* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required.
* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights)
* `regnety_040` - 82.3 @ 224, 82.96 @ 288
* `regnety_064` - 83.0 @ 224, 83.65 @ 288
@ -45,7 +49,8 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
* `resnetrs200` - 83.85 @ 256, 84.44 @ 320
* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks.
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets (
* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets
* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
* VOLO models w/ weights adapted from https://github.com/sail-sg/volo
* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
@ -344,13 +349,16 @@ A full version of the list below with source links can be found in the [document
* FBNet-V3 - https://arxiv.org/abs/2006.02049
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646
* LCNet - https://arxiv.org/abs/2109.15099
* MobileViT - https://arxiv.org/abs/2110.02178
* NASNet-A - https://arxiv.org/abs/1707.07012
* NesT - https://arxiv.org/abs/2105.12723
* NFNet-F - https://arxiv.org/abs/2102.06171
* NF-RegNet / NF-ResNet - https://arxiv.org/abs/2101.08692
* PNasNet - https://arxiv.org/abs/1712.00559
* PoolFormer (MetaFormer) - https://arxiv.org/abs/2111.11418
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* RegNet - https://arxiv.org/abs/2003.13678
* RegNetZ - https://arxiv.org/abs/2103.06877
* RepVGG - https://arxiv.org/abs/2101.03697
* ResMLP - https://arxiv.org/abs/2105.03404
* ResNet/ResNeXt
@ -367,12 +375,15 @@ A full version of the list below with source links can be found in the [document
* ReXNet - https://arxiv.org/abs/2007.00992
* SelecSLS - https://arxiv.org/abs/1907.00837
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586
* Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725
* Swin Transformer - https://arxiv.org/abs/2103.14030
* Swin Transformer V2 - https://arxiv.org/abs/2111.09883
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
* TResNet - https://arxiv.org/abs/2003.13630
* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
* Visformer - https://arxiv.org/abs/2104.12533
* Vision Transformer - https://arxiv.org/abs/2010.11929
* VOLO (Vision Outlooker) - https://arxiv.org/abs/2106.13112
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
* Xception - https://arxiv.org/abs/1610.02357
* Xception (Modified Aligned, Gluon) - https://arxiv.org/abs/1802.02611

@ -44,8 +44,14 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''),
convnext_tiny_hnf=_cfg(url=''),
convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95),
convnext_tiny_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth'),
convnext_small_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth'),
convnext_base_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth'),
convnext_large_in22ft1k=_cfg(
@ -53,6 +59,12 @@ default_cfgs = dict(
convnext_xlarge_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth'),
convnext_tiny_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
convnext_small_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
convnext_base_384_in22ft1k=_cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
@ -63,6 +75,10 @@ default_cfgs = dict(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
convnext_tiny_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
convnext_small_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
convnext_base_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
convnext_large_in22k=_cfg(
@ -322,6 +338,8 @@ def _init_weights(module, name=None, head_init_scale=1.0):
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
@ -401,6 +419,20 @@ def convnext_large(pretrained=False, **kwargs):
return model
@register_model
def convnext_tiny_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
@ -422,6 +454,20 @@ def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
return model
@register_model
def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
@ -443,6 +489,20 @@ def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
return model
@register_model
def convnext_tiny_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_small_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_base_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs)
@ -462,6 +522,3 @@ def convnext_xlarge_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args)
return model

@ -15,25 +15,17 @@ except ImportError:
has_fx_feature_extraction = False
# Layers we went to treat as leaf modules
from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath
from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2
from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
from .layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
from .layers.non_local_attn import BilinearAttnTransform
from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
# BUT modules from timm.models should use the registration mechanism below
_leaf_modules = {
BatchNormAct2d, # reason: flow control for jit scripting
BilinearAttnTransform, # reason: flow control t <= 1
BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1]
# Reason: get_same_padding has a max which raises a control flow error
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0])
DropPath, # reason: TypeError: rand recieved Proxy in `size` argument
EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?)
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a,
}
try:

@ -20,7 +20,7 @@ from torch.utils.checkpoint import checkpoint
from .features import FeatureListNet, FeatureDictNet, FeatureHookNet
from .fx_features import FeatureGraphNet
from .hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from .layers import Conv2dSame, Linear
from .layers import Conv2dSame, Linear, BatchNormAct2d
from .registry import get_pretrained_cfg
@ -374,12 +374,19 @@ def adapt_model_from_string(parent_module, model_string):
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
groups=g, stride=old_module.stride)
set_layer(new_module, n, new_conv)
if isinstance(old_module, nn.BatchNorm2d):
elif isinstance(old_module, BatchNormAct2d):
new_bn = BatchNormAct2d(
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
new_bn.drop = old_module.drop
new_bn.act = old_module.act
set_layer(new_module, n, new_bn)
elif isinstance(old_module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
affine=old_module.affine, track_running_stats=True)
set_layer(new_module, n, new_bn)
if isinstance(old_module, nn.Linear):
elif isinstance(old_module, nn.Linear):
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
num_features = state_dict[n + '.weight'][1]
new_fc = Linear(

@ -39,4 +39,4 @@ class BlurPool2d(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding, 'reflect')
return F.conv2d(x, self.filt, stride=self.stride, groups=x.shape[1])
return F.conv2d(x, self.filt, stride=self.stride, groups=self.channels)

@ -107,13 +107,13 @@ class DropBlock2d(nn.Module):
def __init__(
self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
drop_prob: float = 0.1,
block_size: int = 7,
gamma_scale: float = 1.0,
with_noise: bool = False,
inplace: bool = False,
batchwise: bool = False,
fast: bool = True):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
@ -157,7 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep

@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(x_dtype)
return rms.expand(x.shape).reshape(B, C, H, W)
@ -160,14 +160,14 @@ class EvoNorm2dB1(nn.Module):
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = (x + 1) * instance_rms(x, self.eps)
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dB2(nn.Module):
@ -195,14 +195,14 @@ class EvoNorm2dB2(nn.Module):
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1)))
var.detach().to(self.running_var.dtype) * self.momentum * (n / (n - 1)))
else:
var = self.running_var
var = var.to(dtype=x_dtype).view(v_shape)
var = var.to(x_dtype).view(v_shape)
left = var.add(self.eps).sqrt_()
right = instance_rms(x, self.eps) - x
x = x / left.max(right)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS0(nn.Module):
@ -231,9 +231,9 @@ class EvoNorm2dS0(nn.Module):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS0a(EvoNorm2dS0):
@ -247,10 +247,10 @@ class EvoNorm2dS0a(EvoNorm2dS0):
v_shape = (1, -1, 1, 1)
d = group_std(x, self.groups, self.eps)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
v = self.v.view(v_shape).to(x_dtype)
x = x * (x * v).sigmoid()
x = x / d
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS1(nn.Module):
@ -284,7 +284,7 @@ class EvoNorm2dS1(nn.Module):
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS1a(EvoNorm2dS1):
@ -299,7 +299,7 @@ class EvoNorm2dS1a(EvoNorm2dS1):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_std(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS2(nn.Module):
@ -332,7 +332,7 @@ class EvoNorm2dS2(nn.Module):
v_shape = (1, -1, 1, 1)
if self.apply_act:
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)
class EvoNorm2dS2a(EvoNorm2dS2):
@ -347,4 +347,4 @@ class EvoNorm2dS2a(EvoNorm2dS2):
x_dtype = x.dtype
v_shape = (1, -1, 1, 1)
x = self.act(x) / group_rms(x, self.groups, self.eps)
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
return x * self.weight.view(v_shape).to(x_dtype) + self.bias.view(v_shape).to(x_dtype)

@ -6,6 +6,7 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from .trace_utils import _assert
from .create_act import get_act_layer
@ -29,9 +30,10 @@ class BatchNormAct2d(nn.BatchNorm2d):
else:
self.act = nn.Identity()
def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
"""
def forward(self, x):
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
_assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
@ -63,7 +65,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
x = F.batch_norm(
x,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
@ -74,17 +76,6 @@ class BatchNormAct2d(nn.BatchNorm2d):
exponential_average_factor,
self.eps,
)
@torch.jit.ignore
def _forward_python(self, x):
return super(BatchNormAct2d, self).forward(x)
def forward(self, x):
# FIXME cannot call parent forward() and maintain jit.script compatibility?
if torch.jit.is_scripting():
x = self._forward_jit(x)
else:
x = self._forward_python(x)
x = self.drop(x)
x = self.act(x)
return x

@ -155,10 +155,10 @@ default_cfgs = dict(
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_040_ra3-c248f51f.pth',
first_conv='stem'),
first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetv_064=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetv_064_ra3-530616c2.pth',
first_conv='stem'),
first_conv='stem', crop_pct=1.0, test_input_size=(3, 288, 288)),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(

@ -4,6 +4,9 @@ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shi
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
S3 (AutoFormerV2, https://arxiv.org/abs/2111.14725) Swin weights from
- https://github.com/microsoft/Cream/tree/main/AutoFormerV2
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
"""
# --------------------------------------------------------
@ -669,7 +672,7 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
@register_model
def swin_s3_tiny_224(pretrained=False, **kwargs):
""" Swin-S3-T @ 224x224, ImageNet-1k
""" Swin-S3-T @ 224x224, ImageNet-1k. https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 6, 2),
@ -679,7 +682,7 @@ def swin_s3_tiny_224(pretrained=False, **kwargs):
@register_model
def swin_s3_small_224(pretrained=False, **kwargs):
""" Swin-S3-S @ 224x224, trained ImageNet-1k
""" Swin-S3-S @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(14, 14, 14, 7), embed_dim=96, depths=(2, 2, 18, 2),
@ -689,7 +692,7 @@ def swin_s3_small_224(pretrained=False, **kwargs):
@register_model
def swin_s3_base_224(pretrained=False, **kwargs):
""" Swin-S3-B @ 224x224, trained ImageNet-1k
""" Swin-S3-B @ 224x224, trained ImageNet-1k. https://arxiv.org/abs/2111.14725
"""
model_kwargs = dict(
patch_size=4, window_size=(7, 7, 14, 7), embed_dim=96, depths=(2, 2, 30, 2),

@ -170,6 +170,11 @@ default_cfgs = {
'/vit_base_patch16_224_1k_miil_84_4.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
),
# experimental
'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''),
}
@ -201,28 +206,81 @@ class Attention(nn.Module):
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Block(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path1(self.attn(self.norm1(x)))
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x))))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ParallelBlock(nn.Module):
def __init__(
self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.num_parallel = num_parallel
self.attns = nn.ModuleList()
self.ffns = nn.ModuleList()
for _ in range(num_parallel):
self.attns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('attn', Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
])))
self.ffns.append(nn.Sequential(OrderedDict([
('norm', norm_layer(dim)),
('mlp', Mlp(dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)),
('ls', LayerScale(dim, init_values=init_values) if init_values else nn.Identity()),
('drop_path', DropPath(drop_path) if drop_path > 0. else nn.Identity())
])))
def _forward_jit(self, x):
x = x + torch.stack([attn(x) for attn in self.attns]).sum(dim=0)
x = x + torch.stack([ffn(x) for ffn in self.ffns]).sum(dim=0)
return x
@torch.jit.ignore
def _forward(self, x):
x = x + sum(attn(x) for attn in self.attns)
x = x + sum(ffn(x) for ffn in self.ffns)
return x
def forward(self, x):
if torch.jit.is_scripting() or torch.jit.is_tracing():
return self._forward_jit(x)
else:
return self._forward(x)
class VisionTransformer(nn.Module):
""" Vision Transformer
@ -233,8 +291,8 @@ class VisionTransformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None):
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None,
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
"""
Args:
img_size (int, tuple): input image size
@ -248,10 +306,11 @@ class VisionTransformer(nn.Module):
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
weight_init: (str): weight init scheme
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init: (str): weight init scheme
init_values: (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
@ -277,9 +336,9 @@ class VisionTransformer(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
use_fc_norm = self.global_pool == 'avg'
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
@ -941,3 +1000,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Small w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
return model

@ -30,7 +30,16 @@ class PlateauLRScheduler(Scheduler):
noise_seed=None,
initialize=True,
):
super().__init__(optimizer, 'lr', initialize=initialize)
super().__init__(
optimizer,
'lr',
noise_range_t=noise_range_t,
noise_type=noise_type,
noise_pct=noise_pct,
noise_std=noise_std,
noise_seed=noise_seed,
initialize=initialize,
)
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
@ -43,11 +52,6 @@ class PlateauLRScheduler(Scheduler):
min_lr=lr_min
)
self.noise_range_t = noise_range_t
self.noise_pct = noise_pct
self.noise_type = noise_type
self.noise_std = noise_std
self.noise_seed = noise_seed if noise_seed is not None else 42
self.warmup_t = warmup_t
self.warmup_lr_init = warmup_lr_init
if self.warmup_t:
@ -84,7 +88,6 @@ class PlateauLRScheduler(Scheduler):
if self._is_apply_noise(epoch):
self._apply_noise(epoch)
def _apply_noise(self, epoch):
noise = self._calculate_noise(epoch)

Loading…
Cancel
Save