From 8880f696b6b8368a76296126476ea020fc7c814c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 12 Jun 2021 16:40:02 -0700 Subject: [PATCH] Refactoring, cleanup, improved test coverage. * Add eca_nfnet_l2 weights, 84.7 @ 384x384 * All 'non-std' (ie transformer / mlp) models have classifier / default_cfg test added * Fix #694 reset_classifer / num_features / forward_features / num_classes=0 consistency for transformer / mlp models * Add direct loading of npz to vision transformer (pure transformer so far, hybrid to come) * Rename vit_deit* to deit_* * Remove some deprecated vit hybrid model defs * Clean up classifier flatten for conv classifiers and unusual cases (mobilenetv3/ghostnet) * Remove explicit model fns for levit conv, just pass in arg --- tests/test_models.py | 55 ++++- timm/models/cait.py | 30 ++- timm/models/coat.py | 8 +- timm/models/convit.py | 23 +- timm/models/dla.py | 6 +- timm/models/dpn.py | 5 +- timm/models/ghostnet.py | 9 +- timm/models/helpers.py | 29 +++ timm/models/layers/adaptive_avgmax_pool.py | 13 +- timm/models/layers/classifier.py | 5 +- timm/models/layers/mlp.py | 6 + timm/models/levit.py | 87 ++++--- timm/models/mlp_mixer.py | 116 +++++---- timm/models/mobilenetv3.py | 5 +- timm/models/nfnet.py | 7 +- timm/models/pit.py | 32 ++- timm/models/registry.py | 13 +- timm/models/resnet.py | 7 +- timm/models/resnetv2.py | 266 ++++++++++----------- timm/models/swin_transformer.py | 43 ++-- timm/models/twins.py | 8 +- timm/models/visformer.py | 63 ++--- timm/models/vision_transformer.py | 228 ++++++++++++------ timm/models/vision_transformer_hybrid.py | 28 +-- 24 files changed, 637 insertions(+), 455 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 5a31935e..ac156806 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -17,7 +17,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*'] + 'convit_*', 'levit*', 'visformer*', 'deit*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -120,7 +120,6 @@ def test_model_default_cfgs(model_name, batch_size): state_dict = model.state_dict() cfg = model.default_cfg - classifier = cfg['classifier'] pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] @@ -149,7 +148,57 @@ def test_model_default_cfgs(model_name, batch_size): assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2] # check classifier name matches default_cfg - assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params' + classifier = cfg['classifier'] + if not isinstance(classifier, (tuple, list)): + classifier = classifier, + for c in classifier: + assert c + ".weight" in state_dict.keys(), f'{c} not in model params' + + # check first conv(s) names match default_cfg + first_conv = cfg['first_conv'] + if isinstance(first_conv, str): + first_conv = (first_conv,) + assert isinstance(first_conv, (tuple, list)) + for fc in first_conv: + assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params' + + +@pytest.mark.timeout(300) +@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS)) +@pytest.mark.parametrize('batch_size', [1]) +def test_model_default_cfgs_non_std(model_name, batch_size): + """Run a single forward pass with each model""" + model = create_model(model_name, pretrained=False) + model.eval() + state_dict = model.state_dict() + cfg = model.default_cfg + + input_size = _get_input_size(model_name=model_name, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") + + input_tensor = torch.randn((batch_size, *input_size)) + + # test forward_features (always unpooled) + outputs = model.forward_features(input_tensor) + if isinstance(outputs, tuple): + outputs = outputs[0] + assert outputs.shape[1] == model.num_features + + # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features + model.reset_classifier(0) + outputs = model.forward(input_tensor) + if isinstance(outputs, tuple): + outputs = outputs[0] + assert len(outputs.shape) == 2 + assert outputs.shape[1] == model.num_features + + # check classifier name matches default_cfg + classifier = cfg['classifier'] + if not isinstance(classifier, (tuple, list)): + classifier = classifier, + for c in classifier: + assert c + ".weight" in state_dict.keys(), f'{c} not in model params' # check first conv(s) names match default_cfg first_conv = cfg['first_conv'] diff --git a/timm/models/cait.py b/timm/models/cait.py index aa2e5f07..69b4ba06 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -74,11 +74,11 @@ default_cfgs = dict( class ClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to do CA - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.k = nn.Linear(dim, dim, bias=qkv_bias) @@ -110,13 +110,13 @@ class LayerScaleBlockClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add CA and LayerScale def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn, mlp_block=Mlp, init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_block( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -134,14 +134,14 @@ class LayerScaleBlockClassAttn(nn.Module): class TalkingHeadAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -177,13 +177,13 @@ class LayerScaleBlock(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add layerScale def __init__( - self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn, mlp_block=Mlp, init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) self.attn = attn_block( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -202,7 +202,7 @@ class Cait(nn.Module): # with slight modifications to adapt to our cait models def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., + num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), global_pool=None, @@ -235,14 +235,14 @@ class Cait(nn.Module): dpr = [drop_path_rate for i in range(depth)] self.blocks = nn.ModuleList([ block_layers( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale) for i in range(depth)]) self.blocks_token_only = nn.ModuleList([ block_layers_token( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, act_layer=act_layer, attn_block=attn_block_token_only, mlp_block=mlp_block_token_only, init_values=init_scale) @@ -270,6 +270,13 @@ class Cait(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) @@ -293,7 +300,6 @@ class Cait(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.head(x) - return x diff --git a/timm/models/coat.py b/timm/models/coat.py index 9eb384d8..f071715a 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -335,6 +335,8 @@ class CoaT(nn.Module): crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers self.out_features = out_features + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] self.num_classes = num_classes # Patch embeddings. @@ -441,10 +443,10 @@ class CoaT(nn.Module): # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) - self.head = nn.Linear(embed_dims[3], num_classes) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() else: # CoaT-Lite series: Use feature of last scale for classification. - self.head = nn.Linear(embed_dims[3], num_classes) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # Initialize weights. trunc_normal_(self.cls_token1, std=.02) @@ -471,7 +473,7 @@ class CoaT(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def insert_cls(self, x, cls_token): """ Insert CLS token. """ diff --git a/timm/models/convit.py b/timm/models/convit.py index b15b46d8..0593ec1c 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -57,13 +57,13 @@ default_cfgs = { class GPSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., locality_strength=1.): super().__init__() self.num_heads = num_heads self.dim = dim head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.locality_strength = locality_strength self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias) @@ -142,11 +142,11 @@ class GPSA(nn.Module): class MHSA(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -191,19 +191,16 @@ class MHSA(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs): super().__init__() self.norm1 = norm_layer(dim) self.use_gpsa = use_gpsa if self.use_gpsa: self.attn = GPSA( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, **kwargs) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, **kwargs) else: - self.attn = MHSA( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, - proj_drop=drop, **kwargs) + self.attn = MHSA(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) @@ -220,7 +217,7 @@ class ConViT(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., + num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None, local_up_to_layer=3, locality_strength=1., use_pos_embed=True): super().__init__() @@ -250,13 +247,13 @@ class ConViT(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.ModuleList([ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_gpsa=True, locality_strength=locality_strength) if i < local_up_to_layer else Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, use_gpsa=False) for i in range(depth)]) diff --git a/timm/models/dla.py b/timm/models/dla.py index f0f25b0b..f6e4dd28 100644 --- a/timm/models/dla.py +++ b/timm/models/dla.py @@ -288,6 +288,8 @@ class DLA(nn.Module): self.num_features = channels[-1] self.global_pool, self.fc = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + for m in self.modules(): if isinstance(m, nn.Conv2d): n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels @@ -314,6 +316,7 @@ class DLA(nn.Module): self.num_classes = num_classes self.global_pool, self.fc = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() def forward_features(self, x): x = self.base_layer(x) @@ -331,8 +334,7 @@ class DLA(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.fc(x) - if not self.global_pool.is_identity(): - x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + x = self.flatten(x) return x diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 90ef11cc..c4e380b1 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -237,6 +237,7 @@ class DPN(nn.Module): # Using 1x1 conv for the FC layer to allow the extra pooling scheme self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() def get_classifier(self): return self.classifier @@ -245,6 +246,7 @@ class DPN(nn.Module): self.num_classes = num_classes self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool, use_conv=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() def forward_features(self, x): return self.features(x) @@ -255,8 +257,7 @@ class DPN(nn.Module): if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) x = self.classifier(x) - if not self.global_pool.is_identity(): - x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) + x = self.flatten(x) return x diff --git a/timm/models/ghostnet.py b/timm/models/ghostnet.py index 48dee6ec..a73047c5 100644 --- a/timm/models/ghostnet.py +++ b/timm/models/ghostnet.py @@ -133,7 +133,7 @@ class GhostBottleneck(nn.Module): class GhostNet(nn.Module): - def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32): + def __init__(self, cfgs, num_classes=1000, width=1.0, dropout=0.2, in_chans=3, output_stride=32, global_pool='avg'): super(GhostNet, self).__init__() # setting of inverted residual blocks assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported' @@ -178,9 +178,10 @@ class GhostNet(nn.Module): # building last several layers self.num_features = out_chs = 1280 - self.global_pool = SelectAdaptivePool2d(pool_type='avg') + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True) self.act2 = nn.ReLU(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(out_chs, num_classes) def get_classifier(self): @@ -190,6 +191,7 @@ class GhostNet(nn.Module): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.pool_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -204,8 +206,7 @@ class GhostNet(nn.Module): def forward(self, x): x = self.forward_features(x) - if not self.global_pool.is_identity(): - x = x.view(x.size(0), -1) + x = self.flatten(x) if self.dropout > 0.: x = F.dropout(x, p=self.dropout, training=self.training) x = self.classifier(x) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index adfef550..662a7a48 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -45,6 +45,13 @@ def load_state_dict(checkpoint_path, use_ema=False): def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): + if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): + # numpy checkpoint, try to load via model specific load_pretrained fn + if hasattr(model, 'load_pretrained'): + model.load_pretrained(checkpoint_path) + else: + raise NotImplementedError('Model cannot load numpy checkpoint') + return state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) @@ -477,3 +484,25 @@ def model_parameters(model, exclude_head=False): return [p for p in model.parameters()][:-2] else: return model.parameters() + + +def named_apply(fn: Callable, module: nn.Module, name='', depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def named_modules(module: nn.Module, name='', depth_first=True, include_root=False): + if not depth_first and include_root: + yield name, module + for child_name, child_module in module.named_children(): + child_name = '.'.join((name, child_name)) if name else child_name + yield from named_modules( + module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + yield name, module diff --git a/timm/models/layers/adaptive_avgmax_pool.py b/timm/models/layers/adaptive_avgmax_pool.py index d2bb9f72..ebc6ada8 100644 --- a/timm/models/layers/adaptive_avgmax_pool.py +++ b/timm/models/layers/adaptive_avgmax_pool.py @@ -55,7 +55,7 @@ class FastAdaptiveAvgPool2d(nn.Module): self.flatten = flatten def forward(self, x): - return x.mean((2, 3)) if self.flatten else x.mean((2, 3), keepdim=True) + return x.mean((2, 3), keepdim=not self.flatten) class AdaptiveAvgMaxPool2d(nn.Module): @@ -82,13 +82,13 @@ class SelectAdaptivePool2d(nn.Module): def __init__(self, output_size=1, pool_type='fast', flatten=False): super(SelectAdaptivePool2d, self).__init__() self.pool_type = pool_type or '' # convert other falsy values to empty string for consistent TS typing - self.flatten = flatten + self.flatten = nn.Flatten(1) if flatten else nn.Identity() if pool_type == '': self.pool = nn.Identity() # pass through elif pool_type == 'fast': assert output_size == 1 - self.pool = FastAdaptiveAvgPool2d(self.flatten) - self.flatten = False + self.pool = FastAdaptiveAvgPool2d(flatten) + self.flatten = nn.Identity() elif pool_type == 'avg': self.pool = nn.AdaptiveAvgPool2d(output_size) elif pool_type == 'avgmax': @@ -101,12 +101,11 @@ class SelectAdaptivePool2d(nn.Module): assert False, 'Invalid pool type: %s' % pool_type def is_identity(self): - return self.pool_type == '' + return not self.pool_type def forward(self, x): x = self.pool(x) - if self.flatten: - x = x.flatten(1) + x = self.flatten(x) return x def feat_mult(self): diff --git a/timm/models/layers/classifier.py b/timm/models/layers/classifier.py index 516cc6c9..2b745413 100644 --- a/timm/models/layers/classifier.py +++ b/timm/models/layers/classifier.py @@ -20,7 +20,7 @@ def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): return global_pool, num_pooled_features -def _create_fc(num_features, num_classes, pool_type='avg', use_conv=False): +def _create_fc(num_features, num_classes, use_conv=False): if num_classes <= 0: fc = nn.Identity() # pass-through (no classifier) elif use_conv: @@ -45,11 +45,12 @@ class ClassifierHead(nn.Module): self.drop_rate = drop_rate self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) - self.flatten_after_fc = use_conv and pool_type + self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() def forward(self, x): x = self.global_pool(x) if self.drop_rate: x = F.dropout(x, p=float(self.drop_rate), training=self.training) x = self.fc(x) + x = self.flatten(x) return x diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index 4739ba74..05d07652 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -40,6 +40,12 @@ class GluMlp(nn.Module): self.fc2 = nn.Linear(hidden_features // 2, out_features) self.drop = nn.Dropout(drop) + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + def forward(self, x): x = self.fc1(x) x, gates = x.chunk(2, dim=-1) diff --git a/timm/models/levit.py b/timm/models/levit.py index 2180254a..fa35f41f 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -84,63 +84,33 @@ __all__ = ['Levit'] @register_model -def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): +def levit_128s(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_128s', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_128(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_128', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_192(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_192', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_256(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_256', pretrained=pretrained, use_conv=use_conv, **kwargs) @register_model -def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): +def levit_384(pretrained=False, use_conv=False, **kwargs): return create_levit( - 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs): - return create_levit( - 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) - - -@register_model -def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): - return create_levit( - 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + 'levit_384', pretrained=pretrained, use_conv=use_conv, **kwargs) class ConvNorm(nn.Sequential): @@ -427,6 +397,9 @@ class AttentionSubsample(nn.Module): class Levit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage + + NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems + w/ train scripts that don't take tuple outputs, """ def __init__( @@ -447,7 +420,8 @@ class Levit(nn.Module): attn_act_layer='hard_swish', distillation=True, use_conv=False, - drop_path=0): + drop_rate=0., + drop_path_rate=0.): super().__init__() act_layer = get_act_layer(act_layer) attn_act_layer = get_act_layer(attn_act_layer) @@ -486,7 +460,7 @@ class Levit(nn.Module): Attention( ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution, use_conv=use_conv), - drop_path)) + drop_path_rate)) if mr > 0: h = int(ed * mr) self.blocks.append( @@ -494,7 +468,7 @@ class Levit(nn.Module): ln_layer(ed, h, resolution=resolution), act_layer(), ln_layer(h, ed, bn_weight_init=0, resolution=resolution), - ), drop_path)) + ), drop_path_rate)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) resolution_ = (resolution - 1) // do[5] + 1 @@ -511,26 +485,45 @@ class Levit(nn.Module): ln_layer(embed_dim[i + 1], h, resolution=resolution), act_layer(), ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), - ), drop_path)) + ), drop_path_rate)) self.blocks = nn.Sequential(*self.blocks) # Classifier head self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None if distillation: self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() - else: - self.head_dist = None @torch.jit.ignore def no_weight_decay(self): return {x for x in self.state_dict().keys() if 'attention_biases' in x} - def forward(self, x): + def get_classifier(self): + if self.head_dist is None: + return self.head + else: + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool='', distillation=None): + self.num_classes = num_classes + self.head = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + if distillation is not None: + self.distillation = distillation + if self.distillation: + self.head_dist = NormLinear(self.embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + else: + self.head_dist = None + + def forward_features(self, x): x = self.patch_embed(x) if not self.use_conv: x = x.flatten(2).transpose(1, 2) x = self.blocks(x) x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + return x + + def forward(self, x): + x = self.forward_features(x) if self.head_dist is not None: x, x_dist = self.head(x), self.head_dist(x) if self.training and not torch.jit.is_scripting(): diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 6f53264a..ea6de824 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -45,7 +45,7 @@ import torch import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg, overlay_external_default_cfg, named_apply from .layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple from .registry import register_model @@ -169,6 +169,11 @@ class SpatialGatingUnit(nn.Module): self.norm = norm_layer(gate_dim) self.proj = nn.Linear(seq_len, seq_len) + def init_weights(self): + # special init for the projection gate, called as override by base model init + nn.init.normal_(self.proj.weight, std=1e-6) + nn.init.ones_(self.proj.bias) + def forward(self, x): u, v = x.chunk(2, dim=-1) v = self.norm(v) @@ -205,7 +210,7 @@ class MlpMixer(nn.Module): in_chans=3, patch_size=16, num_blocks=8, - hidden_dim=512, + embed_dim=512, mlp_ratio=(0.5, 4.0), block_layer=MixerBlock, mlp_layer=Mlp, @@ -218,59 +223,71 @@ class MlpMixer(nn.Module): ): super().__init__() self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.stem = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=hidden_dim, - norm_layer=norm_layer if stem_norm else None) + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=norm_layer if stem_norm else None) # FIXME drop_path (stochastic depth scaling rule or all the same?) self.blocks = nn.Sequential(*[ block_layer( - hidden_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, + embed_dim, self.stem.num_patches, mlp_ratio, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, drop=drop_rate, drop_path=drop_path_rate) for _ in range(num_blocks)]) - self.norm = norm_layer(hidden_dim) - self.head = nn.Linear(hidden_dim, self.num_classes) # zero init + self.norm = norm_layer(embed_dim) + self.head = nn.Linear(embed_dim, self.num_classes) # zero init self.init_weights(nlhb=nlhb) def init_weights(self, nlhb=False): head_bias = -math.log(self.num_classes) if nlhb else 0. - for n, m in self.named_modules(): - _init_weights(m, n, head_bias=head_bias) + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first - def forward(self, x): + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) x = x.mean(dim=1) + return x + + def forward(self, x): + x = self.forward_features(x) x = self.head(x) return x -def _init_weights(m, n: str, head_bias: float = 0.): +def _init_weights(module: nn.Module, name: str, head_bias: float = 0.): """ Mixer weight initialization (trying to match Flax defaults) """ - if isinstance(m, nn.Linear): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.endswith('gate.proj'): - nn.init.normal_(m.weight, std=1e-4) - nn.init.ones_(m.bias) + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) else: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, std=1e-6) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) else: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + # NOTE if a parent module contains init_weights method, it can override the init of the + # child modules as this will be called in depth-first order. + module.init_weights() def _create_mixer(variant, pretrained=False, **kwargs): @@ -289,7 +306,7 @@ def mixer_s32_224(pretrained=False, **kwargs): """ Mixer-S/32 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=8, hidden_dim=512, **kwargs) + model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs) model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args) return model @@ -299,7 +316,7 @@ def mixer_s16_224(pretrained=False, **kwargs): """ Mixer-S/16 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=8, hidden_dim=512, **kwargs) + model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs) model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args) return model @@ -309,7 +326,7 @@ def mixer_b32_224(pretrained=False, **kwargs): """ Mixer-B/32 224x224 Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args) return model @@ -319,7 +336,7 @@ def mixer_b16_224(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args) return model @@ -329,7 +346,7 @@ def mixer_b16_224_in21k(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_in21k', pretrained=pretrained, **model_args) return model @@ -339,7 +356,7 @@ def mixer_l32_224(pretrained=False, **kwargs): """ Mixer-L/32 224x224. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=32, num_blocks=24, hidden_dim=1024, **kwargs) + model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs) model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args) return model @@ -349,7 +366,7 @@ def mixer_l16_224(pretrained=False, **kwargs): """ Mixer-L/16 224x224. ImageNet-1k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args) return model @@ -359,35 +376,38 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs): """ Mixer-L/16 224x224. ImageNet-21k pretrained weights. Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601 """ - model_args = dict(patch_size=16, num_blocks=24, hidden_dim=1024, **kwargs) + model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs) model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args) return model + @register_model def mixer_b16_224_miil(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-21k pretrained weights. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args) return model + @register_model def mixer_b16_224_miil_in21k(pretrained=False, **kwargs): """ Mixer-B/16 224x224. ImageNet-1k pretrained weights. Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs) + model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs) model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args) return model + @register_model def gmixer_12_224(pretrained=False, **kwargs): """ Glu-Mixer-12 224x224 (short & fat) Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer """ model_args = dict( - patch_size=20, num_blocks=12, hidden_dim=512, mlp_ratio=(1.0, 6.0), + patch_size=16, num_blocks=12, embed_dim=512, mlp_ratio=(1.0, 6.0), mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args) return model @@ -399,7 +419,7 @@ def gmixer_24_224(pretrained=False, **kwargs): Experiment by Ross Wightman, adding (Si)GLU to MLP-Mixer """ model_args = dict( - patch_size=20, num_blocks=24, hidden_dim=384, mlp_ratio=(1.0, 6.0), + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 6.0), mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs) model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args) return model @@ -411,7 +431,7 @@ def resmlp_12_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=12, hidden_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) + patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args) return model @@ -422,7 +442,7 @@ def resmlp_24_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=24, hidden_dim=384, mlp_ratio=4, + patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4, block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args) return model @@ -434,7 +454,7 @@ def resmlp_36_224(pretrained=False, **kwargs): Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404 """ model_args = dict( - patch_size=16, num_blocks=36, hidden_dim=384, mlp_ratio=4, + patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4, block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs) model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args) return model @@ -446,7 +466,7 @@ def gmlp_ti16_224(pretrained=False, **kwargs): Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ model_args = dict( - patch_size=16, num_blocks=30, hidden_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, + patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock, mlp_layer=GatedMlp, **kwargs) model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args) return model @@ -458,7 +478,7 @@ def gmlp_s16_224(pretrained=False, **kwargs): Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ model_args = dict( - patch_size=16, num_blocks=30, hidden_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, + patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock, mlp_layer=GatedMlp, **kwargs) model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args) return model @@ -470,7 +490,7 @@ def gmlp_b16_224(pretrained=False, **kwargs): Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050 """ model_args = dict( - patch_size=16, num_blocks=30, hidden_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, + patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock, mlp_layer=GatedMlp, **kwargs) model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e85112e6..f810eb82 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -119,6 +119,7 @@ class MobileNetV3(nn.Module): num_pooled_chs = head_chs * self.global_pool.feat_mult() self.conv_head = create_conv2d(num_pooled_chs, self.num_features, 1, padding=pad_type, bias=head_bias) self.act2 = act_layer(inplace=True) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() efficientnet_init_weights(self) @@ -137,6 +138,7 @@ class MobileNetV3(nn.Module): self.num_classes = num_classes # cannot meaningfully change pooling of efficient head after creation self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -151,8 +153,7 @@ class MobileNetV3(nn.Module): def forward(self, x): x = self.forward_features(x) - if not self.global_pool.is_identity(): - x = x.flatten(1) + x = self.flatten(x) if self.drop_rate > 0.: x = F.dropout(x, p=self.drop_rate, training=self.training) return self.classifier(x) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index fc0a20c2..4e0f2b21 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -111,11 +111,11 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), eca_nfnet_l2=_dcfg( - url='', - pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l2_ra3-da781a61.pth', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), eca_nfnet_l3=_dcfg( url='', - pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), + pool_size=(11, 11), input_size=(3, 352, 352), test_input_size=(3, 448, 448), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), @@ -210,6 +210,7 @@ def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', ski return cfg + model_cfgs = dict( # NFNet-F models w/ GELU compatible with DeepMind weights dm_nfnet_f0=_dm_nfnet_cfg(depths=(1, 2, 6, 3)), diff --git a/timm/models/pit.py b/timm/models/pit.py index 9c350861..460824e2 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -186,12 +186,13 @@ class PoolingVisionTransformer(nn.Module): ] self.transformers = SequentialTuple(*transformers) self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6) - self.embed_dim = base_dims[-1] * heads[-1] + self.num_features = self.embed_dim = base_dims[-1] * heads[-1] # Classifier head self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and distilled else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) @@ -207,13 +208,16 @@ class PoolingVisionTransformer(nn.Module): return {'pos_embed', 'cls_token'} def get_classifier(self): - return self.head + if self.head_dist is not None: + return self.head, self.head_dist + else: + return self.head def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ - if num_classes > 0 and self.num_tokens == 2 else nn.Identity() + if self.head_dist is not None: + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) @@ -221,19 +225,21 @@ class PoolingVisionTransformer(nn.Module): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x, cls_tokens = self.transformers((x, cls_tokens)) cls_tokens = self.norm(cls_tokens) - return cls_tokens + if self.head_dist is not None: + return cls_tokens[:, 0], cls_tokens[:, 1] + else: + return cls_tokens[:, 0] def forward(self, x): x = self.forward_features(x) - x_cls = self.head(x[:, 0]) - if self.num_tokens > 1: - x_dist = self.head_dist(x[:, 1]) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): - return x_cls, x_dist + return x, x_dist else: - return (x_cls + x_dist) / 2 + return (x + x_dist) / 2 else: - return x_cls + return self.head(x) def checkpoint_filter_fn(state_dict, model): diff --git a/timm/models/registry.py b/timm/models/registry.py index 6927b6d6..f92219b2 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -65,11 +65,18 @@ def list_models(filter='', module='', pretrained=False, exclude_filters='', name model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module """ if module: - models = list(_module_to_models[module]) + all_models = list(_module_to_models[module]) else: - models = _model_entrypoints.keys() + all_models = _model_entrypoints.keys() if filter: - models = fnmatch.filter(models, filter) # include these models + models = [] + include_filters = filter if isinstance(filter, (tuple, list)) else [filter] + for f in include_filters: + include_models = fnmatch.filter(all_models, f) # include these models + if len(include_models): + models = set(models).union(include_models) + else: + models = all_models if exclude_filters: if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2f02f12a..66baa37a 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -638,12 +638,15 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.init_weights(zero_init_last_bn=zero_init_last_bn) + + def init_weights(self, zero_init_last_bn=True): for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - nn.init.constant_(m.weight, 1.) - nn.init.constant_(m.bias, 0.) + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) if zero_init_last_bn: for m in self.modules(): if hasattr(m, 'zero_init_last_bn'): diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 250695a8..84b16bb2 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -35,9 +35,9 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .registry import register_model -from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d +from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d def _cfg(url='', **kwargs): @@ -86,20 +86,10 @@ default_cfgs = { url='https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz', num_classes=21843), - - # trained on imagenet-1k, NOTE not overly interesting set of weights, leaving disabled for now - # 'resnetv2_50x1_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R50x1.npz'), - # 'resnetv2_50x3_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R50x3.npz'), - # 'resnetv2_101x1_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'), - # 'resnetv2_101x3_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R101x3.npz'), - # 'resnetv2_152x2_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R152x2.npz'), - # 'resnetv2_152x4_bits': _cfg( - # url='https://storage.googleapis.com/bit_models/BiT-S-R152x4.npz'), + 'resnetv2_50': _cfg( + input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic'), + 'resnetv2_50d': _cfg( + input_size=(3, 224, 224), crop_pct=0.875, interpolation='bicubic', first_conv='stem.conv1'), } @@ -111,13 +101,6 @@ def make_div(v, divisor=8): return new_v -def tf2th(conv_weights): - """Possibly convert HWIO to OIHW.""" - if conv_weights.ndim == 4: - conv_weights = conv_weights.transpose([3, 2, 0, 1]) - return torch.from_numpy(conv_weights) - - class PreActBottleneck(nn.Module): """Pre-activation (v2) bottleneck block. @@ -152,6 +135,9 @@ class PreActBottleneck(nn.Module): self.conv3 = conv_layer(mid_chs, out_chs, 1) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + def zero_init_last_bn(self): + nn.init.zeros_(self.norm3.weight) + def forward(self, x): x_preact = self.norm1(x) @@ -198,6 +184,9 @@ class Bottleneck(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() self.act3 = act_layer(inplace=True) + def zero_init_last_bn(self): + nn.init.zeros_(self.norm3.weight) + def forward(self, x): # shortcut branch shortcut = x @@ -276,7 +265,7 @@ class ResNetStage(nn.Module): def create_resnetv2_stem( in_chs, out_chs=64, stem_type='', preact=True, - conv_layer=partial(StdConv2d, eps=1e-8), norm_layer=partial(GroupNormAct, num_groups=32)): + conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') @@ -285,14 +274,17 @@ def create_resnetv2_stem( # A 3 deep 3x3 conv stack as in ResNet V1D models mid_chs = out_chs // 2 stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) + stem['norm1'] = norm_layer(mid_chs) stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) + stem['norm2'] = norm_layer(mid_chs) stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + if not preact: + stem['norm3'] = norm_layer(out_chs) else: # The usual 7x7 stem conv stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) - - if not preact: - stem['norm'] = norm_layer(out_chs) + if not preact: + stem['norm'] = norm_layer(out_chs) if 'fixed' in stem_type: # 'fixed' SAME padding approximation that is used in BiT models @@ -312,11 +304,12 @@ class ResNetV2(nn.Module): """Implementation of Pre-activation (v2) ResNet mode. """ - def __init__(self, layers, channels=(256, 512, 1024, 2048), - num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, - act_layer=nn.ReLU, conv_layer=partial(StdConv2d, eps=1e-8), - norm_layer=partial(GroupNormAct, num_groups=32), drop_rate=0., drop_path_rate=0.): + def __init__( + self, layers, channels=(256, 512, 1024, 2048), + num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True, + act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + drop_rate=0., drop_path_rate=0., zero_init_last_bn=True): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate @@ -354,12 +347,14 @@ class ResNetV2(nn.Module): self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate, use_conv=True) - for n, m in self.named_modules(): - if isinstance(m, nn.Linear) or ('.fc' in n and isinstance(m, nn.Conv2d)): - nn.init.normal_(m.weight, mean=0.0, std=0.01) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + self.init_weights(zero_init_last_bn=zero_init_last_bn) + + def init_weights(self, zero_init_last_bn=True): + named_apply(partial(_init_weights, zero_init_last_bn=zero_init_last_bn), self) + + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix='resnet/'): + _load_weights(self, checkpoint_path, prefix) def get_classifier(self): return self.head.fc @@ -378,41 +373,59 @@ class ResNetV2(nn.Module): def forward(self, x): x = self.forward_features(x) x = self.head(x) - if not self.head.global_pool.is_identity(): - x = x.flatten(1) # conv classifier, flatten if pooling isn't pass-through (disabled) return x - def load_pretrained(self, checkpoint_path, prefix='resnet/'): - import numpy as np - weights = np.load(checkpoint_path) - with torch.no_grad(): - stem_conv_w = tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel']) - if self.stem.conv.weight.shape[1] == 1: - self.stem.conv.weight.copy_(stem_conv_w.sum(dim=1, keepdim=True)) - # FIXME handle > 3 in_chans? - else: - self.stem.conv.weight.copy_(stem_conv_w) - self.norm.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma'])) - self.norm.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta'])) - if self.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: - self.head.fc.weight.copy_(tf2th(weights[f'{prefix}head/conv2d/kernel'])) - self.head.fc.bias.copy_(tf2th(weights[f'{prefix}head/conv2d/bias'])) - for i, (sname, stage) in enumerate(self.stages.named_children()): - for j, (bname, block) in enumerate(stage.blocks.named_children()): - convname = 'standardized_conv2d' - block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' - block.conv1.weight.copy_(tf2th(weights[f'{block_prefix}a/{convname}/kernel'])) - block.conv2.weight.copy_(tf2th(weights[f'{block_prefix}b/{convname}/kernel'])) - block.conv3.weight.copy_(tf2th(weights[f'{block_prefix}c/{convname}/kernel'])) - block.norm1.weight.copy_(tf2th(weights[f'{block_prefix}a/group_norm/gamma'])) - block.norm2.weight.copy_(tf2th(weights[f'{block_prefix}b/group_norm/gamma'])) - block.norm3.weight.copy_(tf2th(weights[f'{block_prefix}c/group_norm/gamma'])) - block.norm1.bias.copy_(tf2th(weights[f'{block_prefix}a/group_norm/beta'])) - block.norm2.bias.copy_(tf2th(weights[f'{block_prefix}b/group_norm/beta'])) - block.norm3.bias.copy_(tf2th(weights[f'{block_prefix}c/group_norm/beta'])) - if block.downsample is not None: - w = weights[f'{block_prefix}a/proj/{convname}/kernel'] - block.downsample.conv.weight.copy_(tf2th(w)) + +def _init_weights(module: nn.Module, name: str = '', zero_init_last_bn=True): + if isinstance(module, nn.Linear) or ('head.fc' in name and isinstance(module, nn.Conv2d)): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.BatchNorm2d, nn.LayerNorm, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif zero_init_last_bn and hasattr(module, 'zero_init_last_bn'): + module.zero_init_last_bn() + + +@torch.no_grad() +def _load_weights(model: nn.Module, checkpoint_path: str, prefix: str = 'resnet/'): + import numpy as np + + def t2p(conv_weights): + """Possibly convert HWIO to OIHW.""" + if conv_weights.ndim == 4: + conv_weights = conv_weights.transpose([3, 2, 0, 1]) + return torch.from_numpy(conv_weights) + + weights = np.load(checkpoint_path) + stem_conv_w = adapt_input_conv( + model.stem.conv.weight.shape[1], t2p(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) + model.stem.conv.weight.copy_(stem_conv_w) + model.norm.weight.copy_(t2p(weights[f'{prefix}group_norm/gamma'])) + model.norm.bias.copy_(t2p(weights[f'{prefix}group_norm/beta'])) + if model.head.fc.weight.shape[0] == weights[f'{prefix}head/conv2d/kernel'].shape[-1]: + model.head.fc.weight.copy_(t2p(weights[f'{prefix}head/conv2d/kernel'])) + model.head.fc.bias.copy_(t2p(weights[f'{prefix}head/conv2d/bias'])) + for i, (sname, stage) in enumerate(model.stages.named_children()): + for j, (bname, block) in enumerate(stage.blocks.named_children()): + cname = 'standardized_conv2d' + block_prefix = f'{prefix}block{i + 1}/unit{j + 1:02d}/' + block.conv1.weight.copy_(t2p(weights[f'{block_prefix}a/{cname}/kernel'])) + block.conv2.weight.copy_(t2p(weights[f'{block_prefix}b/{cname}/kernel'])) + block.conv3.weight.copy_(t2p(weights[f'{block_prefix}c/{cname}/kernel'])) + block.norm1.weight.copy_(t2p(weights[f'{block_prefix}a/group_norm/gamma'])) + block.norm2.weight.copy_(t2p(weights[f'{block_prefix}b/group_norm/gamma'])) + block.norm3.weight.copy_(t2p(weights[f'{block_prefix}c/group_norm/gamma'])) + block.norm1.bias.copy_(t2p(weights[f'{block_prefix}a/group_norm/beta'])) + block.norm2.bias.copy_(t2p(weights[f'{block_prefix}b/group_norm/beta'])) + block.norm3.bias.copy_(t2p(weights[f'{block_prefix}c/group_norm/beta'])) + if block.downsample is not None: + w = weights[f'{block_prefix}a/proj/{cname}/kernel'] + block.downsample.conv.weight.copy_(t2p(w)) def _create_resnetv2(variant, pretrained=False, **kwargs): @@ -425,130 +438,99 @@ def _create_resnetv2(variant, pretrained=False, **kwargs): **kwargs) +def _create_resnetv2_bit(variant, pretrained=False, **kwargs): + return _create_resnetv2( + variant, pretrained=pretrained, stem_type='fixed', conv_layer=partial(StdConv2d, eps=1e-8), **kwargs) + + @register_model def resnetv2_50x1_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50x1_bitm', pretrained=pretrained, - layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_50x1_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=1, **kwargs) @register_model def resnetv2_50x3_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50x3_bitm', pretrained=pretrained, - layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_50x3_bitm', pretrained=pretrained, layers=[3, 4, 6, 3], width_factor=3, **kwargs) @register_model def resnetv2_101x1_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_101x1_bitm', pretrained=pretrained, - layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_101x1_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=1, **kwargs) @register_model def resnetv2_101x3_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_101x3_bitm', pretrained=pretrained, - layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_101x3_bitm', pretrained=pretrained, layers=[3, 4, 23, 3], width_factor=3, **kwargs) @register_model def resnetv2_152x2_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_152x2_bitm', pretrained=pretrained, - layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_152x2_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=2, **kwargs) @register_model def resnetv2_152x4_bitm(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_152x4_bitm', pretrained=pretrained, - layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + return _create_resnetv2_bit( + 'resnetv2_152x4_bitm', pretrained=pretrained, layers=[3, 8, 36, 3], width_factor=4, **kwargs) @register_model def resnetv2_50x1_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_50x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) + layers=[3, 4, 6, 3], width_factor=1, **kwargs) @register_model def resnetv2_50x3_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_50x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) + layers=[3, 4, 6, 3], width_factor=3, **kwargs) @register_model def resnetv2_101x1_bitm_in21k(pretrained=False, **kwargs): return _create_resnetv2( 'resnetv2_101x1_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) + layers=[3, 4, 23, 3], width_factor=1, **kwargs) @register_model def resnetv2_101x3_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_101x3_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) + layers=[3, 4, 23, 3], width_factor=3, **kwargs) @register_model def resnetv2_152x2_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_152x2_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) + layers=[3, 8, 36, 3], width_factor=2, **kwargs) @register_model def resnetv2_152x4_bitm_in21k(pretrained=False, **kwargs): - return _create_resnetv2( + return _create_resnetv2_bit( 'resnetv2_152x4_bitm_in21k', pretrained=pretrained, num_classes=kwargs.pop('num_classes', 21843), - layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) + layers=[3, 8, 36, 3], width_factor=4, **kwargs) -# NOTE the 'S' versions of the model weights arent as interesting as original 21k or transfer to 1K M. +@register_model +def resnetv2_50(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, **kwargs) -# @register_model -# def resnetv2_50x1_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_50x1_bits', pretrained=pretrained, -# layers=[3, 4, 6, 3], width_factor=1, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_50x3_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_50x3_bits', pretrained=pretrained, -# layers=[3, 4, 6, 3], width_factor=3, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_101x1_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_101x1_bits', pretrained=pretrained, -# layers=[3, 4, 23, 3], width_factor=1, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_101x3_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_101x3_bits', pretrained=pretrained, -# layers=[3, 4, 23, 3], width_factor=3, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_152x2_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_152x2_bits', pretrained=pretrained, -# layers=[3, 8, 36, 3], width_factor=2, stem_type='fixed', **kwargs) -# -# -# @register_model -# def resnetv2_152x4_bits(pretrained=False, **kwargs): -# return _create_resnetv2( -# 'resnetv2_152x4_bits', pretrained=pretrained, -# layers=[3, 8, 36, 3], width_factor=4, stem_type='fixed', **kwargs) -# + +@register_model +def resnetv2_50d(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=nn.BatchNorm2d, + stem_type='deep', avg_down=True, **kwargs) diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index a845f505..2ee106d2 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -126,19 +126,18 @@ class WindowAttention(nn.Module): window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 proj_drop (float, optional): Dropout ratio of output. Default: 0.0 """ - def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( @@ -210,7 +209,6 @@ class SwinTransformerBlock(nn.Module): shift_size (int): Shift size for SW-MSA. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float, optional): Stochastic depth rate. Default: 0.0 @@ -219,7 +217,7 @@ class SwinTransformerBlock(nn.Module): """ def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim @@ -236,8 +234,8 @@ class SwinTransformerBlock(nn.Module): self.norm1 = norm_layer(dim) self.attn = WindowAttention( - dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, - qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, qkv_bias=qkv_bias, + attn_drop=attn_drop, proj_drop=drop) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) @@ -369,7 +367,6 @@ class BasicLayer(nn.Module): window_size (int): Local window size. mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. drop (float, optional): Dropout rate. Default: 0.0 attn_drop (float, optional): Attention dropout rate. Default: 0.0 drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 @@ -379,7 +376,7 @@ class BasicLayer(nn.Module): """ def __init__(self, dim, input_resolution, depth, num_heads, window_size, - mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False): super().__init__() @@ -390,14 +387,11 @@ class BasicLayer(nn.Module): # build blocks self.blocks = nn.ModuleList([ - SwinTransformerBlock(dim=dim, input_resolution=input_resolution, - num_heads=num_heads, window_size=window_size, - shift_size=0 if (i % 2 == 0) else window_size // 2, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop, attn_drop=attn_drop, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, - norm_layer=norm_layer) + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer) for i in range(depth)]) # patch merging layer @@ -436,7 +430,6 @@ class SwinTransformer(nn.Module): window_size (int): Window size. Default: 7 mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True - qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None drop_rate (float): Dropout rate. Default: 0 attn_drop_rate (float): Attention dropout rate. Default: 0 drop_path_rate (float): Stochastic depth rate. Default: 0.1 @@ -448,7 +441,7 @@ class SwinTransformer(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), - window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + window_size=7, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, ape=False, patch_norm=True, use_checkpoint=False, weight_init='', **kwargs): @@ -491,8 +484,9 @@ class SwinTransformer(nn.Module): num_heads=num_heads[i_layer], window_size=window_size, mlp_ratio=self.mlp_ratio, - qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], norm_layer=norm_layer, downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, @@ -520,6 +514,13 @@ class SwinTransformer(nn.Module): def no_weight_decay_keywords(self): return {'relative_position_bias_table'} + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + def forward_features(self, x): x = self.patch_embed(x) if self.absolute_pos_embed is not None: diff --git a/timm/models/twins.py b/timm/models/twins.py index 793d2ede..4aed09d9 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -278,6 +278,8 @@ class Twins(nn.Module): super().__init__() self.num_classes = num_classes self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] img_size = to_2tuple(img_size) prev_chs = in_chans @@ -303,10 +305,10 @@ class Twins(nn.Module): self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) - self.norm = norm_layer(embed_dims[-1]) + self.norm = norm_layer(self.num_features) # classification head - self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # init weights self.apply(self._init_weights) @@ -320,7 +322,7 @@ class Twins(nn.Module): def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def _init_weights(self, m): if isinstance(m, nn.Linear): diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 5583ea3c..16631027 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d, create_classifier from .registry import register_model @@ -140,14 +140,14 @@ class Visformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', - vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): + vit_stem=False, group=8, global_pool='avg', conv_init=False, embed_norm=None): super().__init__() + img_size = to_2tuple(img_size) self.num_classes = num_classes - self.num_features = self.embed_dim = embed_dim + self.embed_dim = embed_dim self.init_channels = init_channels self.img_size = img_size self.vit_stem = vit_stem - self.pool = pool self.conv_init = conv_init if isinstance(depth, (list, tuple)): self.stage_num1, self.stage_num2, self.stage_num3 = depth @@ -164,31 +164,31 @@ class Visformer(nn.Module): self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size //= 16 + img_size = [x // 16 for x in img_size] else: if self.init_channels is None: self.stem = None self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size //= 8 + img_size = [x // 8 for x in img_size] else: self.stem = nn.Sequential( nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), nn.BatchNorm2d(self.init_channels), nn.ReLU(inplace=True) ) - img_size //= 2 + img_size = [x // 2 for x in img_size] self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size //= 4 + img_size = [x // 4 for x in img_size] if self.pos_embed: if self.vit_stem: - self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) else: - self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, *img_size)) self.pos_drop = nn.Dropout(p=drop_rate) self.stage1 = nn.ModuleList([ Block( @@ -199,14 +199,14 @@ class Visformer(nn.Module): for i in range(self.stage_num1) ]) - #stage2 + # stage2 if not self.vit_stem: self.patch_embed2 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size //= 2 + img_size = [x // 2 for x in img_size] if self.pos_embed: - self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) self.stage2 = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, @@ -221,9 +221,9 @@ class Visformer(nn.Module): self.patch_embed3 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) - img_size //= 2 + img_size = [x // 2 for x in img_size] if self.pos_embed: - self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) self.stage3 = nn.ModuleList([ Block( dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, @@ -234,11 +234,10 @@ class Visformer(nn.Module): ]) # head - if self.pool: - self.global_pooling = nn.AdaptiveAvgPool2d(1) - head_dim = embed_dim if self.vit_stem else embed_dim * 2 - self.norm = norm_layer(head_dim) - self.head = nn.Linear(head_dim, num_classes) + self.num_features = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(self.num_features) + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + self.head = nn.Linear(self.num_features, num_classes) # weights init if self.pos_embed: @@ -267,7 +266,14 @@ class Visformer(nn.Module): if m.bias is not None: nn.init.constant_(m.bias, 0.) - def forward(self, x): + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool='avg'): + self.num_classes = num_classes + self.global_pool, self.head = create_classifier(self.num_features, self.num_classes, pool_type=global_pool) + + def forward_features(self, x): if self.stem is not None: x = self.stem(x) @@ -297,14 +303,13 @@ class Visformer(nn.Module): for b in self.stage3: x = b(x) - # head x = self.norm(x) - if self.pool: - x = self.global_pooling(x) - else: - x = x[:, :, 0, 0] + return x - x = self.head(x.view(x.size(0), -1)) + def forward(self, x): + x = self.forward_features(x) + x = self.global_pool(x) + x = self.head(x) return x @@ -321,7 +326,7 @@ def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): @register_model def visformer_tiny(pretrained=False, **kwargs): model_cfg = dict( - img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) @@ -331,7 +336,7 @@ def visformer_tiny(pretrained=False, **kwargs): @register_model def visformer_small(pretrained=False, **kwargs): model_cfg = dict( - img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff74d836..c44358df 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,7 +28,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from .registry import register_model @@ -47,9 +47,18 @@ def _cfg(url='', **kwargs): default_cfgs = { - # patch models (my experiments) + # FIXME weights coming + 'vit_tiny_patch16_224': _cfg( + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), 'vit_small_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), + 'vit_small_patch32_224': _cfg( + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), # patch models (weights ported from official Google JAX impl) @@ -97,29 +106,29 @@ default_cfgs = { num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # deit models (FB weights) - 'vit_deit_tiny_patch16_224': _cfg( + 'deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), - 'vit_deit_small_patch16_224': _cfg( + 'deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), - 'vit_deit_base_patch16_224': _cfg( + 'deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), - 'vit_deit_base_patch16_384': _cfg( + 'deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', input_size=(3, 384, 384), crop_pct=1.0), - 'vit_deit_tiny_distilled_patch16_224': _cfg( + 'deit_tiny_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', classifier=('head', 'head_dist')), - 'vit_deit_small_distilled_patch16_224': _cfg( + 'deit_small_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', classifier=('head', 'head_dist')), - 'vit_deit_base_distilled_patch16_224': _cfg( + 'deit_base_distilled_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', classifier=('head', 'head_dist')), - 'vit_deit_base_distilled_patch16_384': _cfg( + 'deit_base_distilled_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), - # ViT ImageNet-21K-P pretraining + # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil_in21k': _cfg( url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_in21k_miil.pth', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221, @@ -133,11 +142,11 @@ default_cfgs = { class Attention(nn.Module): - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) @@ -161,12 +170,11 @@ class Attention(nn.Module): class Block(nn.Module): - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): super().__init__() self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) @@ -190,7 +198,7 @@ class VisionTransformer(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init=''): """ @@ -204,7 +212,6 @@ class VisionTransformer(nn.Module): num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True - qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set distilled (bool): model includes a distillation token and head as in DeiT models drop_rate (float): dropout rate @@ -233,8 +240,8 @@ class VisionTransformer(nn.Module): dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule self.blocks = nn.Sequential(*[ Block( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) @@ -254,16 +261,17 @@ class VisionTransformer(nn.Module): if distilled: self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() - # Weight init - assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) if self.dist_token is not None: trunc_normal_(self.dist_token, std=.02) - if weight_init.startswith('jax'): + if mode.startswith('jax'): # leave cls token as zeros to match jax impl - for n, m in self.named_modules(): - _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) else: trunc_normal_(self.cls_token, std=.02) self.apply(_init_vit_weights) @@ -272,6 +280,10 @@ class VisionTransformer(nn.Module): # this fn left here for compat with downstream users _init_vit_weights(m) + @torch.jit.ignore() + def load_pretrained(self, checkpoint_path, prefix=''): + _load_weights(self, checkpoint_path, prefix) + @torch.jit.ignore def no_weight_decay(self): return {'pos_embed', 'cls_token', 'dist_token'} @@ -317,39 +329,92 @@ class VisionTransformer(nn.Module): return x -def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): +def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False): """ ViT weight initialization * When called without n, head_bias, jax_impl args it will behave exactly the same as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl """ - if isinstance(m, nn.Linear): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.startswith('pre_logits'): - lecun_normal_(m.weight) - nn.init.zeros_(m.bias) + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + elif name.startswith('pre_logits'): + lecun_normal_(module.weight) + nn.init.zeros_(module.bias) else: if jax_impl: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, std=1e-6) + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) else: - nn.init.zeros_(m.bias) + nn.init.zeros_(module.bias) else: - trunc_normal_(m.weight, std=.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif jax_impl and isinstance(m, nn.Conv2d): + trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif jax_impl and isinstance(module, nn.Conv2d): # NOTE conv was left to pytorch default in my original init - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): + nn.init.zeros_(module.bias) + nn.init.ones_(module.weight) + + +@torch.no_grad() +def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''): + """ Load weights from .npz checkpoints for official Google Brain Flax implementation + """ + import numpy as np + + def _n2p(w, t=True): + if t and w.ndim == 4: + w = w.transpose([3, 2, 0, 1]) + elif t and w.ndim == 3: + w = w.transpose([2, 0, 1]) + elif t and w.ndim == 2: + w = w.transpose([1, 0]) + return torch.from_numpy(w) + + w = np.load(checkpoint_path) + if not prefix: + prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix + + input_conv_w = adapt_input_conv( + model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) + model.patch_embed.proj.weight.copy_(input_conv_w) + model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) + model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) + model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)) + model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) + model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) + if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: + model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) + model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) + for i, block in enumerate(model.blocks.children()): + block_prefix = f'{prefix}Transformer/encoderblock_{i}/' + block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) + block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) + mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' + block.attn.qkv.weight.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T, + _n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T, + _n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T])) + block.attn.qkv.bias.copy_(torch.cat([ + _n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1), + _n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1), + _n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)])) + block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) + block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) + block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel'])) + block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias'])) + block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel'])) + block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias'])) + block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) + block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): @@ -417,23 +482,34 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw return model +@register_model +def vit_tiny_patch16_224(pretrained=False, **kwargs): + """ ViT-Tiny (Vit-Ti/16) + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_patch16_224(pretrained=False, **kwargs): - """ My custom 'small' ViT model. embed_dim=768, depth=8, num_heads=8, mlp_ratio=3. - NOTE: - * this differs from the DeiT based 'small' definitions with embed_dim=384, depth=12, num_heads=6 - * this model does not have a bias for QKV (unlike the official ViT and DeiT models) + """ ViT-Small (ViT-S/16) + NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., - qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) - if pretrained: - # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model - model_kwargs.setdefault('qk_scale', 768 ** -0.5) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model +@register_model +def vit_small_patch32_224(pretrained=False, **kwargs): + """ ViT-Small (ViT-S/32) + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). @@ -569,86 +645,86 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): @register_model -def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): +def deit_tiny_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_small_patch16_224(pretrained=False, **kwargs): +def deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_base_patch16_224(pretrained=False, **kwargs): +def deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_base_patch16_384(pretrained=False, **kwargs): +def deit_base_patch16_384(pretrained=False, **kwargs): """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer( - 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model -def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs): +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer( - 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model -def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs): +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( - 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) return model @register_model -def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer( - 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 7fc0cc88..c807ee9a 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -46,8 +46,8 @@ default_cfgs = { input_size=(3, 384, 384), crop_pct=1.0), # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) - 'vit_tiny_r_s16_p8_224': _cfg(), - 'vit_small_r_s16_p8_224': _cfg(), + 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), + 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_small_r20_s16_p2_224': _cfg(), 'vit_small_r20_s16_224': _cfg(), 'vit_small_r26_s32_224': _cfg(), @@ -57,10 +57,14 @@ default_cfgs = { 'vit_large_r50_s32_224': _cfg(), # hybrid models (using timm resnet backbones) - 'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'vit_small_resnet50d_s16_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'vit_base_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'vit_base_resnet50d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'vit_small_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_small_resnet50d_s16_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet26d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), + 'vit_base_resnet50d_224': _cfg( + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'), } @@ -140,12 +144,6 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): return model -@register_model -def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): - # NOTE this is forwarding to model def above for backwards compatibility - return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs) - - @register_model def vit_base_r50_s16_384(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). @@ -158,12 +156,6 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): return model -@register_model -def vit_base_resnet50_384(pretrained=False, **kwargs): - # NOTE this is forwarding to model def above for backwards compatibility - return vit_base_r50_s16_384(pretrained=pretrained, **kwargs) - - @register_model def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.