CoaT merge. Bit of formatting, fix torchscript (for non features), remove einops/einsum dep, add pretrained weight hub (url) support.

pull/554/head
Ross Wightman 3 years ago
parent 026430c083
commit 76739a7589

@ -1,20 +1,25 @@
""" """
CoaT architecture. CoaT architecture.
Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py Modified from timm/models/vision_transformer.py
""" """
from typing import Tuple, Dict, Any, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model from timm.models.registry import register_model
from einops import rearrange
from functools import partial from functools import partial
from torch import nn, einsum from torch import nn
__all__ = [ __all__ = [
"coat_tiny", "coat_tiny",
@ -36,6 +41,19 @@ def _cfg_coat(url='', **kwargs):
} }
default_cfgs = {
'coat_tiny': _cfg_coat(),
'coat_mini': _cfg_coat(),
'coat_lite_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
),
'coat_lite_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
),
'coat_lite_small': _cfg_coat(),
}
class Mlp(nn.Module): class Mlp(nn.Module):
""" Feed-forward network (FFN, a.k.a. MLP) class. """ """ Feed-forward network (FFN, a.k.a. MLP) class. """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
@ -64,14 +82,17 @@ class ConvRelPosEnc(nn.Module):
Ch: Channels per head. Ch: Channels per head.
h: Number of heads. h: Number of heads.
window: Window size(s) in convolutional relative positional encoding. It can have two forms: window: Window size(s) in convolutional relative positional encoding. It can have two forms:
1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc. 1. An integer of window size, which assigns all attention heads with the same window s
2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) size in ConvRelPosEnc.
It will apply different window size to the attention head splits. 2. A dict mapping window size to #attention head splits (
e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
It will apply different window size to the attention head splits.
""" """
super().__init__() super().__init__()
if isinstance(window, int): if isinstance(window, int):
window = {window: h} # Set the same window size for all attention heads. # Set the same window size for all attention heads.
window = {window: h}
self.window = window self.window = window
elif isinstance(window, dict): elif isinstance(window, dict):
self.window = window self.window = window
@ -81,8 +102,10 @@ class ConvRelPosEnc(nn.Module):
self.conv_list = nn.ModuleList() self.conv_list = nn.ModuleList()
self.head_splits = [] self.head_splits = []
for cur_window, cur_head_split in window.items(): for cur_window, cur_head_split in window.items():
dilation = 1 # Use dilation=1 at default. dilation = 1
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 # Determine padding size.
# Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch, cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
kernel_size=(cur_window, cur_window), kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size), padding=(padding_size, padding_size),
@ -93,25 +116,25 @@ class ConvRelPosEnc(nn.Module):
self.head_splits.append(cur_head_split) self.head_splits.append(cur_head_split)
self.channel_splits = [x*Ch for x in self.head_splits] self.channel_splits = [x*Ch for x in self.head_splits]
def forward(self, q, v, size): def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape B, h, N, Ch = q.shape
H, W = size H, W = size
assert N == 1 + H * W assert N == 1 + H * W
# Convolutional relative position encoding. # Convolutional relative position encoding.
q_img = q[:,:,1:,:] # Shape: [B, h, H*W, Ch]. q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
v_img = v[:,:,1:,:] # Shape: [B, h, H*W, Ch]. v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W)
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels. v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)] conv_v_img_list = []
for i, conv in enumerate(self.conv_list):
conv_v_img_list.append(conv(v_img_list[i]))
conv_v_img = torch.cat(conv_v_img_list, dim=1) conv_v_img = torch.cat(conv_v_img_list, dim=1)
conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch]. conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2)
EV_hat_img = q_img * conv_v_img
zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device)
EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch].
EV_hat = q_img * conv_v_img
EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
return EV_hat return EV_hat
@ -124,37 +147,37 @@ class FactorAtt_ConvRelPosEnc(nn.Module):
self.scale = qk_scale or head_dim ** -0.5 self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
self.proj = nn.Linear(dim, dim) self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
# Shared convolutional relative position encoding. # Shared convolutional relative position encoding.
self.crpe = shared_crpe self.crpe = shared_crpe
def forward(self, x, size): def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape B, N, C = x.shape
# Generate Q, K, V. # Generate Q, K, V.
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # Shape: [3, B, h, N, Ch]. qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch]. q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch]
# Factorized attention. # Factorized attention.
k_softmax = k.softmax(dim=2) # Softmax on dim N. k_softmax = k.softmax(dim=2)
k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v) # Shape: [B, h, Ch, Ch]. factor_att = k_softmax.transpose(-1, -2) @ v
factor_att = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v) # Shape: [B, h, N, Ch]. factor_att = q @ factor_att
# Convolutional relative position encoding. # Convolutional relative position encoding.
crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch]. crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
# Merge and reshape. # Merge and reshape.
x = self.scale * factor_att + crpe x = self.scale * factor_att + crpe
x = x.transpose(1, 2).reshape(B, N, C) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]. x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
# Output projection. # Output projection.
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x # Shape: [B, N, C]. return x
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
@ -165,13 +188,13 @@ class ConvPosEnc(nn.Module):
super(ConvPosEnc, self).__init__() super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
def forward(self, x, size): def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape B, N, C = x.shape
H, W = size H, W = size
assert N == 1 + H * W assert N == 1 + H * W
# Extract CLS token and image tokens. # Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C]. cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
# Depthwise convolution. # Depthwise convolution.
feat = img_tokens.transpose(1, 2).view(B, C, H, W) feat = img_tokens.transpose(1, 2).view(B, C, H, W)
@ -206,11 +229,11 @@ class SerialBlock(nn.Module):
mlp_hidden_dim = int(dim * mlp_ratio) mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, size): def forward(self, x, size: Tuple[int, int]):
# Conv-Attention. # Conv-Attention.
x = self.cpe(x, size) # Apply convolutional position encoding. x = self.cpe(x, size)
cur = self.norm1(x) cur = self.norm1(x)
cur = self.factoratt_crpe(cur, size) # Apply factorized attention and convolutional relative position encoding. cur = self.factoratt_crpe(cur, size)
x = x + self.drop_path(cur) x = x + self.drop_path(cur)
# MLP. # MLP.
@ -252,10 +275,12 @@ class ParallelBlock(nn.Module):
self.norm22 = norm_layer(dims[1]) self.norm22 = norm_layer(dims[1])
self.norm23 = norm_layer(dims[2]) self.norm23 = norm_layer(dims[2])
self.norm24 = norm_layer(dims[3]) self.norm24 = norm_layer(dims[3])
assert dims[1] == dims[2] == dims[3] # In parallel block, we assume dimensions are the same and share the linear transformation. # In parallel block, we assume dimensions are the same and share the linear transformation.
assert dims[1] == dims[2] == dims[3]
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def upsample(self, x, factor, size): def upsample(self, x, factor, size):
""" Feature map up-sampling. """ """ Feature map up-sampling. """
@ -271,7 +296,7 @@ class ParallelBlock(nn.Module):
H, W = size H, W = size
assert N == 1 + H * W assert N == 1 + H * W
cls_token = x[:, :1, :] cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :] img_tokens = x[:, 1:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
@ -293,18 +318,18 @@ class ParallelBlock(nn.Module):
cur2 = self.norm12(x2) cur2 = self.norm12(x2)
cur3 = self.norm13(x3) cur3 = self.norm13(x3)
cur4 = self.norm14(x4) cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2,W2)) cur2 = self.factoratt_crpe2(cur2, size=(H2, W2))
cur3 = self.factoratt_crpe3(cur3, size=(H3,W3)) cur3 = self.factoratt_crpe3(cur3, size=(H3, W3))
cur4 = self.factoratt_crpe4(cur4, size=(H4,W4)) cur4 = self.factoratt_crpe4(cur4, size=(H4, W4))
upsample3_2 = self.upsample(cur3, factor=2, size=(H3,W3)) upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3))
upsample4_3 = self.upsample(cur4, factor=2, size=(H4,W4)) upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4))
upsample4_2 = self.upsample(cur4, factor=4, size=(H4,W4)) upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4))
downsample2_3 = self.downsample(cur2, factor=2, size=(H2,W2)) downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2))
downsample3_4 = self.downsample(cur3, factor=2, size=(H3,W3)) downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3))
downsample2_4 = self.downsample(cur2, factor=4, size=(H2,W2)) downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2))
cur2 = cur2 + upsample3_2 + upsample4_2 cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3 cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4 cur4 = cur4 + downsample3_4 + downsample2_4
x2 = x2 + self.drop_path(cur2) x2 = x2 + self.drop_path(cur2)
x3 = x3 + self.drop_path(cur3) x3 = x3 + self.drop_path(cur3)
x4 = x4 + self.drop_path(cur4) x4 = x4 + self.drop_path(cur4)
@ -334,8 +359,10 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size self.patch_size = patch_size
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
f"img_size {img_size} should be divided by patch_size {patch_size}." f"img_size {img_size} should be divided by patch_size {patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # Note: self.H, self.W and self.num_patches are not used # Note: self.H, self.W and self.num_patches are not used
self.num_patches = self.H * self.W # since the image size may change on the fly. # since the image size may change on the fly.
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim) self.norm = nn.LayerNorm(embed_dim)
@ -355,18 +382,22 @@ class CoaT(nn.Module):
serial_depths=[0, 0, 0, 0], parallel_depth=0, serial_depths=[0, 0, 0, 0], parallel_depth=0,
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features=None, crpe_window={3:2, 5:3, 7:3}, return_interm_layers=False, out_features = None, crpe_window=None, **kwargs):
**kwargs):
super().__init__() super().__init__()
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers self.return_interm_layers = return_interm_layers
self.out_features = out_features self.out_features = out_features
self.num_classes = num_classes self.num_classes = num_classes
# Patch embeddings. # Patch embeddings.
self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) self.patch_embed1 = PatchEmbed(
self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) self.patch_embed2 = PatchEmbed(
self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
# Class tokens. # Class tokens.
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
@ -442,6 +473,8 @@ class CoaT(nn.Module):
) )
for _ in range(parallel_depth)] for _ in range(parallel_depth)]
) )
else:
self.parallel_blocks = None
# Classification head(s). # Classification head(s).
if not self.return_interm_layers: if not self.return_interm_layers:
@ -450,12 +483,14 @@ class CoaT(nn.Module):
self.norm3 = norm_layer(embed_dims[2]) self.norm3 = norm_layer(embed_dims[2])
self.norm4 = norm_layer(embed_dims[3]) self.norm4 = norm_layer(embed_dims[3])
if self.parallel_depth > 0: # CoaT series: Aggregate features of last three scales for classification. if self.parallel_depth > 0:
# CoaT series: Aggregate features of last three scales for classification.
assert embed_dims[1] == embed_dims[2] == embed_dims[3] assert embed_dims[1] == embed_dims[2] == embed_dims[3]
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
self.head = nn.Linear(embed_dims[3], num_classes) self.head = nn.Linear(embed_dims[3], num_classes)
else: else:
self.head = nn.Linear(embed_dims[3], num_classes) # CoaT-Lite series: Use feature of last scale for classification. # CoaT-Lite series: Use feature of last scale for classification.
self.head = nn.Linear(embed_dims[3], num_classes)
# Initialize weights. # Initialize weights.
trunc_normal_(self.cls_token1, std=.02) trunc_normal_(self.cls_token1, std=.02)
@ -530,8 +565,9 @@ class CoaT(nn.Module):
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
# Only serial blocks: Early return. # Only serial blocks: Early return.
if self.parallel_depth == 0: if self.parallel_blocks is None:
if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {} feat_out = {}
if 'x1_nocls' in self.out_features: if 'x1_nocls' in self.out_features:
feat_out['x1_nocls'] = x1_nocls feat_out['x1_nocls'] = x1_nocls
@ -542,7 +578,8 @@ class CoaT(nn.Module):
if 'x4_nocls' in self.out_features: if 'x4_nocls' in self.out_features:
feat_out['x4_nocls'] = x4_nocls feat_out['x4_nocls'] = x4_nocls
return feat_out return feat_out
else: # Return features for classification. else:
# Return features for classification.
x4 = self.norm4(x4) x4 = self.norm4(x4)
x4_cls = x4[:, 0] x4_cls = x4[:, 0]
return x4_cls return x4_cls
@ -551,7 +588,8 @@ class CoaT(nn.Module):
for blk in self.parallel_blocks: for blk in self.parallel_blocks:
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {} feat_out = {}
if 'x1_nocls' in self.out_features: if 'x1_nocls' in self.out_features:
x1_nocls = self.remove_cls(x1) x1_nocls = self.remove_cls(x1)
@ -574,50 +612,70 @@ class CoaT(nn.Module):
x2 = self.norm2(x2) x2 = self.norm2(x2)
x3 = self.norm3(x3) x3 = self.norm3(x3)
x4 = self.norm4(x4) x4 = self.norm4(x4)
x2_cls = x2[:, :1] # Shape: [B, 1, C]. x2_cls = x2[:, :1] # [B, 1, C]
x3_cls = x3[:, :1] x3_cls = x3[:, :1]
x4_cls = x4[:, :1] x4_cls = x4[:, :1]
merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # Shape: [B, 3, C]. merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C]
merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]. merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]
return merged_cls return merged_cls
def forward(self, x): def forward(self, x):
if self.return_interm_layers: # Return intermediate features (for down-stream tasks). if self.return_interm_layers:
# Return intermediate features (for down-stream tasks).
return self.forward_features(x) return self.forward_features(x)
else: # Return features for classification. else:
# Return features for classification.
x = self.forward_features(x) x = self.forward_features(x)
x = self.head(x) x = self.head(x)
return x return x
# CoaT.
@register_model @register_model
def coat_tiny(**kwargs): def coat_tiny(pretrained=False, **kwargs):
model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) model = CoaT(
model.default_cfg = _cfg_coat() patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_tiny']
return model return model
@register_model @register_model
def coat_mini(**kwargs): def coat_mini(pretrained=False, **kwargs):
model = CoaT(patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) model = CoaT(
model.default_cfg = _cfg_coat() patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_mini']
return model return model
# CoaT-Lite.
@register_model @register_model
def coat_lite_tiny(**kwargs): def coat_lite_tiny(pretrained=False, **kwargs):
model = CoaT(patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) model = CoaT(
model.default_cfg = _cfg_coat() patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def coat_lite_mini(**kwargs): def coat_lite_mini(pretrained=False, **kwargs):
model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) model = CoaT(
model.default_cfg = _cfg_coat() patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model return model
@register_model @register_model
def coat_lite_small(**kwargs): def coat_lite_small(pretrained=False, **kwargs):
model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) model = CoaT(
model.default_cfg = _cfg_coat() patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_lite_small']
return model return model
Loading…
Cancel
Save