Refactoring, cleanup, improved test coverage.

* Add eca_nfnet_l2 weights, 84.7 @ 384x384
* All 'non-std' (ie transformer / mlp) models have classifier / default_cfg test added
* Fix #694 reset_classifer / num_features / forward_features / num_classes=0 consistency for transformer / mlp models
* Add direct loading of npz to vision transformer (pure transformer so far, hybrid to come)
* Rename vit_deit* to deit_*
* Remove some deprecated vit hybrid model defs
* Clean up classifier flatten for conv classifiers and unusual cases (mobilenetv3/ghostnet)
* Remove explicit model fns for levit conv, just pass in arg
cleanup_xla_model_fixes
Ross Wightman 4 years ago
parent ba2ca4b464
commit 8880f696b6

@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities # transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*'] 'convit_*', 'levit*', 'visformer*', 'deit*']
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size):
state_dict = model.state_dict() state_dict = model.state_dict()
cfg = model.default_cfg cfg = model.default_cfg
classifier = cfg['classifier']
pool_size = cfg['pool_size'] pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size'] input_size = model.default_cfg['input_size']
@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size):
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
# check classifier name matches default_cfg # check classifier name matches default_cfg
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params' classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)):
classifier = classifier,
for c in classifier:
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
# check first conv(s) names match default_cfg
first_conv = cfg['first_conv']
if isinstance(first_conv, str):
first_conv = (first_conv,)
assert isinstance(first_conv, (tuple, list))
for fc in first_conv:
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs_non_std(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
state_dict = model.state_dict()
cfg = model.default_cfg
input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled)
outputs = model.forward_features(input_tensor)
if isinstance(outputs, tuple):
outputs = outputs[0]
assert outputs.shape[1] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
outputs = model.forward(input_tensor)
if isinstance(outputs, tuple):
outputs = outputs[0]
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
# check classifier name matches default_cfg
classifier = cfg['classifier']
if not isinstance(classifier, (tuple, list)):
classifier = classifier,
for c in classifier:
assert c + ".weight" in state_dict.keys(), f'{c} not in model params'
# check first conv(s) names match default_cfg # check first conv(s) names match default_cfg
first_conv = cfg['first_conv'] first_conv = cfg['first_conv']

@ -74,11 +74,11 @@ default_cfgs = dict(
class ClassAttn(nn.Module): class ClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to do CA # with slight modifications to do CA
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias) self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.k = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(dim, dim, bias=qkv_bias)
@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add CA and LayerScale # with slight modifications to add CA and LayerScale
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
mlp_block=Mlp, init_values=1e-4): mlp_block=Mlp, init_values=1e-4):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = attn_block( self.attn = attn_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
@ -134,14 +134,14 @@ class LayerScaleBlockClassAttn(nn.Module):
class TalkingHeadAttn(nn.Module): class TalkingHeadAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add layerScale # with slight modifications to add layerScale
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
mlp_block=Mlp, init_values=1e-4): mlp_block=Mlp, init_values=1e-4):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = attn_block( self.attn = attn_block(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
@ -202,7 +202,7 @@ class Cait(nn.Module):
# with slight modifications to adapt to our cait models # with slight modifications to adapt to our cait models
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_layer=partial(nn.LayerNorm, eps=1e-6),
global_pool=None, global_pool=None,
@ -235,14 +235,14 @@ class Cait(nn.Module):
dpr = [drop_path_rate for i in range(depth)] dpr = [drop_path_rate for i in range(depth)]
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
block_layers( block_layers(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale) act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale)
for i in range(depth)]) for i in range(depth)])
self.blocks_token_only = nn.ModuleList([ self.blocks_token_only = nn.ModuleList([
block_layers_token( block_layers_token(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias,
drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer,
act_layer=act_layer, attn_block=attn_block_token_only, act_layer=act_layer, attn_block=attn_block_token_only,
mlp_block=mlp_block_token_only, init_values=init_scale) mlp_block=mlp_block_token_only, init_values=init_scale)
@ -270,6 +270,13 @@ class Cait(nn.Module):
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token'} return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
B = x.shape[0] B = x.shape[0]
x = self.patch_embed(x) x = self.patch_embed(x)
@ -293,7 +300,6 @@ class Cait(nn.Module):
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) x = self.head(x)
return x return x

@ -335,6 +335,8 @@ class CoaT(nn.Module):
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers self.return_interm_layers = return_interm_layers
self.out_features = out_features self.out_features = out_features
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
self.num_classes = num_classes self.num_classes = num_classes
# Patch embeddings. # Patch embeddings.
@ -441,10 +443,10 @@ class CoaT(nn.Module):
# CoaT series: Aggregate features of last three scales for classification. # CoaT series: Aggregate features of last three scales for classification.
assert embed_dims[1] == embed_dims[2] == embed_dims[3] assert embed_dims[1] == embed_dims[2] == embed_dims[3]
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
self.head = nn.Linear(embed_dims[3], num_classes) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
else: else:
# CoaT-Lite series: Use feature of last scale for classification. # CoaT-Lite series: Use feature of last scale for classification.
self.head = nn.Linear(embed_dims[3], num_classes) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# Initialize weights. # Initialize weights.
trunc_normal_(self.cls_token1, std=.02) trunc_normal_(self.cls_token1, std=.02)
@ -471,7 +473,7 @@ class CoaT(nn.Module):
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def insert_cls(self, x, cls_token): def insert_cls(self, x, cls_token):
""" Insert CLS token. """ """ Insert CLS token. """

@ -57,13 +57,13 @@ default_cfgs = {
class GPSA(nn.Module): class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
locality_strength=1.): locality_strength=1.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
self.dim = dim self.dim = dim
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.locality_strength = locality_strength self.locality_strength = locality_strength
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
@ -142,11 +142,11 @@ class GPSA(nn.Module):
class MHSA(nn.Module): class MHSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
@ -191,19 +191,16 @@ class MHSA(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.use_gpsa = use_gpsa self.use_gpsa = use_gpsa
if self.use_gpsa: if self.use_gpsa:
self.attn = GPSA( self.attn = GPSA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs)
proj_drop=drop, **kwargs)
else: else:
self.attn = MHSA( self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, **kwargs)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
@ -220,7 +217,7 @@ class ConViT(nn.Module):
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
local_up_to_layer=3, locality_strength=1., use_pos_embed=True): local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
super().__init__() super().__init__()
@ -250,13 +247,13 @@ class ConViT(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=True, use_gpsa=True,
locality_strength=locality_strength) locality_strength=locality_strength)
if i < local_up_to_layer else if i < local_up_to_layer else
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=False) use_gpsa=False)
for i in range(depth)]) for i in range(depth)])

@ -288,6 +288,8 @@ class DLA(nn.Module):
self.num_features = channels[-1] self.num_features = channels[-1]
self.global_pool, self.fc = create_classifier( self.global_pool, self.fc = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@ -314,6 +316,7 @@ class DLA(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.fc = create_classifier( self.global_pool, self.fc = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.base_layer(x) x = self.base_layer(x)
@ -331,8 +334,7 @@ class DLA(nn.Module):
if self.drop_rate > 0.: if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training) x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.fc(x) x = self.fc(x)
if not self.global_pool.is_identity(): x = self.flatten(x)
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x return x

@ -237,6 +237,7 @@ class DPN(nn.Module):
# Using 1x1 conv for the FC layer to allow the extra pooling scheme # Using 1x1 conv for the FC layer to allow the extra pooling scheme
self.global_pool, self.classifier = create_classifier( self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def get_classifier(self): def get_classifier(self):
return self.classifier return self.classifier
@ -245,6 +246,7 @@ class DPN(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool, self.classifier = create_classifier( self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
return self.features(x) return self.features(x)
@ -255,8 +257,7 @@ class DPN(nn.Module):
if self.drop_rate > 0.: if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training) x = F.dropout(x, p=self.drop_rate, training=self.training)
x = self.classifier(x) x = self.classifier(x)
if not self.global_pool.is_identity(): x = self.flatten(x)
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x return x

@ -133,7 +133,7 @@ class GhostBottleneck(nn.Module):
class GhostNet(nn.Module): class GhostNet(nn.Module):
def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32): def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'):
super(GhostNet, self).__init__() super(GhostNet, self).__init__()
# setting of inverted residual blocks # setting of inverted residual blocks
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
@ -178,9 +178,10 @@ class GhostNet(nn.Module):
# building last several layers # building last several layers
self.num_features = out_chs = 1280 self.num_features = out_chs = 1280
self.global_pool = SelectAdaptivePool2d(pool_type='avg') self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
self.act2 = nn.ReLU(inplace=True) self.act2 = nn.ReLU(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(out_chs, num_classes) self.classifier = Linear(out_chs, num_classes)
def get_classifier(self): def get_classifier(self):
@ -190,6 +191,7 @@ class GhostNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation # cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
@ -204,8 +206,7 @@ class GhostNet(nn.Module):
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if not self.global_pool.is_identity(): x = self.flatten(x)
x = x.view(x.size(0), -1)
if self.dropout > 0.: if self.dropout > 0.:
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = self.classifier(x) x = self.classifier(x)

@ -45,6 +45,13 @@ def load_state_dict(checkpoint_path, use_ema=False):
def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'):
model.load_pretrained(checkpoint_path)
else:
raise NotImplementedError('Model cannot load numpy checkpoint')
return
state_dict = load_state_dict(checkpoint_path, use_ema) state_dict = load_state_dict(checkpoint_path, use_ema)
model.load_state_dict(state_dict, strict=strict) model.load_state_dict(state_dict, strict=strict)
@ -477,3 +484,25 @@ def model_parameters(model, exclude_head=False):
return [p for p in model.parameters()][:-2] return [p for p in model.parameters()][:-2]
else: else:
return model.parameters() return model.parameters()
def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
def named_modules(module: nn.Module, name='', depth_first=True, include_root=False):
if not depth_first and include_root:
yield name, module
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
yield from named_modules(
module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
yield name, module

@ -55,7 +55,7 @@ class FastAdaptiveAvgPool2d(nn.Module):
self.flatten = flatten self.flatten = flatten
def forward(self, x): def forward(self, x):
return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) return x.mean((2, 3), keepdim=not self.flatten)
class AdaptiveAvgMaxPool2d(nn.Module): class AdaptiveAvgMaxPool2d(nn.Module):
@ -82,13 +82,13 @@ class SelectAdaptivePool2d(nn.Module):
def __init__(self, output_size=1, pool_type='fast', flatten=False): def __init__(self, output_size=1, pool_type='fast', flatten=False):
super(SelectAdaptivePool2d, self).__init__() super(SelectAdaptivePool2d, self).__init__()
self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing
self.flatten = flatten self.flatten = nn.Flatten(1) if flatten else nn.Identity()
if pool_type == '': if pool_type == '':
self.pool = nn.Identity() # pass through self.pool = nn.Identity() # pass through
elif pool_type == 'fast': elif pool_type == 'fast':
assert output_size == 1 assert output_size == 1
self.pool = FastAdaptiveAvgPool2d(self.flatten) self.pool = FastAdaptiveAvgPool2d(flatten)
self.flatten = False self.flatten = nn.Identity()
elif pool_type == 'avg': elif pool_type == 'avg':
self.pool = nn.AdaptiveAvgPool2d(output_size) self.pool = nn.AdaptiveAvgPool2d(output_size)
elif pool_type == 'avgmax': elif pool_type == 'avgmax':
@ -101,12 +101,11 @@ class SelectAdaptivePool2d(nn.Module):
assert False, 'Invalid pool type: %s' % pool_type assert False, 'Invalid pool type: %s' % pool_type
def is_identity(self): def is_identity(self):
return self.pool_type == '' return not self.pool_type
def forward(self, x): def forward(self, x):
x = self.pool(x) x = self.pool(x)
if self.flatten: x = self.flatten(x)
x = x.flatten(1)
return x return x
def feat_mult(self): def feat_mult(self):

@ -20,7 +20,7 @@ def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
return global_pool, num_pooled_features return global_pool, num_pooled_features
def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False): def _create_fc(num_features, num_classes, use_conv=False):
if num_classes <= 0: if num_classes <= 0:
fc = nn.Identity() # pass-through (no classifier) fc = nn.Identity() # pass-through (no classifier)
elif use_conv: elif use_conv:
@ -45,11 +45,12 @@ class ClassifierHead(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten_after_fc = use_conv and pool_type self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def forward(self, x): def forward(self, x):
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: if self.drop_rate:
x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = F.dropout(x, p=float(self.drop_rate), training=self.training)
x = self.fc(x) x = self.fc(x)
x = self.flatten(x)
return x return x

@ -40,6 +40,12 @@ class GluMlp(nn.Module):
self.fc2 = nn.Linear(hidden_features // 2, out_features) self.fc2 = nn.Linear(hidden_features // 2, out_features)
self.drop = nn.Dropout(drop) self.drop = nn.Dropout(drop)
def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1
fc1_mid = self.fc1.bias.shape[0] // 2
nn.init.ones_(self.fc1.bias[fc1_mid:])
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x, gates = x.chunk(2, dim=-1) x, gates = x.chunk(2, dim=-1)

@ -84,63 +84,33 @@ __all__ = ['Levit']
@register_model @register_model
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): def levit_128s(pretrained=False, use_conv=False, **kwargs):
return create_levit( return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): def levit_128(pretrained=False, use_conv=False, **kwargs):
return create_levit( return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): def levit_192(pretrained=False, use_conv=False, **kwargs):
return create_levit( return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): def levit_256(pretrained=False, use_conv=False, **kwargs):
return create_levit( return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model @register_model
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): def levit_384(pretrained=False, use_conv=False, **kwargs):
return create_levit( return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(nn.Sequential): class ConvNorm(nn.Sequential):
@ -427,6 +397,9 @@ class AttentionSubsample(nn.Module):
class Levit(nn.Module): class Levit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer with support for patch or hybrid CNN input stage
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
w/ train scripts that don't take tuple outputs,
""" """
def __init__( def __init__(
@ -447,7 +420,8 @@ class Levit(nn.Module):
attn_act_layer='hard_swish', attn_act_layer='hard_swish',
distillation=True, distillation=True,
use_conv=False, use_conv=False,
drop_path=0): drop_rate=0.,
drop_path_rate=0.):
super().__init__() super().__init__()
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer) attn_act_layer = get_act_layer(attn_act_layer)
@ -486,7 +460,7 @@ class Levit(nn.Module):
Attention( Attention(
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
resolution=resolution, use_conv=use_conv), resolution=resolution, use_conv=use_conv),
drop_path)) drop_path_rate))
if mr > 0: if mr > 0:
h = int(ed * mr) h = int(ed * mr)
self.blocks.append( self.blocks.append(
@ -494,7 +468,7 @@ class Levit(nn.Module):
ln_layer(ed, h, resolution=resolution), ln_layer(ed, h, resolution=resolution),
act_layer(), act_layer(),
ln_layer(h, ed, bn_weight_init=0, resolution=resolution), ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path)) ), drop_path_rate))
if do[0] == 'Subsample': if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1 resolution_ = (resolution - 1) // do[5] + 1
@ -511,26 +485,45 @@ class Levit(nn.Module):
ln_layer(embed_dim[i + 1], h, resolution=resolution), ln_layer(embed_dim[i + 1], h, resolution=resolution),
act_layer(), act_layer(),
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path)) ), drop_path_rate))
self.blocks = nn.Sequential(*self.blocks) self.blocks = nn.Sequential(*self.blocks)
# Classifier head # Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = None
if distillation: if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x} return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x): def get_classifier(self):
if self.head_dist is None:
return self.head
else:
return self.head, self.head_dist
def reset_classifier(self, num_classes, global_pool='', distillation=None):
self.num_classes = num_classes
self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation is not None:
self.distillation = distillation
if self.distillation:
self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if not self.use_conv: if not self.use_conv:
x = x.flatten(2).transpose(1, 2) x = x.flatten(2).transpose(1, 2)
x = self.blocks(x) x = self.blocks(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1) x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
return x
def forward(self, x):
x = self.forward_features(x)
if self.head_dist is not None: if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x) x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting(): if self.training and not torch.jit.is_scripting():

@ -45,7 +45,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply
from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model
@ -169,6 +169,11 @@ class SpatialGatingUnit(nn.Module):
self.norm = norm_layer(gate_dim) self.norm = norm_layer(gate_dim)
self.proj = nn.Linear(seq_len, seq_len) self.proj = nn.Linear(seq_len, seq_len)
def init_weights(self):
# special init for the projection gate, called as override by base model init
nn.init.normal_(self.proj.weight, std=1e-6)
nn.init.ones_(self.proj.bias)
def forward(self, x): def forward(self, x):
u, v = x.chunk(2, dim=-1) u, v = x.chunk(2, dim=-1)
v = self.norm(v) v = self.norm(v)
@ -205,7 +210,7 @@ class MlpMixer(nn.Module):
in_chans=3, in_chans=3,
patch_size=16, patch_size=16,
num_blocks=8, num_blocks=8,
hidden_dim=512, embed_dim=512,
mlp_ratio=(0.5, 4.0), mlp_ratio=(0.5, 4.0),
block_layer=MixerBlock, block_layer=MixerBlock,
mlp_layer=Mlp, mlp_layer=Mlp,
@ -218,59 +223,71 @@ class MlpMixer(nn.Module):
): ):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.stem = PatchEmbed( self.stem = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, img_size=img_size, patch_size=patch_size, in_chans=in_chans,
norm_layer=norm_layer if stem_norm else None) embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None)
# FIXME drop_path (stochastic depth scaling rule or all the same?) # FIXME drop_path (stochastic depth scaling rule or all the same?)
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
block_layer( block_layer(
hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer,
act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate)
for _ in range(num_blocks)]) for _ in range(num_blocks)])
self.norm = norm_layer(hidden_dim) self.norm = norm_layer(embed_dim)
self.head = nn.Linear(hidden_dim, self.num_classes) # zero init self.head = nn.Linear(embed_dim, self.num_classes) # zero init
self.init_weights(nlhb=nlhb) self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False): def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0. head_bias = -math.log(self.num_classes) if nlhb else 0.
for n, m in self.named_modules(): named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
_init_weights(m, n, head_bias=head_bias)
def forward(self, x): def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.blocks(x) x = self.blocks(x)
x = self.norm(x) x = self.norm(x)
x = x.mean(dim=1) x = x.mean(dim=1)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x) x = self.head(x)
return x return x
def _init_weights(m, n: str, head_bias: float = 0.): def _init_weights(module: nn.Module, name: str, head_bias: float = 0.):
""" Mixer weight initialization (trying to match Flax defaults) """ Mixer weight initialization (trying to match Flax defaults)
""" """
if isinstance(m, nn.Linear): if isinstance(module, nn.Linear):
if n.startswith('head'): if name.startswith('head'):
nn.init.zeros_(m.weight) nn.init.zeros_(module.weight)
nn.init.constant_(m.bias, head_bias) nn.init.constant_(module.bias, head_bias)
elif n.endswith('gate.proj'):
nn.init.normal_(m.weight, std=1e-4)
nn.init.ones_(m.bias)
else: else:
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(module.weight)
if m.bias is not None: if module.bias is not None:
if 'mlp' in n: if 'mlp' in name:
nn.init.normal_(m.bias, std=1e-6) nn.init.normal_(module.bias, std=1e-6)
else: else:
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
elif isinstance(m, nn.Conv2d): elif isinstance(module, nn.Conv2d):
lecun_normal_(m.weight) lecun_normal_(module.weight)
if m.bias is not None: if module.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
elif isinstance(m, nn.LayerNorm): elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
nn.init.zeros_(m.bias) nn.init.ones_(module.weight)
nn.init.ones_(m.weight) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
# NOTE if a parent module contains init_weights method, it can override the init of the
# child modules as this will be called in depth-first order.
module.init_weights()
def _create_mixer(variant, pretrained=False, **kwargs): def _create_mixer(variant, pretrained=False, **kwargs):
@ -289,7 +306,7 @@ def mixer_s32_224(pretrained=False, **kwargs):
""" Mixer-S/32 224x224 """ Mixer-S/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs) model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
return model return model
@ -299,7 +316,7 @@ def mixer_s16_224(pretrained=False, **kwargs):
""" Mixer-S/16 224x224 """ Mixer-S/16 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs) model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
return model return model
@ -309,7 +326,7 @@ def mixer_b32_224(pretrained=False, **kwargs):
""" Mixer-B/32 224x224 """ Mixer-B/32 224x224
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs) model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
return model return model
@ -319,7 +336,7 @@ def mixer_b16_224(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights. """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
return model return model
@ -329,7 +346,7 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights. """ Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args)
return model return model
@ -339,7 +356,7 @@ def mixer_l32_224(pretrained=False, **kwargs):
""" Mixer-L/32 224x224. """ Mixer-L/32 224x224.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs) model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
return model return model
@ -349,7 +366,7 @@ def mixer_l16_224(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-1k pretrained weights. """ Mixer-L/16 224x224. ImageNet-1k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
return model return model
@ -359,35 +376,38 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs):
""" Mixer-L/16 224x224. ImageNet-21k pretrained weights. """ Mixer-L/16 224x224. ImageNet-21k pretrained weights.
Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
""" """
model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def mixer_b16_224_miil(pretrained=False, **kwargs): def mixer_b16_224_miil(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights. """ Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
""" """
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights. """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
""" """
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def gmixer_12_224(pretrained=False, **kwargs): def gmixer_12_224(pretrained=False, **kwargs):
""" Glu-Mixer-12 224x224 (short & fat) """ Glu-Mixer-12 224x224 (short & fat)
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
""" """
model_args = dict( model_args = dict(
patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0), patch_size=16, num_blocks=12, embed_dim=512, mlp_ratio=(1.0, 6.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
return model return model
@ -399,7 +419,7 @@ def gmixer_24_224(pretrained=False, **kwargs):
Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer
""" """
model_args = dict( model_args = dict(
patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0), patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 6.0),
mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
return model return model
@ -411,7 +431,7 @@ def resmlp_12_224(pretrained=False, **kwargs):
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
""" """
model_args = dict( model_args = dict(
patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args) model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
return model return model
@ -422,7 +442,7 @@ def resmlp_24_224(pretrained=False, **kwargs):
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
""" """
model_args = dict( model_args = dict(
patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
return model return model
@ -434,7 +454,7 @@ def resmlp_36_224(pretrained=False, **kwargs):
Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
""" """
model_args = dict( model_args = dict(
patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
return model return model
@ -446,7 +466,7 @@ def gmlp_ti16_224(pretrained=False, **kwargs):
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
""" """
model_args = dict( model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs) mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args) model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
return model return model
@ -458,7 +478,7 @@ def gmlp_s16_224(pretrained=False, **kwargs):
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
""" """
model_args = dict( model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs) mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args) model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
return model return model
@ -470,7 +490,7 @@ def gmlp_b16_224(pretrained=False, **kwargs):
Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
""" """
model_args = dict( model_args = dict(
patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
mlp_layer=GatedMlp, **kwargs) mlp_layer=GatedMlp, **kwargs)
model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args) model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
return model return model

@ -119,6 +119,7 @@ class MobileNetV3(nn.Module):
num_pooled_chs = head_chs * self.global_pool.feat_mult() num_pooled_chs = head_chs * self.global_pool.feat_mult()
self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias)
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
efficientnet_init_weights(self) efficientnet_init_weights(self)
@ -137,6 +138,7 @@ class MobileNetV3(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
# cannot meaningfully change pooling of efficient head after creation # cannot meaningfully change pooling of efficient head after creation
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
@ -151,8 +153,7 @@ class MobileNetV3(nn.Module):
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
if not self.global_pool.is_identity(): x = self.flatten(x)
x = x.flatten(1)
if self.drop_rate > 0.: if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training) x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x) return self.classifier(x)

@ -111,11 +111,11 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
eca_nfnet_l2=_dcfg( eca_nfnet_l2=_dcfg(
url='', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth',
pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0), pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
eca_nfnet_l3=_dcfg( eca_nfnet_l3=_dcfg(
url='', url='',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0),
nf_regnet_b0=_dcfg( nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
@ -210,6 +210,7 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski
return cfg return cfg
model_cfgs = dict( model_cfgs = dict(
# NFNet-F models w/ GELU compatible with DeepMind weights # NFNet-F models w/ GELU compatible with DeepMind weights
dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)),

@ -186,12 +186,13 @@ class PoolingVisionTransformer(nn.Module):
] ]
self.transformers = SequentialTuple(*transformers) self.transformers = SequentialTuple(*transformers)
self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
self.embed_dim = base_dims[-1] * heads[-1] self.num_features = self.embed_dim = base_dims[-1] * heads[-1]
# Classifier head # Classifier head
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ self.head_dist = None
if num_classes > 0 and distilled else nn.Identity() if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
@ -207,13 +208,16 @@ class PoolingVisionTransformer(nn.Module):
return {'pos_embed', 'cls_token'} return {'pos_embed', 'cls_token'}
def get_classifier(self): def get_classifier(self):
return self.head if self.head_dist is not None:
return self.head, self.head_dist
else:
return self.head
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ if self.head_dist is not None:
if num_classes > 0 and self.num_tokens == 2 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
@ -221,19 +225,21 @@ class PoolingVisionTransformer(nn.Module):
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
x, cls_tokens = self.transformers((x, cls_tokens)) x, cls_tokens = self.transformers((x, cls_tokens))
cls_tokens = self.norm(cls_tokens) cls_tokens = self.norm(cls_tokens)
return cls_tokens if self.head_dist is not None:
return cls_tokens[:, 0], cls_tokens[:, 1]
else:
return cls_tokens[:, 0]
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x_cls = self.head(x[:, 0]) if self.head_dist is not None:
if self.num_tokens > 1: x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
x_dist = self.head_dist(x[:, 1])
if self.training and not torch.jit.is_scripting(): if self.training and not torch.jit.is_scripting():
return x_cls, x_dist return x, x_dist
else: else:
return (x_cls + x_dist) / 2 return (x + x_dist) / 2
else: else:
return x_cls return self.head(x)
def checkpoint_filter_fn(state_dict, model): def checkpoint_filter_fn(state_dict, model):

@ -65,11 +65,18 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
""" """
if module: if module:
models = list(_module_to_models[module]) all_models = list(_module_to_models[module])
else: else:
models = _model_entrypoints.keys() all_models = _model_entrypoints.keys()
if filter: if filter:
models = fnmatch.filter(models, filter) # include these models models = []
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
for f in include_filters:
include_models = fnmatch.filter(all_models, f) # include these models
if len(include_models):
models = set(models).union(include_models)
else:
models = all_models
if exclude_filters: if exclude_filters:
if not isinstance(exclude_filters, (tuple, list)): if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters] exclude_filters = [exclude_filters]

@ -638,12 +638,15 @@ class ResNet(nn.Module):
self.num_features = 512 * block.expansion self.num_features = 512 * block.expansion
self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.init_weights(zero_init_last_bn=zero_init_last_bn)
def init_weights(self, zero_init_last_bn=True):
for n, m in self.named_modules(): for n, m in self.named_modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1.) nn.init.ones_(m.weight)
nn.init.constant_(m.bias, 0.) nn.init.zeros_(m.bias)
if zero_init_last_bn: if zero_init_last_bn:
for m in self.modules(): for m in self.modules():
if hasattr(m, 'zero_init_last_bn'): if hasattr(m, 'zero_init_last_bn'):

@ -35,9 +35,9 @@ import torch.nn as nn
from functools import partial from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .registry import register_model from .registry import register_model
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
def _cfg(url='', **kwargs): def _cfg(url='', **kwargs):
@ -86,20 +86,10 @@ default_cfgs = {
url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz',
num_classes=21843), num_classes=21843),
'resnetv2_50': _cfg(
# trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'),
# 'resnetv2_50x1_bits': _cfg( 'resnetv2_50d': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'), input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic', first_conv='stem.conv1'),
# 'resnetv2_50x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'),
# 'resnetv2_101x1_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_101x3_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'),
# 'resnetv2_152x2_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'),
# 'resnetv2_152x4_bits': _cfg(
# url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'),
} }
@ -111,13 +101,6 @@ def make_div(v, divisor=8):
return new_v return new_v
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:
conv_weights = conv_weights.transpose([3, 2, 0, 1])
return torch.from_numpy(conv_weights)
class PreActBottleneck(nn.Module): class PreActBottleneck(nn.Module):
"""Pre-activation (v2) bottleneck block. """Pre-activation (v2) bottleneck block.
@ -152,6 +135,9 @@ class PreActBottleneck(nn.Module):
self.conv3 = conv_layer(mid_chs, out_chs, 1) self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last_bn(self):
nn.init.zeros_(self.norm3.weight)
def forward(self, x): def forward(self, x):
x_preact = self.norm1(x) x_preact = self.norm1(x)
@ -198,6 +184,9 @@ class Bottleneck(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.act3 = act_layer(inplace=True) self.act3 = act_layer(inplace=True)
def zero_init_last_bn(self):
nn.init.zeros_(self.norm3.weight)
def forward(self, x): def forward(self, x):
# shortcut branch # shortcut branch
shortcut = x shortcut = x
@ -276,7 +265,7 @@ class ResNetStage(nn.Module):
def create_resnetv2_stem( def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True, in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)): conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
stem = OrderedDict() stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
@ -285,14 +274,17 @@ def create_resnetv2_stem(
# A 3 deep 3x3 conv stack as in ResNet V1D models # A 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2 mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['norm1'] = norm_layer(mid_chs)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['norm2'] = norm_layer(mid_chs)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
if not preact:
stem['norm3'] = norm_layer(out_chs)
else: else:
# The usual 7x7 stem conv # The usual 7x7 stem conv
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
if not preact:
if not preact: stem['norm'] = norm_layer(out_chs)
stem['norm'] = norm_layer(out_chs)
if 'fixed' in stem_type: if 'fixed' in stem_type:
# 'fixed' SAME padding approximation that is used in BiT models # 'fixed' SAME padding approximation that is used in BiT models
@ -312,11 +304,12 @@ class ResNetV2(nn.Module):
"""Implementation of Pre-activation (v2) ResNet mode. """Implementation of Pre-activation (v2) ResNet mode.
""" """
def __init__(self, layers, channels=(256, 512, 1024, 2048), def __init__(
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, self, layers, channels=(256, 512, 1024, 2048),
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8), width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.): act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last_bn=True):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -354,12 +347,14 @@ class ResNetV2(nn.Module):
self.head = ClassifierHead( self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True)
for n, m in self.named_modules(): self.init_weights(zero_init_last_bn=zero_init_last_bn)
if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)):
nn.init.normal_(m.weight, mean=0.0, std=0.01) def init_weights(self, zero_init_last_bn=True):
nn.init.zeros_(m.bias) named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') @torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
_load_weights(self, checkpoint_path, prefix)
def get_classifier(self): def get_classifier(self):
return self.head.fc return self.head.fc
@ -378,41 +373,59 @@ class ResNetV2(nn.Module):
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) x = self.head(x)
if not self.head.global_pool.is_identity():
x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled)
return x return x
def load_pretrained(self, checkpoint_path, prefix='resnet/'):
import numpy as np def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True):
weights = np.load(checkpoint_path) if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)):
with torch.no_grad(): nn.init.normal_(module.weight, mean=0.0, std=0.01)
stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']) nn.init.zeros_(module.bias)
if self.stem.conv.weight.shape[1] == 1: elif isinstance(module, nn.Conv2d):
self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True)) nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
# FIXME handle > 3 in_chans? if module.bias is not None:
else: nn.init.zeros_(module.bias)
self.stem.conv.weight.copy_(stem_conv_w) elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)):
self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) nn.init.ones_(module.weight)
self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) nn.init.zeros_(module.bias)
if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'):
self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) module.zero_init_last_bn()
self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(self.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()): @torch.no_grad()
convname = 'standardized_conv2d' def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'):
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' import numpy as np
block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel']))
block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel'])) def t2p(conv_weights):
block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel'])) """Possibly convert HWIO to OIHW."""
block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma'])) if conv_weights.ndim == 4:
block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma'])) conv_weights = conv_weights.transpose([3, 2, 0, 1])
block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma'])) return torch.from_numpy(conv_weights)
block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta']))
block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta'])) weights = np.load(checkpoint_path)
block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta'])) stem_conv_w = adapt_input_conv(
if block.downsample is not None: model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel']))
w = weights[f'{block_prefix}a/proj/{convname}/kernel'] model.stem.conv.weight.copy_(stem_conv_w)
block.downsample.conv.weight.copy_(tf2th(w)) model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma']))
model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta']))
if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]:
model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel']))
model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias']))
for i, (sname, stage) in enumerate(model.stages.named_children()):
for j, (bname, block) in enumerate(stage.blocks.named_children()):
cname = 'standardized_conv2d'
block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/'
block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel']))
block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel']))
block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel']))
block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma']))
block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma']))
block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma']))
block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta']))
block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta']))
block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta']))
if block.downsample is not None:
w = weights[f'{block_prefix}a/proj/{cname}/kernel']
block.downsample.conv.weight.copy_(t2p(w))
def _create_resnetv2(variant, pretrained=False, **kwargs): def _create_resnetv2(variant, pretrained=False, **kwargs):
@ -425,130 +438,99 @@ def _create_resnetv2(variant, pretrained=False, **kwargs):
**kwargs) **kwargs)
def _create_resnetv2_bit(variant, pretrained=False, **kwargs):
return _create_resnetv2(
variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs)
@register_model @register_model
def resnetv2_50x1_bitm(pretrained=False, **kwargs): def resnetv2_50x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_50x1_bitm', pretrained=pretrained, 'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs)
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_50x3_bitm(pretrained=False, **kwargs): def resnetv2_50x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_50x3_bitm', pretrained=pretrained, 'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs)
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_101x1_bitm(pretrained=False, **kwargs): def resnetv2_101x1_bitm(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_101x1_bitm', pretrained=pretrained, 'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs)
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_101x3_bitm(pretrained=False, **kwargs): def resnetv2_101x3_bitm(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_101x3_bitm', pretrained=pretrained, 'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs)
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_152x2_bitm(pretrained=False, **kwargs): def resnetv2_152x2_bitm(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_152x2_bitm', pretrained=pretrained, 'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs)
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_152x4_bitm(pretrained=False, **kwargs): def resnetv2_152x4_bitm(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_152x4_bitm', pretrained=pretrained, 'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs)
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
@register_model @register_model
def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) layers=[3, 4, 6, 3], width_factor=1, **kwargs)
@register_model @register_model
def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) layers=[3, 4, 6, 3], width_factor=3, **kwargs)
@register_model @register_model
def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2(
'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) layers=[3, 4, 23, 3], width_factor=1, **kwargs)
@register_model @register_model
def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) layers=[3, 4, 23, 3], width_factor=3, **kwargs)
@register_model @register_model
def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) layers=[3, 8, 36, 3], width_factor=2, **kwargs)
@register_model @register_model
def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs):
return _create_resnetv2( return _create_resnetv2_bit(
'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843),
layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) layers=[3, 8, 36, 3], width_factor=4, **kwargs)
# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M. @register_model
def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs)
# @register_model
# def resnetv2_50x1_bits(pretrained=False, **kwargs): @register_model
# return _create_resnetv2( def resnetv2_50d(pretrained=False, **kwargs):
# 'resnetv2_50x1_bits', pretrained=pretrained, return _create_resnetv2(
# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) 'resnetv2_50d', pretrained=pretrained,
# layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d,
# stem_type='deep', avg_down=True, **kwargs)
# @register_model
# def resnetv2_50x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_50x3_bits', pretrained=pretrained,
# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x1_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x1_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_101x3_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_101x3_bits', pretrained=pretrained,
# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x2_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x2_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs)
#
#
# @register_model
# def resnetv2_152x4_bits(pretrained=False, **kwargs):
# return _create_resnetv2(
# 'resnetv2_152x4_bits', pretrained=pretrained,
# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs)
#

@ -126,19 +126,18 @@ class WindowAttention(nn.Module):
window_size (tuple[int]): The height and width of the window. window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads. num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0
""" """
def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
self.window_size = window_size # Wh, Ww self.window_size = window_size # Wh, Ww
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
# define a parameter table of relative position bias # define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter( self.relative_position_bias_table = nn.Parameter(
@ -210,7 +209,6 @@ class SwinTransformerBlock(nn.Module):
shift_size (int): Shift size for SW-MSA. shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0 drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0
@ -219,7 +217,7 @@ class SwinTransformerBlock(nn.Module):
""" """
def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm): act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -236,8 +234,8 @@ class SwinTransformerBlock(nn.Module):
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = WindowAttention( self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) attn_drop=attn_drop, proj_drop=drop)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
@ -369,7 +367,6 @@ class BasicLayer(nn.Module):
window_size (int): Local window size. window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0 drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
@ -379,7 +376,7 @@ class BasicLayer(nn.Module):
""" """
def __init__(self, dim, input_resolution, depth, num_heads, window_size, def __init__(self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):
super().__init__() super().__init__()
@ -390,14 +387,11 @@ class BasicLayer(nn.Module):
# build blocks # build blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=dim, input_resolution=input_resolution, SwinTransformerBlock(
num_heads=num_heads, window_size=window_size, dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2, shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer)
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer)
for i in range(depth)]) for i in range(depth)])
# patch merging layer # patch merging layer
@ -436,7 +430,6 @@ class SwinTransformer(nn.Module):
window_size (int): Window size. Default: 7 window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0 drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1 drop_path_rate (float): Stochastic depth rate. Default: 0.1
@ -448,7 +441,7 @@ class SwinTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000,
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True, norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
use_checkpoint=False, weight_init='', **kwargs): use_checkpoint=False, weight_init='', **kwargs):
@ -491,8 +484,9 @@ class SwinTransformer(nn.Module):
num_heads=num_heads[i_layer], num_heads=num_heads[i_layer],
window_size=window_size, window_size=window_size,
mlp_ratio=self.mlp_ratio, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer, norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
@ -520,6 +514,13 @@ class SwinTransformer(nn.Module):
def no_weight_decay_keywords(self): def no_weight_decay_keywords(self):
return {'relative_position_bias_table'} return {'relative_position_bias_table'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.absolute_pos_embed is not None: if self.absolute_pos_embed is not None:

@ -278,6 +278,8 @@ class Twins(nn.Module):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.depths = depths self.depths = depths
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
prev_chs = in_chans prev_chs = in_chans
@ -303,10 +305,10 @@ class Twins(nn.Module):
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
self.norm = norm_layer(embed_dims[-1]) self.norm = norm_layer(self.num_features)
# classification head # classification head
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# init weights # init weights
self.apply(self._init_weights) self.apply(self._init_weights)
@ -320,7 +322,7 @@ class Twins(nn.Module):
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def _init_weights(self, m): def _init_weights(self, m):
if isinstance(m, nn.Linear): if isinstance(m, nn.Linear):

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier
from .registry import register_model from .registry import register_model
@ -140,14 +140,14 @@ class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None):
super().__init__() super().__init__()
img_size = to_2tuple(img_size)
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim self.embed_dim = embed_dim
self.init_channels = init_channels self.init_channels = init_channels
self.img_size = img_size self.img_size = img_size
self.vit_stem = vit_stem self.vit_stem = vit_stem
self.pool = pool
self.conv_init = conv_init self.conv_init = conv_init
if isinstance(depth, (list, tuple)): if isinstance(depth, (list, tuple)):
self.stage_num1, self.stage_num2, self.stage_num3 = depth self.stage_num1, self.stage_num2, self.stage_num3 = depth
@ -164,31 +164,31 @@ class Visformer(nn.Module):
self.patch_embed1 = PatchEmbed( self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 16 img_size = [x // 16 for x in img_size]
else: else:
if self.init_channels is None: if self.init_channels is None:
self.stem = None self.stem = None
self.patch_embed1 = PatchEmbed( self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 8 img_size = [x // 8 for x in img_size]
else: else:
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.init_channels), nn.BatchNorm2d(self.init_channels),
nn.ReLU(inplace=True) nn.ReLU(inplace=True)
) )
img_size //= 2 img_size = [x // 2 for x in img_size]
self.patch_embed1 = PatchEmbed( self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 4 img_size = [x // 4 for x in img_size]
if self.pos_embed: if self.pos_embed:
if self.vit_stem: if self.vit_stem:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
else: else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size))
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
self.stage1 = nn.ModuleList([ self.stage1 = nn.ModuleList([
Block( Block(
@ -199,14 +199,14 @@ class Visformer(nn.Module):
for i in range(self.stage_num1) for i in range(self.stage_num1)
]) ])
#stage2 # stage2
if not self.vit_stem: if not self.vit_stem:
self.patch_embed2 = PatchEmbed( self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 2 img_size = [x // 2 for x in img_size]
if self.pos_embed: if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
self.stage2 = nn.ModuleList([ self.stage2 = nn.ModuleList([
Block( Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
@ -221,9 +221,9 @@ class Visformer(nn.Module):
self.patch_embed3 = PatchEmbed( self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size //= 2 img_size = [x // 2 for x in img_size]
if self.pos_embed: if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
self.stage3 = nn.ModuleList([ self.stage3 = nn.ModuleList([
Block( Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
@ -234,11 +234,10 @@ class Visformer(nn.Module):
]) ])
# head # head
if self.pool: self.num_features = embed_dim if self.vit_stem else embed_dim * 2
self.global_pooling = nn.AdaptiveAvgPool2d(1) self.norm = norm_layer(self.num_features)
head_dim = embed_dim if self.vit_stem else embed_dim * 2 self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
self.norm = norm_layer(head_dim) self.head = nn.Linear(self.num_features, num_classes)
self.head = nn.Linear(head_dim, num_classes)
# weights init # weights init
if self.pos_embed: if self.pos_embed:
@ -267,7 +266,14 @@ class Visformer(nn.Module):
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0.) nn.init.constant_(m.bias, 0.)
def forward(self, x): def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool='avg'):
self.num_classes = num_classes
self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
def forward_features(self, x):
if self.stem is not None: if self.stem is not None:
x = self.stem(x) x = self.stem(x)
@ -297,14 +303,13 @@ class Visformer(nn.Module):
for b in self.stage3: for b in self.stage3:
x = b(x) x = b(x)
# head
x = self.norm(x) x = self.norm(x)
if self.pool: return x
x = self.global_pooling(x)
else:
x = x[:, :, 0, 0]
x = self.head(x.view(x.size(0), -1)) def forward(self, x):
x = self.forward_features(x)
x = self.global_pool(x)
x = self.head(x)
return x return x
@ -321,7 +326,7 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
@register_model @register_model
def visformer_tiny(pretrained=False, **kwargs): def visformer_tiny(pretrained=False, **kwargs):
model_cfg = dict( model_cfg = dict(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs) embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
@ -331,7 +336,7 @@ def visformer_tiny(pretrained=False, **kwargs):
@register_model @register_model
def visformer_small(pretrained=False, **kwargs): def visformer_small(pretrained=False, **kwargs):
model_cfg = dict( model_cfg = dict(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs) embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)

@ -28,7 +28,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg from .helpers import build_model_with_cfg, named_apply, adapt_input_conv
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_
from .registry import register_model from .registry import register_model
@ -47,9 +47,18 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
# patch models (my experiments) # FIXME weights coming
'vit_tiny_patch16_224': _cfg(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vit_small_patch16_224': _cfg( 'vit_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
),
'vit_small_patch32_224': _cfg(
url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
), ),
# patch models (weights ported from official Google JAX impl) # patch models (weights ported from official Google JAX impl)
@ -97,29 +106,29 @@ default_cfgs = {
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# deit models (FB weights) # deit models (FB weights)
'vit_deit_tiny_patch16_224': _cfg( 'deit_tiny_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
'vit_deit_small_patch16_224': _cfg( 'deit_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
'vit_deit_base_patch16_224': _cfg( 'deit_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
'vit_deit_base_patch16_384': _cfg( 'deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
'vit_deit_tiny_distilled_patch16_224': _cfg( 'deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
classifier=('head', 'head_dist')), classifier=('head', 'head_dist')),
'vit_deit_small_distilled_patch16_224': _cfg( 'deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
classifier=('head', 'head_dist')), classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_224': _cfg( 'deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
classifier=('head', 'head_dist')), classifier=('head', 'head_dist')),
'vit_deit_base_distilled_patch16_384': _cfg( 'deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
# ViT ImageNet-21K-P pretraining # ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil_in21k': _cfg( 'vit_base_patch16_224_miil_in21k': _cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
@ -133,11 +142,11 @@ default_cfgs = {
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
head_dim = dim // num_heads head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5 self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
@ -161,12 +170,11 @@ class Attention(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
self.attn = Attention( self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
@ -190,7 +198,7 @@ class VisionTransformer(nn.Module):
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
act_layer=None, weight_init=''): act_layer=None, weight_init=''):
""" """
@ -204,7 +212,6 @@ class VisionTransformer(nn.Module):
num_heads (int): number of attention heads num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
distilled (bool): model includes a distillation token and head as in DeiT models distilled (bool): model includes a distillation token and head as in DeiT models
drop_rate (float): dropout rate drop_rate (float): dropout rate
@ -233,8 +240,8 @@ class VisionTransformer(nn.Module):
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.Sequential(*[ self.blocks = nn.Sequential(*[
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
@ -254,16 +261,17 @@ class VisionTransformer(nn.Module):
if distilled: if distilled:
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
# Weight init self.init_weights(weight_init)
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
if self.dist_token is not None: if self.dist_token is not None:
trunc_normal_(self.dist_token, std=.02) trunc_normal_(self.dist_token, std=.02)
if weight_init.startswith('jax'): if mode.startswith('jax'):
# leave cls token as zeros to match jax impl # leave cls token as zeros to match jax impl
for n, m in self.named_modules(): named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
_init_vit_weights(m, n, head_bias=head_bias, jax_impl=True)
else: else:
trunc_normal_(self.cls_token, std=.02) trunc_normal_(self.cls_token, std=.02)
self.apply(_init_vit_weights) self.apply(_init_vit_weights)
@ -272,6 +280,10 @@ class VisionTransformer(nn.Module):
# this fn left here for compat with downstream users # this fn left here for compat with downstream users
_init_vit_weights(m) _init_vit_weights(m)
@torch.jit.ignore()
def load_pretrained(self, checkpoint_path, prefix=''):
_load_weights(self, checkpoint_path, prefix)
@torch.jit.ignore @torch.jit.ignore
def no_weight_decay(self): def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'dist_token'} return {'pos_embed', 'cls_token', 'dist_token'}
@ -317,39 +329,92 @@ class VisionTransformer(nn.Module):
return x return x
def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
""" ViT weight initialization """ ViT weight initialization
* When called without n, head_bias, jax_impl args it will behave exactly the same * When called without n, head_bias, jax_impl args it will behave exactly the same
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
""" """
if isinstance(m, nn.Linear): if isinstance(module, nn.Linear):
if n.startswith('head'): if name.startswith('head'):
nn.init.zeros_(m.weight) nn.init.zeros_(module.weight)
nn.init.constant_(m.bias, head_bias) nn.init.constant_(module.bias, head_bias)
elif n.startswith('pre_logits'): elif name.startswith('pre_logits'):
lecun_normal_(m.weight) lecun_normal_(module.weight)
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
else: else:
if jax_impl: if jax_impl:
nn.init.xavier_uniform_(m.weight) nn.init.xavier_uniform_(module.weight)
if m.bias is not None: if module.bias is not None:
if 'mlp' in n: if 'mlp' in name:
nn.init.normal_(m.bias, std=1e-6) nn.init.normal_(module.bias, std=1e-6)
else: else:
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
else: else:
trunc_normal_(m.weight, std=.02) trunc_normal_(module.weight, std=.02)
if m.bias is not None: if module.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
elif jax_impl and isinstance(m, nn.Conv2d): elif jax_impl and isinstance(module, nn.Conv2d):
# NOTE conv was left to pytorch default in my original init # NOTE conv was left to pytorch default in my original init
lecun_normal_(m.weight) lecun_normal_(module.weight)
if m.bias is not None: if module.bias is not None:
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
elif isinstance(m, nn.LayerNorm): elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
nn.init.zeros_(m.bias) nn.init.zeros_(module.bias)
nn.init.ones_(m.weight) nn.init.ones_(module.weight)
@torch.no_grad()
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
"""
import numpy as np
def _n2p(w, t=True):
if t and w.ndim == 4:
w = w.transpose([3, 2, 0, 1])
elif t and w.ndim == 3:
w = w.transpose([2, 0, 1])
elif t and w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w)
w = np.load(checkpoint_path)
if not prefix:
prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix
input_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(input_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False))
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T,
_n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T,
_n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T]))
block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1),
_n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1),
_n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel']))
block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias']))
block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel']))
block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
@ -417,23 +482,34 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
return model return model
@register_model
def vit_tiny_patch16_224(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs):
""" My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3. """ ViT-Small (ViT-S/16)
NOTE: NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
* this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6
* this model does not have a bias for QKV (unlike the official ViT and DeiT models)
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3.,
qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs)
if pretrained:
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
model_kwargs.setdefault('qk_scale', 768 ** -0.5)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch16_224(pretrained=False, **kwargs): def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
@ -569,86 +645,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
@register_model @register_model
def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): def deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_small_patch16_224(pretrained=False, **kwargs): def deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_base_patch16_224(pretrained=False, **kwargs): def deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_base_patch16_384(pretrained=False, **kwargs): def deit_base_patch16_384(pretrained=False, **kwargs):
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs): def deit_small_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs): def deit_base_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model return model
@register_model @register_model
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit. ImageNet-1k weights from https://github.com/facebookresearch/deit.
""" """
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer( model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model return model

@ -46,8 +46,8 @@ default_cfgs = {
input_size=(3, 384, 384), crop_pct=1.0), input_size=(3, 384, 384), crop_pct=1.0),
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
'vit_tiny_r_s16_p8_224': _cfg(), 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
'vit_small_r_s16_p8_224': _cfg(), 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
'vit_small_r20_s16_p2_224': _cfg(), 'vit_small_r20_s16_p2_224': _cfg(),
'vit_small_r20_s16_224': _cfg(), 'vit_small_r20_s16_224': _cfg(),
'vit_small_r26_s32_224': _cfg(), 'vit_small_r26_s32_224': _cfg(),
@ -57,10 +57,14 @@ default_cfgs = {
'vit_large_r50_s32_224': _cfg(), 'vit_large_r50_s32_224': _cfg(),
# hybrid models (using timm resnet backbones) # hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'vit_small_resnet26d_224': _cfg(
'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'vit_small_resnet50d_s16_224': _cfg(
'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
} }
@ -140,12 +144,6 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
# NOTE this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
@register_model @register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs): def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
@ -158,12 +156,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_resnet50_384(pretrained=False, **kwargs):
# NOTE this is forwarding to model def above for backwards compatibility
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
@register_model @register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.

Loading…
Cancel
Save