diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 73c2e42c..2df02f49 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -274,7 +274,9 @@ class ResNetStage(nn.Module): return x -def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None): +def create_resnetv2_stem( + in_chs, out_chs=64, stem_type='', preact=True, + conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)): stem = OrderedDict() assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') @@ -322,7 +324,7 @@ class ResNetV2(nn.Module): self.feature_info = [] stem_chs = make_div(stem_chs * wf) - self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + self.stem = create_resnetv2_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) # NOTE no, reduction 2 feature if preact self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm')) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index acd4d18d..02c32cb7 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,9 +28,9 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ +from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_ from .resnet import resnet26d, resnet50d -from .resnetv2 import ResNetV2 +from .resnetv2 import ResNetV2, create_resnetv2_stem from .registry import register_model _logger = logging.getLogger(__name__) @@ -97,17 +97,62 @@ default_cfgs = { url='', # FIXME I have weights for this but > 2GB limit for github release binaries num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - # hybrid models (weights ported from official Google JAX impl) - 'vit_base_resnet50_224_in21k': _cfg( + # hybrid in-21k models (weights ported from official Google JAX impl where they exist) + 'vit_base_r50_s16_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), - 'vit_base_resnet50_384': _cfg( + + # hybrid in-1k models (weights ported from official Google JAX impl where they exist) + 'vit_small_r_s16_p8_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_p2_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_p2_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r20_s16_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r26_s32_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_small_r26_s32_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r20_s16_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r20_s16_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r26_s32_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r26_s32_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r50_s16_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_base_r50_s16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', - input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_large_r50_s32_224': _cfg( + input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), + 'vit_large_r50_s32_384': _cfg( + input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, + first_conv='patch_embed.backbone.stem.conv'), # hybrid models (my experiments) 'vit_small_resnet26d_224': _cfg(), - 'vit_small_resnet50d_s3_224': _cfg(), + 'vit_small_resnet50d_s16_224': _cfg(), 'vit_base_resnet26d_224': _cfg(), 'vit_base_resnet50d_224': _cfg(), @@ -227,11 +272,13 @@ class HybridEmbed(nn.Module): """ CNN Feature Map Embedding Extract feature map from CNN, flatten, project to embedding dim. """ - def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): + def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): super().__init__() assert isinstance(backbone, nn.Module) img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) self.img_size = img_size + self.patch_size = patch_size self.backbone = backbone if feature_size is None: with torch.no_grad(): @@ -253,8 +300,9 @@ class HybridEmbed(nn.Module): feature_dim = self.backbone.feature_info.channels()[-1] else: feature_dim = self.backbone.num_features - self.num_patches = feature_size[0] * feature_size[1] - self.proj = nn.Conv2d(feature_dim, embed_dim, 1) + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): x = self.backbone(x) @@ -270,9 +318,10 @@ class VisionTransformer(nn.Module): A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + act_layer=None): """ Args: img_size (int, tuple): input image size @@ -296,10 +345,12 @@ class VisionTransformer(nn.Module): self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + patch_size = patch_size or 1 if hybrid_backbone is not None else 16 if hybrid_backbone is not None: self.patch_embed = HybridEmbed( - hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) + hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) else: self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) @@ -313,7 +364,7 @@ class VisionTransformer(nn.Module): self.blocks = nn.ModuleList([ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) self.norm = norm_layer(embed_dim) @@ -423,13 +474,15 @@ class DistilledVisionTransformer(VisionTransformer): return (x + x_dist) / 2 -def resize_pos_embed(posemb, posemb_new): +def resize_pos_embed(posemb, posemb_new, token='class'): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if True: - posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + if token: + assert token in ('class', 'distill') + token_idx = 2 if token == 'distill' else 1 + posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] @@ -633,33 +686,190 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): return model +def _resnetv2(layers=(3, 4, 9), **kwargs): + """ ResNet-V2 backbone helper""" + padding_same = kwargs.get('padding_same', True) + if padding_same: + stem_type = 'same' + conv_layer = StdConv2dSame + else: + stem_type = '' + conv_layer = StdConv2d + if len(layers): + backbone = ResNetV2( + layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), + preact=False, stem_type=stem_type, conv_layer=conv_layer) + else: + backbone = create_resnetv2_stem( + kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer) + return backbone + + @register_model -def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): +def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ - # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head - backbone = ResNetV2( - layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), - preact=False, stem_type='same', conv_layer=StdConv2dSame) + backbone = _resnetv2(layers=(3, 4, 9), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r_s16_p8_224(pretrained=False, **kwargs): + """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2(layers=(), **kwargs) + model_kwargs = dict( + patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): + """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. + """ + backbone = _resnetv2((2, 4), **kwargs) model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, - representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) + patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r20_s16_p2_384(pretrained=False, **kwargs): + """ R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384. + """ + backbone = _resnetv2((2, 4), **kwargs) + model_kwargs = dict( + embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r20_s16_224(pretrained=False, **kwargs): + """ R20+ViT-S/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def vit_base_resnet50_384(pretrained=False, **kwargs): +def vit_small_r20_s16_384(pretrained=False, **kwargs): + """ R20+ViT-S/S16 hybrid @ 384x384. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_small_r26_s32_384(pretrained=False, **kwargs): + """ R26+ViT-S/S32 hybrid @ 384x384. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r20_s16_224(pretrained=False, **kwargs): + """ R20+ViT-B/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs) + model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r20_s16_384(pretrained=False, **kwargs): + """ R20+ViT-B/S16 hybrid. + """ + backbone = _resnetv2((2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r26_s32_224(pretrained=False, **kwargs): + """ R26+ViT-B/S32 hybrid. + """ + backbone = _resnetv2((2, 2, 2, 2), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_224(pretrained=False, **kwargs): + """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). + """ + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_r50_s16_384(pretrained=False, **kwargs): """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head - backbone = ResNetV2( - layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), - preact=False, stem_type='same', conv_layer=StdConv2dSame) + backbone = _resnetv2((3, 4, 9), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_r50_s32_224(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_large_r50_s32_384(pretrained=False, **kwargs): + """ R50+ViT-L/S32 hybrid. + """ + backbone = _resnetv2((3, 4, 6, 3), **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs) return model @@ -674,12 +884,12 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): @register_model -def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): +def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs) return model