diff --git a/tests/test_models.py b/tests/test_models.py index 5ff9fb33..570b49db 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,29 +26,41 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS + '*resnetrs350*', '*resnetrs420*'] else: - EXCLUDE_FILTERS = NON_STD_FILTERS + EXCLUDE_FILTERS = [] -MAX_FWD_SIZE = 384 -MAX_BWD_SIZE = 128 +TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 +TARGET_BWD_SIZE = 128 +MAX_BWD_SIZE = 384 MAX_FWD_FEAT_SIZE = 448 +def _get_input_size(model, target=None): + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + if 'fixed_input_size' in default_cfg and default_cfg['fixed_input_size']: + return input_size + if 'min_input_size' in default_cfg: + if target and max(input_size) > target: + input_size = default_cfg['min_input_size'] + else: + if target and max(input_size) > target: + input_size = tuple([min(x, target) for x in input_size]) + return input_size + + @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False) model.eval() - input_size = model.default_cfg['input_size'] - if any([x > MAX_FWD_SIZE for x in input_size]): - if is_model_default_key(model_name, 'fixed_input_size'): - pytest.skip("Fixed input size model > limit.") - # cap forward test at max res 384 * 384 to keep resource down - input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size]) + input_size = _get_input_size(model, TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) @@ -63,20 +75,16 @@ def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) - model.eval() + model.train() - input_size = model.default_cfg['input_size'] - if not is_model_default_key(model_name, 'fixed_input_size'): - min_input_size = get_model_default_value(model_name, 'min_input_size') - if min_input_size is not None: - input_size = min_input_size - else: - if any([x > MAX_BWD_SIZE for x in input_size]): - # cap backward test at 128 * 128 to keep resource usage down - input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) + input_size = _get_input_size(model, TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) outputs.mean().backward() for n, x in model.named_parameters(): assert x.grad is not None, f'No gradient for {n}' @@ -168,12 +176,9 @@ def test_model_forward_torchscript(model_name, batch_size): model = create_model(model_name, pretrained=False) model.eval() - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model, 128) + if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) @@ -184,7 +189,7 @@ def test_model_forward_torchscript(model_name, batch_size): EXCLUDE_FEAT_FILTERS = [ '*pruned*', # hopefully fix at some point -] +] + NON_STD_FILTERS if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] @@ -200,12 +205,9 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model, 96) # jit compile is already a bit slow and we've tested normal res already... + if max(input_size) > MAX_FWD_SIZE: # NOTE using MAX_FWD_SIZE as the final limit is intentional + pytest.skip("Fixed input size model > limit.") outputs = model(torch.randn((batch_size, *input_size))) assert len(expected_channels) == len(outputs) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 1a21de09..788b7518 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,8 +16,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * -from .levitc import * from .levit import * +#from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * diff --git a/timm/models/cait.py b/timm/models/cait.py index c5f7742f..aa2e5f07 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None): return checkpoint_no_module -def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_cait(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Cait, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/coat.py b/timm/models/coat.py index cb265522..9eb384d8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from typing import Tuple, Dict, Any, Optional +from copy import deepcopy +from functools import partial +from typing import Tuple, List import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained -from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model -from functools import partial -from torch import nn __all__ = [ "coat_tiny", @@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed1.proj', 'classifier': 'head', **kwargs @@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs): default_cfgs = { - 'coat_tiny': _cfg_coat(), - 'coat_mini': _cfg_coat(), + 'coat_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth' + ), + 'coat_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth' + ), 'coat_lite_tiny': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' ), 'coat_lite_mini': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' ), - 'coat_lite_small': _cfg_coat(), + 'coat_lite_small': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth' + ), } @@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module): class FactorAtt_ConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. @@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module): class SerialBlock(nn.Module): """ Serial block class. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpe=None, shared_crpe=None): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None): super().__init__() # Conv-Attention. @@ -200,8 +205,7 @@ class SerialBlock(nn.Module): self.norm1 = norm_layer(dim) self.factoratt_crpe = FactorAtt_ConvRelPosEnc( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpe) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. @@ -226,27 +230,24 @@ class SerialBlock(nn.Module): class ParallelBlock(nn.Module): """ Parallel block class. """ - def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpes=None, shared_crpes=None): + def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None): super().__init__() # Conv-Attention. - self.cpes = shared_cpes - self.norm12 = norm_layer(dims[1]) self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( - dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1] ) self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( - dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2] ) self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( - dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3] ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -262,15 +263,15 @@ class ParallelBlock(nn.Module): self.mlp2 = self.mlp3 = self.mlp4 = Mlp( in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def upsample(self, x, factor, size): + def upsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map up-sampling. """ return self.interpolate(x, scale_factor=factor, size=size) - def downsample(self, x, factor, size): + def downsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map down-sampling. """ return self.interpolate(x, scale_factor=1.0/factor, size=size) - def interpolate(self, x, scale_factor, size): + def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): """ Feature map interpolation. """ B, N, C = x.shape H, W = size @@ -280,33 +281,28 @@ class ParallelBlock(nn.Module): img_tokens = x[:, 1:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) - img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') + img_tokens = F.interpolate( + img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) out = torch.cat((cls_token, img_tokens), dim=1) return out - def forward(self, x1, x2, x3, x4, sizes): - _, (H2, W2), (H3, W3), (H4, W4) = sizes - - # Conv-Attention. - x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored. - x3 = self.cpes[2](x3, size=(H3, W3)) - x4 = self.cpes[3](x4, size=(H4, W4)) - + def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): + _, S2, S3, S4 = sizes cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) - cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) - cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) - cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) - upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) - upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) - upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) - downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) - downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) - downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) + cur2 = self.factoratt_crpe2(cur2, size=S2) + cur3 = self.factoratt_crpe3(cur3, size=S3) + cur4 = self.factoratt_crpe4(cur4, size=S4) + upsample3_2 = self.upsample(cur3, factor=2., size=S3) + upsample4_3 = self.upsample(cur4, factor=2., size=S4) + upsample4_2 = self.upsample(cur4, factor=4., size=S4) + downsample2_3 = self.downsample(cur2, factor=2., size=S2) + downsample3_4 = self.downsample(cur3, factor=2., size=S3) + downsample2_4 = self.downsample(cur2, factor=4., size=S2) cur2 = cur2 + upsample3_2 + upsample4_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 + downsample2_4 @@ -330,11 +326,11 @@ class ParallelBlock(nn.Module): class CoaT(nn.Module): """ CoaT class. """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], - serial_depths=[0, 0, 0, 0], parallel_depth=0, - num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), + serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, out_features=None, crpe_window=None, **kwargs): super().__init__() crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers @@ -342,17 +338,18 @@ class CoaT(nn.Module): self.num_classes = num_classes # Patch embeddings. + img_size = to_2tuple(img_size) self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) self.patch_embed2 = PatchEmbed( - img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) self.patch_embed3 = PatchEmbed( - img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) self.patch_embed4 = PatchEmbed( - img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) # Class tokens. @@ -380,7 +377,7 @@ class CoaT(nn.Module): # Serial blocks 1. self.serial_blocks1 = nn.ModuleList([ SerialBlock( - dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe1, shared_crpe=self.crpe1 ) @@ -390,7 +387,7 @@ class CoaT(nn.Module): # Serial blocks 2. self.serial_blocks2 = nn.ModuleList([ SerialBlock( - dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe2, shared_crpe=self.crpe2 ) @@ -400,7 +397,7 @@ class CoaT(nn.Module): # Serial blocks 3. self.serial_blocks3 = nn.ModuleList([ SerialBlock( - dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe3, shared_crpe=self.crpe3 ) @@ -410,7 +407,7 @@ class CoaT(nn.Module): # Serial blocks 4. self.serial_blocks4 = nn.ModuleList([ SerialBlock( - dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe4, shared_crpe=self.crpe4 ) @@ -422,10 +419,9 @@ class CoaT(nn.Module): if self.parallel_depth > 0: self.parallel_blocks = nn.ModuleList([ ParallelBlock( - dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], - shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4] + dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4) ) for _ in range(parallel_depth)] ) @@ -434,9 +430,11 @@ class CoaT(nn.Module): # Classification head(s). if not self.return_interm_layers: - self.norm1 = norm_layer(embed_dims[0]) - self.norm2 = norm_layer(embed_dims[1]) - self.norm3 = norm_layer(embed_dims[2]) + if self.parallel_blocks is not None: + self.norm2 = norm_layer(embed_dims[1]) + self.norm3 = norm_layer(embed_dims[2]) + else: + self.norm2 = self.norm3 = None self.norm4 = norm_layer(embed_dims[3]) if self.parallel_depth > 0: @@ -546,6 +544,7 @@ class CoaT(nn.Module): # Parallel blocks. for blk in self.parallel_blocks: + x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4)) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) if not torch.jit.is_scripting() and self.return_interm_layers: @@ -590,52 +589,70 @@ class CoaT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + for k, v in state_dict.items(): + # original model had unused norm layers, removing them requires filtering pretrained checkpoints + if k.startswith('norm1') or \ + (model.norm2 is None and k.startswith('norm2')) or \ + (model.norm3 is None and k.startswith('norm3')): + continue + out_dict[k] = v + return out_dict + + +def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + CoaT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def coat_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_tiny'] + model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_mini'] + model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_tiny'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_mini'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_small(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_lite_small'] + model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg) return model \ No newline at end of file diff --git a/timm/models/convit.py b/timm/models/convit.py index f6ae3ec1..b15b46d8 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -39,7 +39,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } @@ -317,6 +317,9 @@ class ConViT(nn.Module): def _create_convit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + return build_model_with_cfg( ConViT, variant, pretrained, default_cfg=default_cfgs[variant], diff --git a/timm/models/helpers.py b/timm/models/helpers.py index e9ac7f00..dfb6b860 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) @@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): # Overlay default cfg values from `external_default_cfg` if it exists in kwargs overlay_external_default_cfg(default_cfg, kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) # Filter keyword args for task specific model variants (some 'features only' models, etc.) filter_kwargs(kwargs, names=kwargs_filter) diff --git a/timm/models/levit.py b/timm/models/levit.py index 997b44d7..96a0c85b 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -1,3 +1,22 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications by/coyright Copyright 2021 Ross Wightman +""" + # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. @@ -5,10 +24,15 @@ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # Copyright 2020 Ross Wightman, Apache-2.0 License import itertools +from copy import deepcopy +from functools import partial import torch +import torch.nn as nn from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_ntuple from .vision_transformer import trunc_normal_ from .registry import register_model @@ -19,70 +43,113 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), **kwargs } -specification = { - 'levit_128s': { - 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, - 'levit_128': { - 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, - 'levit_192': { - 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, - 'levit_256': { - 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, - 'levit_384': { - 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, -} +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), +) __all__ = ['Levit'] @register_model -def levit_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_128s'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_128'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_192'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_256'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) @register_model -def levit_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_384'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) +def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) -class ConvNorm(torch.nn.Sequential): +@register_model +def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +class ConvNorm(nn.Sequential): def __init__( self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = torch.nn.BatchNorm2d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) + self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = nn.BatchNorm2d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() @@ -91,7 +158,7 @@ class ConvNorm(torch.nn.Sequential): w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = c.weight * w[:, None, None, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d( + m = nn.Conv2d( w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) m.weight.data.copy_(w) @@ -99,13 +166,13 @@ class ConvNorm(torch.nn.Sequential): return m -class LinearNorm(torch.nn.Sequential): +class LinearNorm(nn.Sequential): def __init__(self, a, b, bn_weight_init=1, resolution=-100000): super().__init__() - self.add_module('c', torch.nn.Linear(a, b, bias=False)) - bn = torch.nn.BatchNorm1d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) + self.add_module('c', nn.Linear(a, b, bias=False)) + bn = nn.BatchNorm1d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) self.add_module('bn', bn) @torch.no_grad() @@ -114,25 +181,24 @@ class LinearNorm(torch.nn.Sequential): w = bn.weight / (bn.running_var + bn.eps) ** 0.5 w = l.weight * w[:, None] b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Linear(w.size(1), w.size(0)) + m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m def forward(self, x): - l, bn = self._modules.values() - x = l(x) - return bn(x.flatten(0, 1)).reshape_as(x) + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) -class NormLinear(torch.nn.Sequential): +class NormLinear(nn.Sequential): def __init__(self, a, b, bias=True, std=0.02): super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - l = torch.nn.Linear(a, b, bias=bias) + self.add_module('bn', nn.BatchNorm1d(a)) + l = nn.Linear(a, b, bias=bias) trunc_normal_(l.weight, std=std) if bias: - torch.nn.init.constant_(l.bias, 0) + nn.init.constant_(l.bias, 0) self.add_module('l', l) @torch.no_grad() @@ -145,24 +211,24 @@ class NormLinear(torch.nn.Sequential): b = b @ self.l.weight.T else: b = (l.weight @ b[:, None]).view(-1) + self.l.bias - m = torch.nn.Linear(w.size(1), w.size(0)) + m = nn.Linear(w.size(1), w.size(0)) m.weight.data.copy_(w) m.bias.data.copy_(b) return m -def b16(n, activation, resolution=224): - return torch.nn.Sequential( - ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), activation(), - ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), activation(), - ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), activation(), - ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) -class Residual(torch.nn.Module): +class Residual(nn.Module): def __init__(self, m, drop): super().__init__() self.m = m @@ -176,10 +242,23 @@ class Residual(torch.nn.Module): return x + self.m(x) -class Attention(torch.nn.Module): +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): def __init__( - self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14): + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): super().__init__() + self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim @@ -187,11 +266,13 @@ class Attention(torch.nn.Module): self.d = int(attn_ratio * key_dim) self.dh = int(attn_ratio * key_dim) * num_heads self.attn_ratio = attn_ratio + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm h = self.dh + nh_kd * 2 - self.qkv = LinearNorm(dim, h, resolution=resolution) - self.proj = torch.nn.Sequential( + self.qkv = ln_layer(dim, h, resolution=resolution) + self.proj = nn.Sequential( act_layer(), - LinearNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) + ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) points = list(itertools.product(range(resolution), range(resolution))) N = len(points) @@ -203,68 +284,68 @@ class Attention(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = None @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab + self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,N,C) - B, N, C = x.shape - qkv = self.qkv(x) - q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) - q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) - v = v.permute(0, 2, 1, 3) - - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab - - attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) x = self.proj(x) return x -class Subsample(torch.nn.Module): - def __init__(self, stride, resolution): - super().__init__() - self.stride = stride - self.resolution = resolution - - def forward(self, x): - B, N, C = x.shape - x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] - return x.reshape(B, -1, C) - - -class AttentionSubsample(torch.nn.Module): - def __init__(self, in_dim, out_dim, key_dim, num_heads=8, - attn_ratio=2, act_layer=None, stride=2, resolution=14, resolution_=7): +class AttentionSubsample(nn.Module): + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): super().__init__() self.num_heads = num_heads self.scale = key_dim ** -0.5 self.key_dim = key_dim self.nh_kd = nh_kd = key_dim * num_heads self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads + self.dh = self.d * self.num_heads self.attn_ratio = attn_ratio self.resolution_ = resolution_ self.resolution_2 = resolution_ ** 2 - h = self.dh + nh_kd - self.kv = LinearNorm(in_dim, h, resolution=resolution) + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) - self.q = torch.nn.Sequential( - Subsample(stride, resolution), - LinearNorm(in_dim, nh_kd, resolution=resolution_)) - self.proj = torch.nn.Sequential( + h = self.dh + nh_kd + self.kv = ln_layer(in_dim, h, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, nh_kd, resolution=resolution_)) + self.proj = nn.Sequential( act_layer(), - LinearNorm(self.dh, out_dim, resolution=resolution_)) + ln_layer(self.dh, out_dim, resolution=resolution_)) self.stride = stride self.resolution = resolution @@ -283,35 +364,43 @@ class AttentionSubsample(torch.nn.Module): if offset not in attention_offsets: attention_offsets[offset] = len(attention_offsets) idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - + self.ab = None @torch.no_grad() def train(self, mode=True): super().train(mode) - if mode and hasattr(self, 'ab'): - del self.ab - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] + self.ab = None if mode else self.attention_biases[:, self.attention_bias_idxs] def forward(self, x): - B, N, C = x.shape - k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) - k = k.permute(0, 2, 1, 3) # BHNC - v = v.permute(0, 2, 1, 3) # BHNC - q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = (q.transpose(-2, -1) @ k) * self.scale + ab + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = q @ k.transpose(-2, -1) * self.scale + ab - attn = attn.softmax(dim=-1) + ab = self.attention_biases[:, self.attention_bias_idxs] if self.ab is None else self.ab + attn = q @ k.transpose(-2, -1) * self.scale + ab + attn = attn.softmax(dim=-1) - x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) x = self.proj(x) return x -class Levit(torch.nn.Module): +class Levit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ @@ -321,45 +410,63 @@ class Levit(torch.nn.Module): patch_size=16, in_chans=3, num_classes=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, hybrid_backbone=None, - down_ops=[], - attn_act_layer=torch.nn.Hardswish, - mlp_act_layer=torch.nn.Hardswish, + down_ops=None, + act_layer=nn.Hardswish, + attn_act_layer=nn.Hardswish, distillation=True, + use_conv=False, drop_path=0): super().__init__() - global FLOPS_COUNTER - + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] self.num_classes = num_classes self.num_features = embed_dim[-1] self.embed_dim = embed_dim + N = len(embed_dim) + assert len(depth) == len(num_heads) == N + key_dim = to_ntuple(N)(key_dim) + attn_ratio = to_ntuple(N)(attn_ratio) + mlp_ratio = to_ntuple(N)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) self.distillation = distillation + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm - self.patch_embed = hybrid_backbone + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) self.blocks = [] - down_ops.append(['']) resolution = img_size // patch_size for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): for _ in range(dpth): self.blocks.append( Residual( - Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), drop_path)) if mr > 0: h = int(ed * mr) self.blocks.append( - Residual(torch.nn.Sequential( - LinearNorm(ed, h, resolution=resolution), - mlp_act_layer(), - LinearNorm(h, ed, bn_weight_init=0, resolution=resolution), + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), ), drop_path)) if do[0] == 'Subsample': # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) @@ -368,22 +475,22 @@ class Levit(torch.nn.Module): AttentionSubsample( *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_)) + resolution=resolution, resolution_=resolution_, use_conv=use_conv)) resolution = resolution_ if do[4] > 0: # mlp_ratio h = int(embed_dim[i + 1] * do[4]) self.blocks.append( - Residual(torch.nn.Sequential( - LinearNorm(embed_dim[i + 1], h, resolution=resolution), - mlp_act_layer(), - LinearNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), ), drop_path)) - self.blocks = torch.nn.Sequential(*self.blocks) + self.blocks = nn.Sequential(*self.blocks) # Classifier head - self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() if distillation: - self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() else: self.head_dist = None @@ -393,48 +500,44 @@ class Levit(torch.nn.Module): def forward(self, x): x = self.patch_embed(x) - x = x.flatten(2).transpose(1, 2) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) x = self.blocks(x) - x = x.mean(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + if self.head_dist is not None: + x, x_dist = self.head(x), self.head_dist(x) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 else: x = self.head(x) return x -def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = torch.nn.Hardswish - model = Levit( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attn_act_layer=act, - mlp_act_layer=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - num_classes=num_classes, - drop_path=drop_path, - distillation=distillation - ) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url(weights, map_location='cpu') - model.load_state_dict(checkpoint['model']) +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + Levit, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) #if fuse: # utils.replace_batchnorm(model) - return model + diff --git a/timm/models/levitc.py b/timm/models/levitc.py deleted file mode 100644 index 1a422953..00000000 --- a/timm/models/levitc.py +++ /dev/null @@ -1,400 +0,0 @@ -# Copyright (c) 2015-present, Facebook, Inc. -# All rights reserved. - -# Modified from -# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py -# Copyright 2020 Ross Wightman, Apache-2.0 License -import itertools - -import torch - -from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN -from .vision_transformer import trunc_normal_ -from .registry import register_model - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', - **kwargs - } - - -specification = { - 'levit_c_128s': { - 'C': '128_256_384', 'D': 16, 'N': '4_6_8', 'X': '2_3_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'}, - 'levit_c_128': { - 'C': '128_256_384', 'D': 16, 'N': '4_8_12', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'}, - 'levit_c_192': { - 'C': '192_288_384', 'D': 32, 'N': '3_5_6', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'}, - 'levit_c_256': { - 'C': '256_384_512', 'D': 32, 'N': '4_6_8', 'X': '4_4_4', 'drop_path': 0, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'}, - 'levit_c_384': { - 'C': '384_512_768', 'D': 32, 'N': '6_9_12', 'X': '4_4_4', 'drop_path': 0.1, - 'weights': 'https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'}, -} - -__all__ = ['Levit'] - - -@register_model -def levit_c_128s(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_128s'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_128(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_128'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_192(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_192'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_256(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_256'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -@register_model -def levit_c_384(num_classes=1000, distillation=True, pretrained=False, fuse=False, **kwargs): - return model_factory(**specification['levit_c_384'], num_classes=num_classes, - distillation=distillation, pretrained=pretrained, fuse=fuse) - - -class ConvNorm(torch.nn.Sequential): - def __init__( - self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): - super().__init__() - self.add_module('c', torch.nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) - bn = torch.nn.BatchNorm2d(b) - torch.nn.init.constant_(bn.weight, bn_weight_init) - torch.nn.init.constant_(bn.bias, 0) - self.add_module('bn', bn) - - @torch.no_grad() - def fuse(self): - c, bn = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = c.weight * w[:, None, None, None] - b = bn.bias - bn.running_mean * bn.weight / \ - (bn.running_var + bn.eps) ** 0.5 - m = torch.nn.Conv2d( - w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, - padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -class NormLinear(torch.nn.Sequential): - def __init__(self, a, b, bias=True, std=0.02): - super().__init__() - self.add_module('bn', torch.nn.BatchNorm1d(a)) - l = torch.nn.Linear(a, b, bias=bias) - trunc_normal_(l.weight, std=std) - if bias: - torch.nn.init.constant_(l.bias, 0) - self.add_module('l', l) - - @torch.no_grad() - def fuse(self): - bn, l = self._modules.values() - w = bn.weight / (bn.running_var + bn.eps) ** 0.5 - b = bn.bias - self.bn.running_mean * \ - self.bn.weight / (bn.running_var + bn.eps) ** 0.5 - w = l.weight * w[None, :] - if l.bias is None: - b = b @ self.l.weight.T - else: - b = (l.weight @ b[:, None]).view(-1) + self.l.bias - m = torch.nn.Linear(w.size(1), w.size(0)) - m.weight.data.copy_(w) - m.bias.data.copy_(b) - return m - - -def b16(n, activation, resolution=224): - return torch.nn.Sequential( - ConvNorm(3, n // 8, 3, 2, 1, resolution=resolution), - activation(), - ConvNorm(n // 8, n // 4, 3, 2, 1, resolution=resolution // 2), - activation(), - ConvNorm(n // 4, n // 2, 3, 2, 1, resolution=resolution // 4), - activation(), - ConvNorm(n // 2, n, 3, 2, 1, resolution=resolution // 8)) - - -class Residual(torch.nn.Module): - def __init__(self, m, drop): - super().__init__() - self.m = m - self.drop = drop - - def forward(self, x): - if self.training and self.drop > 0: - return x + self.m(x) * torch.rand( - x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() - else: - return x + self.m(x) - - -class Attention(torch.nn.Module): - def __init__(self, dim, key_dim, num_heads=8, - attn_ratio=4, act_layer=None, resolution=14): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim ** -0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * num_heads - self.attn_ratio = attn_ratio - h = self.dh + nh_kd * 2 - self.qkv = ConvNorm(dim, h, resolution=resolution) - self.proj = torch.nn.Sequential( - act_layer(), - ConvNorm(self.dh, dim, bn_weight_init=0, resolution=resolution)) - - points = list(itertools.product(range(resolution), range(resolution))) - N = len(points) - attention_offsets = {} - idxs = [] - for p1 in points: - for p2 in points: - offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) - self.ab = None - - @torch.no_grad() - def train(self, mode=True): - super().train(mode) - if mode and self.ab is not None: - self.ab = None - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): # x (B,C,H,W) - B, C, H, W = x.shape - q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab - attn = attn.softmax(dim=-1) - x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) - x = self.proj(x) - return x - - -class AttentionSubsample(torch.nn.Module): - def __init__( - self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, - act_layer=None, stride=2, resolution=14, resolution_=7): - super().__init__() - self.num_heads = num_heads - self.scale = key_dim ** -0.5 - self.key_dim = key_dim - self.nh_kd = nh_kd = key_dim * num_heads - self.d = int(attn_ratio * key_dim) - self.dh = int(attn_ratio * key_dim) * self.num_heads - self.attn_ratio = attn_ratio - self.resolution_ = resolution_ - self.resolution_2 = resolution_ ** 2 - h = self.dh + nh_kd - self.kv = ConvNorm(in_dim, h, resolution=resolution) - self.q = torch.nn.Sequential( - torch.nn.AvgPool2d(1, stride, 0), - ConvNorm(in_dim, nh_kd, resolution=resolution_)) - self.proj = torch.nn.Sequential( - act_layer(), - ConvNorm(self.d * num_heads, out_dim, resolution=resolution_)) - - self.stride = stride - self.resolution = resolution - points = list(itertools.product(range(resolution), range(resolution))) - points_ = list(itertools.product(range(resolution_), range(resolution_))) - N = len(points) - N_ = len(points_) - attention_offsets = {} - idxs = [] - for p1 in points_: - for p2 in points: - size = 1 - offset = ( - abs(p1[0] * stride - p2[0] + (size - 1) / 2), - abs(p1[1] * stride - p2[1] + (size - 1) / 2)) - if offset not in attention_offsets: - attention_offsets[offset] = len(attention_offsets) - idxs.append(attention_offsets[offset]) - self.attention_biases = torch.nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) - self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) - self.ab = None - - @torch.no_grad() - def train(self, mode=True): - super().train(mode) - if mode and self.ab is not None: - self.ab = None - else: - self.ab = self.attention_biases[:, self.attention_bias_idxs] - - def forward(self, x): - B, C, H, W = x.shape - k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) - q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) - ab = self.attention_biases[:, self.attention_bias_idxs] if self.training else self.ab - attn = (q.transpose(-2, -1) @ k) * self.scale + ab - attn = attn.softmax(dim=-1) - - x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) - x = self.proj(x) - return x - - -class Levit(torch.nn.Module): - """ Vision Transformer with support for patch or hybrid CNN input stage - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dim=[192], - key_dim=[64], - depth=[12], - num_heads=[3], - attn_ratio=[2], - mlp_ratio=[2], - hybrid_backbone=None, - down_ops=[], - attn_act_layer=torch.nn.Hardswish, - mlp_act_layer=torch.nn.Hardswish, - distillation=True, - drop_path=0): - super().__init__() - self.num_classes = num_classes - self.num_features = embed_dim[-1] - self.embed_dim = embed_dim - self.distillation = distillation - - self.patch_embed = hybrid_backbone - - self.blocks = [] - down_ops.append(['']) - resolution = img_size // patch_size - for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( - zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): - for _ in range(dpth): - self.blocks.append( - Residual( - Attention(ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, resolution=resolution), - drop_path)) - if mr > 0: - h = int(ed * mr) - self.blocks.append( - Residual(torch.nn.Sequential( - ConvNorm(ed, h, resolution=resolution), - mlp_act_layer(), - ConvNorm(h, ed, bn_weight_init=0, resolution=resolution), - ), drop_path)) - if do[0] == 'Subsample': - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - resolution_ = (resolution - 1) // do[5] + 1 - self.blocks.append( - AttentionSubsample( - *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], attn_ratio=do[3], - act_layer=attn_act_layer, stride=do[5], - resolution=resolution, resolution_=resolution_)) - resolution = resolution_ - if do[4] > 0: # mlp_ratio - h = int(embed_dim[i + 1] * do[4]) - self.blocks.append( - Residual(torch.nn.Sequential( - ConvNorm(embed_dim[i + 1], h, resolution=resolution), - mlp_act_layer(), - ConvNorm(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), - ), drop_path)) - self.blocks = torch.nn.Sequential(*self.blocks) - - # Classifier head - self.head = NormLinear( - embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - if distillation: - self.head_dist = NormLinear( - embed_dim[-1], num_classes) if num_classes > 0 else torch.nn.Identity() - - @torch.jit.ignore - def no_weight_decay(self): - return {x for x in self.state_dict().keys() if 'attention_biases' in x} - - def forward(self, x): - x = self.patch_embed(x) - x = self.blocks(x) - x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1) - if self.distillation: - x = self.head(x), self.head_dist(x) - if not self.training: - x = (x[0] + x[1]) / 2 - else: - x = self.head(x) - return x - - -def model_factory(C, D, X, N, drop_path, weights, num_classes, distillation, pretrained, fuse): - embed_dim = [int(x) for x in C.split('_')] - num_heads = [int(x) for x in N.split('_')] - depth = [int(x) for x in X.split('_')] - act = torch.nn.Hardswish - model = Levit( - patch_size=16, - embed_dim=embed_dim, - num_heads=num_heads, - key_dim=[D] * 3, - depth=depth, - attn_ratio=[2, 2, 2], - mlp_ratio=[2, 2, 2], - down_ops=[ - # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) - ['Subsample', D, embed_dim[0] // D, 4, 2, 2], - ['Subsample', D, embed_dim[1] // D, 4, 2, 2], - ], - attn_act_layer=act, - mlp_act_layer=act, - hybrid_backbone=b16(embed_dim[0], activation=act), - num_classes=num_classes, - drop_path=drop_path, - distillation=distillation - ) - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - weights, map_location='cpu') - d = checkpoint['model'] - D = model.state_dict() - for k in d.keys(): - if D[k].shape != d[k].shape: - d[k] = d[k][:, :, None, None] - model.load_state_dict(d) - #if fuse: - # utils.replace_batchnorm(model) - - return model - diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 92ca115b..5a6dce6f 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.): nn.init.ones_(m.weight) -def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_mixer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for MLP-Mixer models.') model = build_model_with_cfg( MlpMixer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/pit.py b/timm/models/pit.py index 040d96db..9c350861 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model): def _create_pit(variant, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - img_size = kwargs.pop('img_size', default_img_size) - num_classes = kwargs.pop('num_classes', default_num_classes) - if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( PoolingVisionTransformer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 8e038718..8186cc4a 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,7 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained +from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model): return state_dict +def _create_tnt(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + TNT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_s_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), - filter_fn=checkpoint_filter_fn) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) return model @register_model def tnt_b_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_b_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) return model diff --git a/timm/models/twins.py b/timm/models/twins.py index a534d174..793d2ede 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -33,7 +33,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', **kwargs } @@ -361,25 +361,14 @@ class Twins(nn.Module): return x -def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) +def _create_twins(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Twins, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/visformer.py b/timm/models/visformer.py index aa3bca57..df1d502a 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -1,3 +1,12 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +""" +from copy import deepcopy + import torch import torch.nn as nn import torch.nn.functional as F @@ -22,6 +31,12 @@ def _cfg(url='', **kwargs): } +default_cfgs = dict( + visformer_tiny=_cfg(), + visformer_small=_cfg(), +) + + class LayerNormBHWC(nn.LayerNorm): def __init__(self, dim): super().__init__(dim) @@ -300,87 +315,97 @@ class Visformer(nn.Module): return x +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Visformer, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + @register_model def visformer_tiny(pretrained=False, **kwargs): - model = Visformer( + model_cfg = dict( img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) - model.default_cfg = _cfg() + model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) return model @register_model def visformer_small(pretrained=False, **kwargs): - model = Visformer( + model_cfg = dict( img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, embed_norm=nn.BatchNorm2d, **kwargs) - model.default_cfg = _cfg() + model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) return model -@register_model -def visformer_net1(pretrained=False, **kwargs): - model = Visformer( - init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net2(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net3(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net4(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', - spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net5(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', - spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net6(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', - pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model - - -@register_model -def visformer_net7(pretrained=False, **kwargs): - model = Visformer( - init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', - pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) - model.default_cfg = _cfg() - return model +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bef6dfb0..ff74d836 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), - model.patch_embed.grid_size) + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) out_dict[k] = v return out_dict def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: # Remove representation layer if fine-tuning. This may not always be the desired action, @@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw _logger.warning("Removing representation layer for fine-tuning.") repr_size = None - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( VisionTransformer, variant, pretrained, default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 1656559f..9e5a62b2 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -27,7 +27,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', **kwargs @@ -107,11 +107,10 @@ class HybridEmbed(nn.Module): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer( - variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) + variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) def _resnetv2(layers=(3, 4, 9), **kwargs):