Finish CaiT cleanup

pull/609/head
Ross Wightman 4 years ago
parent 1daa15ecc3
commit 3db12b4b6a

@ -1,19 +1,78 @@
""" Class-Attention in Image Transformers (CaiT)
Paper: 'Going deeper with Image Transformers' - https://arxiv.org/abs/2103.17239
Original code and weights from https://github.com/facebookresearch/deit, copyright below
"""
# Copyright (c) 2015-present, Facebook, Inc. # Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved. # All rights reserved.
from copy import deepcopy
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import trunc_normal_, DropPath from .layers import trunc_normal_, DropPath
from .vision_transformer import Mlp, PatchEmbed, _cfg from .vision_transformer import Mlp, PatchEmbed
from .registry import register_model from .registry import register_model
__all__ = ['Cait', 'Class_Attention', 'LayerScale_Block_CA', 'LayerScale_Block', 'Attention_talking_head'] __all__ = ['Cait', 'ClassAttn', 'LayerScaleBlockClassAttn', 'LayerScaleBlock', 'TalkingHeadAttn']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 384, 384), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
cait_xxs24_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/XXS24_224.pth',
input_size=(3, 224, 224),
),
cait_xxs24_384=_cfg(
url='https://dl.fbaipublicfiles.com/deit/XXS24_384.pth',
),
cait_xxs36_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/XXS36_224.pth',
input_size=(3, 224, 224),
),
cait_xxs36_384=_cfg(
url='https://dl.fbaipublicfiles.com/deit/XXS36_384.pth',
),
cait_xs24_384=_cfg(
url='https://dl.fbaipublicfiles.com/deit/XS24_384.pth',
),
cait_s24_224=_cfg(
url='https://dl.fbaipublicfiles.com/deit/S24_224.pth',
input_size=(3, 224, 224),
),
cait_s24_384=_cfg(
url='https://dl.fbaipublicfiles.com/deit/S24_384.pth',
),
cait_s36_384=_cfg(
url='https://dl.fbaipublicfiles.com/deit/S36_384.pth',
),
cait_m36_384=_cfg(
url='https://dl.fbaipublicfiles.com/deit/M36_384.pth',
),
cait_m48_448=_cfg(
url='https://dl.fbaipublicfiles.com/deit/M48_448.pth',
input_size=(3, 448, 448),
),
)
class Class_Attention(nn.Module): class ClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to do CA # with slight modifications to do CA
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
@ -48,12 +107,12 @@ class Class_Attention(nn.Module):
return x_cls return x_cls
class LayerScale_Block_CA(nn.Module): class LayerScaleBlockClassAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add CA and LayerScale # with slight modifications to add CA and LayerScale
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=Class_Attention, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=ClassAttn,
mlp_block=Mlp, init_values=1e-4): mlp_block=Mlp, init_values=1e-4):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
@ -68,15 +127,12 @@ class LayerScale_Block_CA(nn.Module):
def forward(self, x, x_cls): def forward(self, x, x_cls):
u = torch.cat((x_cls, x), dim=1) u = torch.cat((x_cls, x), dim=1)
x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u))) x_cls = x_cls + self.drop_path(self.gamma_1 * self.attn(self.norm1(u)))
x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls))) x_cls = x_cls + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x_cls)))
return x_cls return x_cls
class Attention_talking_head(nn.Module): class TalkingHeadAttn(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf) # with slight modifications to add Talking Heads Attention (https://arxiv.org/pdf/2003.02436v1.pdf)
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
@ -118,12 +174,12 @@ class Attention_talking_head(nn.Module):
return x return x
class LayerScale_Block(nn.Module): class LayerScaleBlock(nn.Module):
# taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# with slight modifications to add layerScale # with slight modifications to add layerScale
def __init__( def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=Attention_talking_head, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, attn_block=TalkingHeadAttn,
mlp_block=Mlp, init_values=1e-4): mlp_block=Mlp, init_values=1e-4):
super().__init__() super().__init__()
self.norm1 = norm_layer(dim) self.norm1 = norm_layer(dim)
@ -147,17 +203,22 @@ class Cait(nn.Module):
# with slight modifications to adapt to our cait models # with slight modifications to adapt to our cait models
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=nn.LayerNorm, global_pool=None, drop_path_rate=0.,
block_layers=LayerScale_Block, norm_layer=partial(nn.LayerNorm, eps=1e-6),
block_layers_token=LayerScale_Block_CA, global_pool=None,
patch_layer=PatchEmbed, act_layer=nn.GELU, block_layers=LayerScaleBlock,
attn_block=Attention_talking_head, mlp_block=Mlp, block_layers_token=LayerScaleBlockClassAttn,
patch_layer=PatchEmbed,
act_layer=nn.GELU,
attn_block=TalkingHeadAttn,
mlp_block=Mlp,
init_scale=1e-4, init_scale=1e-4,
attn_block_token_only=Class_Attention, attn_block_token_only=ClassAttn,
mlp_block_token_only=Mlp, mlp_block_token_only=Mlp,
depth_token_only=2, depth_token_only=2,
mlp_ratio_clstk=4.0): mlp_ratio_clstk=4.0
):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -237,211 +298,103 @@ class Cait(nn.Module):
return x return x
@register_model def checkpoint_filter_fn(state_dict, model=None):
def cait_xxs24_224(pretrained=False, **kwargs): if 'model' in state_dict:
model = Cait( state_dict = state_dict['model']
img_size=224, patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS24_224.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {} checkpoint_no_module = {}
for k in model.state_dict().keys(): for k, v in state_dict.items():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k] checkpoint_no_module[k.replace('module.', '')] = v
return checkpoint_no_module
model.load_state_dict(checkpoint_no_module)
def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Cait, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model return model
@register_model @register_model
def cait_xxs24(pretrained=False, **kwargs): def cait_xxs24_224(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs)
img_size=384, patch_size=16, embed_dim=192, depth=24, num_heads=4, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs) return model
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS24_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
@register_model
def cait_xxs24_384(pretrained=False, **kwargs):
model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs)
model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def cait_xxs36_224(pretrained=False, **kwargs): def cait_xxs36_224(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs)
img_size=224, patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS36_224.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_xxs36(pretrained=False, **kwargs): def cait_xxs36_384(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs)
img_size=384, patch_size=16, embed_dim=192, depth=36, num_heads=4, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XXS36_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_xs24(pretrained=False, **kwargs): def cait_xs24_384(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs)
img_size=384, patch_size=16, embed_dim=288, depth=24, num_heads=6, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/XS24_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_s24_224(pretrained=False, **kwargs): def cait_s24_224(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs)
img_size=224, patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/S24_224.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_s24(pretrained=False, **kwargs): def cait_s24_384(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs)
img_size=384, patch_size=16, embed_dim=384, depth=24, num_heads=8, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-5, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/S24_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_s36(pretrained=False, **kwargs): def cait_s36_384(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs)
img_size=384, patch_size=16, embed_dim=384, depth=36, num_heads=8, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-6, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/S36_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_m36(pretrained=False, **kwargs): def cait_m36_384(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs)
img_size=384, patch_size=16, embed_dim=768, depth=36, num_heads=16, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-6, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/M36_384.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
@register_model @register_model
def cait_m48(pretrained=False, **kwargs): def cait_m48_448(pretrained=False, **kwargs):
model = Cait( model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs)
img_size=448, patch_size=16, embed_dim=768, depth=48, num_heads=16, mlp_ratio=4, qkv_bias=True, model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args)
norm_layer=partial(nn.LayerNorm, eps=1e-6), init_scale=1e-6, depth_token_only=2, **kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/deit/M48_448.pth",
map_location="cpu", check_hash=True
)
checkpoint_no_module = {}
for k in model.state_dict().keys():
checkpoint_no_module[k] = checkpoint["model"]['module.' + k]
model.load_state_dict(checkpoint_no_module)
return model return model
Loading…
Cancel
Save