From 3db12b4b6a7b862698b6fd85ee26b9c924d1c4d3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 5 May 2021 17:28:19 -0700 Subject: [PATCH] Finish CaiT cleanup --- timm/models/cait.py | 331 +++++++++++++++++++------------------------- 1 file changed, 142 insertions(+), 189 deletions(-) diff --git a/timm/models/cait.py b/timm/models/cait.py index b82add71..c16bf86a 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -1,19 +1,78 @@ +""" Class-Attention in Image Transformers (CaiT) + +Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239 + +Original code and weights from https://github.com/facebookresearch/deit, copyright below + +""" # Copyright (c) 2015-present, Facebook, Inc. # All rights reserved. +from copy import deepcopy import torch import torch.nn as nn from functools import partial +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg from .layers import trunc_normal_, DropPath -from .vision_transformer import Mlp, PatchEmbed, _cfg +from .vision_transformer import Mlp, PatchEmbed from .registry import register_model -__all__ = ['Cait', 'Class_Attention', 'LayerScale_Block_CA', 'LayerScale_Block', 'Attention_talking_head'] - - -class Class_Attention(nn.Module): +__all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + cait_xxs24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS24_224.pth', + input_size=(3, 224, 224), + ), + cait_xxs24_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS24_384.pth', + ), + cait_xxs36_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS36_224.pth', + input_size=(3, 224, 224), + ), + cait_xxs36_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XXS36_384.pth', + ), + cait_xs24_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/XS24_384.pth', + ), + cait_s24_224=_cfg( + url='https://dl.fbaipublicfiles.com/deit/S24_224.pth', + input_size=(3, 224, 224), + ), + cait_s24_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/S24_384.pth', + ), + cait_s36_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/S36_384.pth', + ), + cait_m36_384=_cfg( + url='https://dl.fbaipublicfiles.com/deit/M36_384.pth', + ), + cait_m48_448=_cfg( + url='https://dl.fbaipublicfiles.com/deit/M48_448.pth', + input_size=(3, 448, 448), + ), +) + + +class ClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to do CA def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): @@ -48,12 +107,12 @@ class Class_Attention(nn.Module): return x_cls -class LayerScale_Block_CA(nn.Module): +class LayerScaleBlockClassAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add CA and LayerScale def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=Class_Attention, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn, mlp_block=Mlp, init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) @@ -68,15 +127,12 @@ class LayerScale_Block_CA(nn.Module): def forward(self, x, x_cls): u = torch.cat((x_cls, x), dim=1) - x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u))) - x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls))) - return x_cls -class Attention_talking_head(nn.Module): +class TalkingHeadAttn(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): @@ -118,12 +174,12 @@ class Attention_talking_head(nn.Module): return x -class LayerScale_Block(nn.Module): +class LayerScaleBlock(nn.Module): # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # with slight modifications to add layerScale def __init__( self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=Attention_talking_head, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn, mlp_block=Mlp, init_values=1e-4): super().__init__() self.norm1 = norm_layer(dim) @@ -147,17 +203,22 @@ class Cait(nn.Module): # with slight modifications to adapt to our cait models def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None, - block_layers=LayerScale_Block, - block_layers_token=LayerScale_Block_CA, - patch_layer=PatchEmbed, act_layer=nn.GELU, - attn_block=Attention_talking_head, mlp_block=Mlp, + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), + global_pool=None, + block_layers=LayerScaleBlock, + block_layers_token=LayerScaleBlockClassAttn, + patch_layer=PatchEmbed, + act_layer=nn.GELU, + attn_block=TalkingHeadAttn, + mlp_block=Mlp, init_scale=1e-4, - attn_block_token_only=Class_Attention, + attn_block_token_only=ClassAttn, mlp_block_token_only=Mlp, depth_token_only=2, - mlp_ratio_clstk=4.0): + mlp_ratio_clstk=4.0 + ): super().__init__() self.num_classes = num_classes @@ -237,211 +298,103 @@ class Cait(nn.Module): return x -@register_model -def cait_xxs24_224(pretrained=False, **kwargs): - model = Cait( - img_size=224, patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/XXS24_224.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) +def checkpoint_filter_fn(state_dict, model=None): + if 'model' in state_dict: + state_dict = state_dict['model'] + checkpoint_no_module = {} + for k, v in state_dict.items(): + checkpoint_no_module[k.replace('module.', '')] = v + return checkpoint_no_module + + +def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + Cait, variant, pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) return model @register_model -def cait_xxs24(pretrained=False, **kwargs): - model = Cait( - img_size=384, patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/XXS24_384.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) +def cait_xxs24_224(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args) + return model + +@register_model +def cait_xxs24_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args) return model @register_model def cait_xxs36_224(pretrained=False, **kwargs): - model = Cait( - img_size=224, patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/XXS36_224.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args) return model @register_model -def cait_xxs36(pretrained=False, **kwargs): - model = Cait( - img_size=384, patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/XXS36_384.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - +def cait_xxs36_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args) return model @register_model -def cait_xs24(pretrained=False, **kwargs): - model = Cait( - img_size=384, patch_size=16, embed_dim=288, depth=24, num_heads=6, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/XS24_384.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - +def cait_xs24_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs) + model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args) return model @register_model def cait_s24_224(pretrained=False, **kwargs): - model = Cait( - img_size=224, patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/S24_224.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) + model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args) return model @register_model -def cait_s24(pretrained=False, **kwargs): - model = Cait( - img_size=384, patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/S24_384.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - +def cait_s24_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) + model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args) return model @register_model -def cait_s36(pretrained=False, **kwargs): - model = Cait( - img_size=384, patch_size=16, embed_dim=384, depth=36, num_heads=8, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-6, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/S36_384.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - +def cait_s36_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs) + model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args) return model @register_model -def cait_m36(pretrained=False, **kwargs): - model = Cait( - img_size=384, patch_size=16, embed_dim=768, depth=36, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-6, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/M36_384.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - +def cait_m36_384(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs) + model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args) return model @register_model -def cait_m48(pretrained=False, **kwargs): - model = Cait( - img_size=448, patch_size=16, embed_dim=768, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-6, depth_token_only=2, **kwargs) - - model.default_cfg = _cfg() - if pretrained: - checkpoint = torch.hub.load_state_dict_from_url( - url="https://dl.fbaipublicfiles.com/deit/M48_448.pth", - map_location="cpu", check_hash=True - ) - checkpoint_no_module = {} - for k in model.state_dict().keys(): - checkpoint_no_module[k] = checkpoint["model"]['module.' + k] - - model.load_state_dict(checkpoint_no_module) - - return model \ No newline at end of file +def cait_m48_448(pretrained=False, **kwargs): + model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs) + model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args) + return model