|
|
|
"""
|
|
|
|
CoaT architecture.
|
|
|
|
|
|
|
|
Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
|
|
|
|
|
|
|
|
Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
|
|
|
|
|
|
|
|
Modified from timm/models/vision_transformer.py
|
|
|
|
"""
|
|
|
|
from typing import Tuple, Dict, Any, Optional
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.models.helpers import load_pretrained
|
|
|
|
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
|
|
|
|
from functools import partial
|
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
"coat_tiny",
|
|
|
|
"coat_mini",
|
|
|
|
"coat_lite_tiny",
|
|
|
|
"coat_lite_mini",
|
|
|
|
"coat_lite_small"
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg_coat(url='', **kwargs):
|
|
|
|
return {
|
|
|
|
'url': url,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
|
|
|
'crop_pct': .9, 'interpolation': 'bicubic',
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
|
|
|
'first_conv': 'patch_embed1.proj', 'classifier': 'head',
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
'coat_tiny': _cfg_coat(),
|
|
|
|
'coat_mini': _cfg_coat(),
|
|
|
|
'coat_lite_tiny': _cfg_coat(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
|
|
|
|
),
|
|
|
|
'coat_lite_mini': _cfg_coat(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
|
|
|
|
),
|
|
|
|
'coat_lite_small': _cfg_coat(),
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
class ConvRelPosEnc(nn.Module):
|
|
|
|
""" Convolutional relative position encoding. """
|
|
|
|
def __init__(self, Ch, h, window):
|
|
|
|
"""
|
|
|
|
Initialization.
|
|
|
|
Ch: Channels per head.
|
|
|
|
h: Number of heads.
|
|
|
|
window: Window size(s) in convolutional relative positional encoding. It can have two forms:
|
|
|
|
1. An integer of window size, which assigns all attention heads with the same window s
|
|
|
|
size in ConvRelPosEnc.
|
|
|
|
2. A dict mapping window size to #attention head splits (
|
|
|
|
e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
|
|
|
|
It will apply different window size to the attention head splits.
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
if isinstance(window, int):
|
|
|
|
# Set the same window size for all attention heads.
|
|
|
|
window = {window: h}
|
|
|
|
self.window = window
|
|
|
|
elif isinstance(window, dict):
|
|
|
|
self.window = window
|
|
|
|
else:
|
|
|
|
raise ValueError()
|
|
|
|
|
|
|
|
self.conv_list = nn.ModuleList()
|
|
|
|
self.head_splits = []
|
|
|
|
for cur_window, cur_head_split in window.items():
|
|
|
|
dilation = 1
|
|
|
|
# Determine padding size.
|
|
|
|
# Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
|
|
|
|
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
|
|
|
|
cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
|
|
|
|
kernel_size=(cur_window, cur_window),
|
|
|
|
padding=(padding_size, padding_size),
|
|
|
|
dilation=(dilation, dilation),
|
|
|
|
groups=cur_head_split*Ch,
|
|
|
|
)
|
|
|
|
self.conv_list.append(cur_conv)
|
|
|
|
self.head_splits.append(cur_head_split)
|
|
|
|
self.channel_splits = [x*Ch for x in self.head_splits]
|
|
|
|
|
|
|
|
def forward(self, q, v, size: Tuple[int, int]):
|
|
|
|
B, h, N, Ch = q.shape
|
|
|
|
H, W = size
|
|
|
|
assert N == 1 + H * W
|
|
|
|
|
|
|
|
# Convolutional relative position encoding.
|
|
|
|
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
|
|
|
v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
|
|
|
|
|
|
|
|
v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W)
|
|
|
|
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
|
|
|
|
conv_v_img_list = []
|
|
|
|
for i, conv in enumerate(self.conv_list):
|
|
|
|
conv_v_img_list.append(conv(v_img_list[i]))
|
|
|
|
conv_v_img = torch.cat(conv_v_img_list, dim=1)
|
|
|
|
conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2)
|
|
|
|
|
|
|
|
EV_hat = q_img * conv_v_img
|
|
|
|
EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
|
|
|
|
return EV_hat
|
|
|
|
|
|
|
|
|
|
|
|
class FactorAtt_ConvRelPosEnc(nn.Module):
|
|
|
|
""" Factorized attention with convolutional relative position encoding class. """
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None):
|
|
|
|
super().__init__()
|
|
|
|
self.num_heads = num_heads
|
|
|
|
head_dim = dim // num_heads
|
|
|
|
self.scale = qk_scale or head_dim ** -0.5
|
|
|
|
|
|
|
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
|
|
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
|
|
|
|
self.proj = nn.Linear(dim, dim)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
# Shared convolutional relative position encoding.
|
|
|
|
self.crpe = shared_crpe
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
B, N, C = x.shape
|
|
|
|
|
|
|
|
# Generate Q, K, V.
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch]
|
|
|
|
|
|
|
|
# Factorized attention.
|
|
|
|
k_softmax = k.softmax(dim=2)
|
|
|
|
factor_att = k_softmax.transpose(-1, -2) @ v
|
|
|
|
factor_att = q @ factor_att
|
|
|
|
|
|
|
|
# Convolutional relative position encoding.
|
|
|
|
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
|
|
|
|
|
|
|
|
# Merge and reshape.
|
|
|
|
x = self.scale * factor_att + crpe
|
|
|
|
x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
|
|
|
|
|
|
|
|
# Output projection.
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
""" Convolutional Position Encoding.
|
|
|
|
Note: This module is similar to the conditional position encoding in CPVT.
|
|
|
|
"""
|
|
|
|
def __init__(self, dim, k=3):
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
B, N, C = x.shape
|
|
|
|
H, W = size
|
|
|
|
assert N == 1 + H * W
|
|
|
|
|
|
|
|
# Extract CLS token and image tokens.
|
|
|
|
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
|
|
|
|
|
|
|
|
# Depthwise convolution.
|
|
|
|
feat = img_tokens.transpose(1, 2).view(B, C, H, W)
|
|
|
|
x = self.proj(feat) + feat
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
# Combine with CLS token.
|
|
|
|
x = torch.cat((cls_token, x), dim=1)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class SerialBlock(nn.Module):
|
|
|
|
""" Serial block class.
|
|
|
|
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
|
|
|
|
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
shared_cpe=None, shared_crpe=None):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Conv-Attention.
|
|
|
|
self.cpe = shared_cpe
|
|
|
|
|
|
|
|
self.norm1 = norm_layer(dim)
|
|
|
|
self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
|
|
|
|
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
|
|
|
shared_crpe=shared_crpe)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
# MLP.
|
|
|
|
self.norm2 = norm_layer(dim)
|
|
|
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
|
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
# Conv-Attention.
|
|
|
|
x = self.cpe(x, size)
|
|
|
|
cur = self.norm1(x)
|
|
|
|
cur = self.factoratt_crpe(cur, size)
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
|
|
|
|
# MLP.
|
|
|
|
cur = self.norm2(x)
|
|
|
|
cur = self.mlp(cur)
|
|
|
|
x = x + self.drop_path(cur)
|
|
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelBlock(nn.Module):
|
|
|
|
""" Parallel block class. """
|
|
|
|
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
|
|
|
|
shared_cpes=None, shared_crpes=None):
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
# Conv-Attention.
|
|
|
|
self.cpes = shared_cpes
|
|
|
|
|
|
|
|
self.norm12 = norm_layer(dims[1])
|
|
|
|
self.norm13 = norm_layer(dims[2])
|
|
|
|
self.norm14 = norm_layer(dims[3])
|
|
|
|
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
|
|
|
|
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
|
|
|
shared_crpe=shared_crpes[1]
|
|
|
|
)
|
|
|
|
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
|
|
|
|
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
|
|
|
shared_crpe=shared_crpes[2]
|
|
|
|
)
|
|
|
|
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
|
|
|
|
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
|
|
|
|
shared_crpe=shared_crpes[3]
|
|
|
|
)
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
|
|
|
|
|
|
|
# MLP.
|
|
|
|
self.norm22 = norm_layer(dims[1])
|
|
|
|
self.norm23 = norm_layer(dims[2])
|
|
|
|
self.norm24 = norm_layer(dims[3])
|
|
|
|
# In parallel block, we assume dimensions are the same and share the linear transformation.
|
|
|
|
assert dims[1] == dims[2] == dims[3]
|
|
|
|
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
|
|
|
|
mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
|
|
|
|
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
|
|
|
|
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
|
|
|
|
|
|
|
def upsample(self, x, factor, size):
|
|
|
|
""" Feature map up-sampling. """
|
|
|
|
return self.interpolate(x, scale_factor=factor, size=size)
|
|
|
|
|
|
|
|
def downsample(self, x, factor, size):
|
|
|
|
""" Feature map down-sampling. """
|
|
|
|
return self.interpolate(x, scale_factor=1.0/factor, size=size)
|
|
|
|
|
|
|
|
def interpolate(self, x, scale_factor, size):
|
|
|
|
""" Feature map interpolation. """
|
|
|
|
B, N, C = x.shape
|
|
|
|
H, W = size
|
|
|
|
assert N == 1 + H * W
|
|
|
|
|
|
|
|
cls_token = x[:, :1, :]
|
|
|
|
img_tokens = x[:, 1:, :]
|
|
|
|
|
|
|
|
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
|
|
|
|
img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear')
|
|
|
|
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
|
|
|
|
|
|
|
|
out = torch.cat((cls_token, img_tokens), dim=1)
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
def forward(self, x1, x2, x3, x4, sizes):
|
|
|
|
_, (H2, W2), (H3, W3), (H4, W4) = sizes
|
|
|
|
|
|
|
|
# Conv-Attention.
|
|
|
|
x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored.
|
|
|
|
x3 = self.cpes[2](x3, size=(H3, W3))
|
|
|
|
x4 = self.cpes[3](x4, size=(H4, W4))
|
|
|
|
|
|
|
|
cur2 = self.norm12(x2)
|
|
|
|
cur3 = self.norm13(x3)
|
|
|
|
cur4 = self.norm14(x4)
|
|
|
|
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2))
|
|
|
|
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3))
|
|
|
|
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4))
|
|
|
|
upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3))
|
|
|
|
upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4))
|
|
|
|
upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4))
|
|
|
|
downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2))
|
|
|
|
downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3))
|
|
|
|
downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2))
|
|
|
|
cur2 = cur2 + upsample3_2 + upsample4_2
|
|
|
|
cur3 = cur3 + upsample4_3 + downsample2_3
|
|
|
|
cur4 = cur4 + downsample3_4 + downsample2_4
|
|
|
|
x2 = x2 + self.drop_path(cur2)
|
|
|
|
x3 = x3 + self.drop_path(cur3)
|
|
|
|
x4 = x4 + self.drop_path(cur4)
|
|
|
|
|
|
|
|
# MLP.
|
|
|
|
cur2 = self.norm22(x2)
|
|
|
|
cur3 = self.norm23(x3)
|
|
|
|
cur4 = self.norm24(x4)
|
|
|
|
cur2 = self.mlp2(cur2)
|
|
|
|
cur3 = self.mlp3(cur3)
|
|
|
|
cur4 = self.mlp4(cur4)
|
|
|
|
x2 = x2 + self.drop_path(cur2)
|
|
|
|
x3 = x3 + self.drop_path(cur3)
|
|
|
|
x4 = x4 + self.drop_path(cur4)
|
|
|
|
|
|
|
|
return x1, x2, x3, x4
|
|
|
|
|
|
|
|
|
|
|
|
class CoaT(nn.Module):
|
|
|
|
""" CoaT class. """
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0],
|
|
|
|
serial_depths=[0, 0, 0, 0], parallel_depth=0,
|
|
|
|
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
|
|
|
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
|
|
|
return_interm_layers=False, out_features = None, crpe_window=None, **kwargs):
|
|
|
|
super().__init__()
|
|
|
|
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
|
|
|
|
self.return_interm_layers = return_interm_layers
|
|
|
|
self.out_features = out_features
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
|
|
|
# Patch embeddings.
|
|
|
|
self.patch_embed1 = PatchEmbed(
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
|
|
|
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
|
|
|
|
self.patch_embed2 = PatchEmbed(
|
|
|
|
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
|
|
|
|
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
|
|
|
|
self.patch_embed3 = PatchEmbed(
|
|
|
|
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
|
|
|
|
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
|
|
|
|
self.patch_embed4 = PatchEmbed(
|
|
|
|
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
|
|
|
|
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
|
|
|
|
|
|
|
|
# Class tokens.
|
|
|
|
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
|
|
|
|
self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
|
|
|
|
self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
|
|
|
|
self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
|
|
|
|
|
|
|
|
# Convolutional position encodings.
|
|
|
|
self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
|
|
|
|
self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
|
|
|
|
self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
|
|
|
|
self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
|
|
|
|
|
|
|
|
# Convolutional relative position encodings.
|
|
|
|
self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window)
|
|
|
|
self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window)
|
|
|
|
self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window)
|
|
|
|
self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window)
|
|
|
|
|
|
|
|
# Disable stochastic depth.
|
|
|
|
dpr = drop_path_rate
|
|
|
|
assert dpr == 0.0
|
|
|
|
|
|
|
|
# Serial blocks 1.
|
|
|
|
self.serial_blocks1 = nn.ModuleList([
|
|
|
|
SerialBlock(
|
|
|
|
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
|
|
|
shared_cpe=self.cpe1, shared_crpe=self.crpe1
|
|
|
|
)
|
|
|
|
for _ in range(serial_depths[0])]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Serial blocks 2.
|
|
|
|
self.serial_blocks2 = nn.ModuleList([
|
|
|
|
SerialBlock(
|
|
|
|
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
|
|
|
shared_cpe=self.cpe2, shared_crpe=self.crpe2
|
|
|
|
)
|
|
|
|
for _ in range(serial_depths[1])]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Serial blocks 3.
|
|
|
|
self.serial_blocks3 = nn.ModuleList([
|
|
|
|
SerialBlock(
|
|
|
|
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
|
|
|
shared_cpe=self.cpe3, shared_crpe=self.crpe3
|
|
|
|
)
|
|
|
|
for _ in range(serial_depths[2])]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Serial blocks 4.
|
|
|
|
self.serial_blocks4 = nn.ModuleList([
|
|
|
|
SerialBlock(
|
|
|
|
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
|
|
|
shared_cpe=self.cpe4, shared_crpe=self.crpe4
|
|
|
|
)
|
|
|
|
for _ in range(serial_depths[3])]
|
|
|
|
)
|
|
|
|
|
|
|
|
# Parallel blocks.
|
|
|
|
self.parallel_depth = parallel_depth
|
|
|
|
if self.parallel_depth > 0:
|
|
|
|
self.parallel_blocks = nn.ModuleList([
|
|
|
|
ParallelBlock(
|
|
|
|
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
|
|
|
|
shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4],
|
|
|
|
shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4]
|
|
|
|
)
|
|
|
|
for _ in range(parallel_depth)]
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
self.parallel_blocks = None
|
|
|
|
|
|
|
|
# Classification head(s).
|
|
|
|
if not self.return_interm_layers:
|
|
|
|
self.norm1 = norm_layer(embed_dims[0])
|
|
|
|
self.norm2 = norm_layer(embed_dims[1])
|
|
|
|
self.norm3 = norm_layer(embed_dims[2])
|
|
|
|
self.norm4 = norm_layer(embed_dims[3])
|
|
|
|
|
|
|
|
if self.parallel_depth > 0:
|
|
|
|
# CoaT series: Aggregate features of last three scales for classification.
|
|
|
|
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
|
|
|
|
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
|
|
|
|
self.head = nn.Linear(embed_dims[3], num_classes)
|
|
|
|
else:
|
|
|
|
# CoaT-Lite series: Use feature of last scale for classification.
|
|
|
|
self.head = nn.Linear(embed_dims[3], num_classes)
|
|
|
|
|
|
|
|
# Initialize weights.
|
|
|
|
trunc_normal_(self.cls_token1, std=.02)
|
|
|
|
trunc_normal_(self.cls_token2, std=.02)
|
|
|
|
trunc_normal_(self.cls_token3, std=.02)
|
|
|
|
trunc_normal_(self.cls_token4, std=.02)
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
if isinstance(m, nn.Linear):
|
|
|
|
trunc_normal_(m.weight, std=.02)
|
|
|
|
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
elif isinstance(m, nn.LayerNorm):
|
|
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
nn.init.constant_(m.weight, 1.0)
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
def no_weight_decay(self):
|
|
|
|
return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
return self.head
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
def insert_cls(self, x, cls_token):
|
|
|
|
""" Insert CLS token. """
|
|
|
|
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
def remove_cls(self, x):
|
|
|
|
""" Remove CLS token. """
|
|
|
|
return x[:, 1:, :]
|
|
|
|
|
|
|
|
def forward_features(self, x0):
|
|
|
|
B = x0.shape[0]
|
|
|
|
|
|
|
|
# Serial blocks 1.
|
|
|
|
x1 = self.patch_embed1(x0)
|
|
|
|
H1, W1 = self.patch_embed1.out_size
|
|
|
|
x1 = self.insert_cls(x1, self.cls_token1)
|
|
|
|
for blk in self.serial_blocks1:
|
|
|
|
x1 = blk(x1, size=(H1, W1))
|
|
|
|
x1_nocls = self.remove_cls(x1)
|
|
|
|
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
# Serial blocks 2.
|
|
|
|
x2 = self.patch_embed2(x1_nocls)
|
|
|
|
H2, W2 = self.patch_embed2.out_size
|
|
|
|
x2 = self.insert_cls(x2, self.cls_token2)
|
|
|
|
for blk in self.serial_blocks2:
|
|
|
|
x2 = blk(x2, size=(H2, W2))
|
|
|
|
x2_nocls = self.remove_cls(x2)
|
|
|
|
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
# Serial blocks 3.
|
|
|
|
x3 = self.patch_embed3(x2_nocls)
|
|
|
|
H3, W3 = self.patch_embed3.out_size
|
|
|
|
x3 = self.insert_cls(x3, self.cls_token3)
|
|
|
|
for blk in self.serial_blocks3:
|
|
|
|
x3 = blk(x3, size=(H3, W3))
|
|
|
|
x3_nocls = self.remove_cls(x3)
|
|
|
|
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
# Serial blocks 4.
|
|
|
|
x4 = self.patch_embed4(x3_nocls)
|
|
|
|
H4, W4 = self.patch_embed4.out_size
|
|
|
|
x4 = self.insert_cls(x4, self.cls_token4)
|
|
|
|
for blk in self.serial_blocks4:
|
|
|
|
x4 = blk(x4, size=(H4, W4))
|
|
|
|
x4_nocls = self.remove_cls(x4)
|
|
|
|
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
# Only serial blocks: Early return.
|
|
|
|
if self.parallel_blocks is None:
|
|
|
|
if not torch.jit.is_scripting() and self.return_interm_layers:
|
|
|
|
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
|
|
|
|
feat_out = {}
|
|
|
|
if 'x1_nocls' in self.out_features:
|
|
|
|
feat_out['x1_nocls'] = x1_nocls
|
|
|
|
if 'x2_nocls' in self.out_features:
|
|
|
|
feat_out['x2_nocls'] = x2_nocls
|
|
|
|
if 'x3_nocls' in self.out_features:
|
|
|
|
feat_out['x3_nocls'] = x3_nocls
|
|
|
|
if 'x4_nocls' in self.out_features:
|
|
|
|
feat_out['x4_nocls'] = x4_nocls
|
|
|
|
return feat_out
|
|
|
|
else:
|
|
|
|
# Return features for classification.
|
|
|
|
x4 = self.norm4(x4)
|
|
|
|
x4_cls = x4[:, 0]
|
|
|
|
return x4_cls
|
|
|
|
|
|
|
|
# Parallel blocks.
|
|
|
|
for blk in self.parallel_blocks:
|
|
|
|
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
|
|
|
|
|
|
|
|
if not torch.jit.is_scripting() and self.return_interm_layers:
|
|
|
|
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
|
|
|
|
feat_out = {}
|
|
|
|
if 'x1_nocls' in self.out_features:
|
|
|
|
x1_nocls = self.remove_cls(x1)
|
|
|
|
x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
feat_out['x1_nocls'] = x1_nocls
|
|
|
|
if 'x2_nocls' in self.out_features:
|
|
|
|
x2_nocls = self.remove_cls(x2)
|
|
|
|
x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
feat_out['x2_nocls'] = x2_nocls
|
|
|
|
if 'x3_nocls' in self.out_features:
|
|
|
|
x3_nocls = self.remove_cls(x3)
|
|
|
|
x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
feat_out['x3_nocls'] = x3_nocls
|
|
|
|
if 'x4_nocls' in self.out_features:
|
|
|
|
x4_nocls = self.remove_cls(x4)
|
|
|
|
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
feat_out['x4_nocls'] = x4_nocls
|
|
|
|
return feat_out
|
|
|
|
else:
|
|
|
|
x2 = self.norm2(x2)
|
|
|
|
x3 = self.norm3(x3)
|
|
|
|
x4 = self.norm4(x4)
|
|
|
|
x2_cls = x2[:, :1] # [B, 1, C]
|
|
|
|
x3_cls = x3[:, :1]
|
|
|
|
x4_cls = x4[:, :1]
|
|
|
|
merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C]
|
|
|
|
merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]
|
|
|
|
return merged_cls
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
if self.return_interm_layers:
|
|
|
|
# Return intermediate features (for down-stream tasks).
|
|
|
|
return self.forward_features(x)
|
|
|
|
else:
|
|
|
|
# Return features for classification.
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.head(x)
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def coat_tiny(pretrained=False, **kwargs):
|
|
|
|
model = CoaT(
|
|
|
|
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
|
|
|
|
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['coat_tiny']
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def coat_mini(pretrained=False, **kwargs):
|
|
|
|
model = CoaT(
|
|
|
|
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
|
|
|
|
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['coat_mini']
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def coat_lite_tiny(pretrained=False, **kwargs):
|
|
|
|
model = CoaT(
|
|
|
|
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
|
|
|
|
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
|
|
|
|
# FIXME use builder
|
|
|
|
model.default_cfg = default_cfgs['coat_lite_tiny']
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def coat_lite_mini(pretrained=False, **kwargs):
|
|
|
|
model = CoaT(
|
|
|
|
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
|
|
|
|
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
|
|
|
|
# FIXME use builder
|
|
|
|
model.default_cfg = default_cfgs['coat_lite_mini']
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
def coat_lite_small(pretrained=False, **kwargs):
|
|
|
|
model = CoaT(
|
|
|
|
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
|
|
|
|
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
|
|
|
|
model.default_cfg = default_cfgs['coat_lite_small']
|
|
|
|
return model
|