CoaT merge. Bit of formatting, fix torchscript (for non features), remove einops/einsum dep, add pretrained weight hub (url) support.

pull/554/head
Ross Wightman 3 years ago
parent 026430c083
commit 76739a7589

@ -1,20 +1,25 @@
"""
CoaT architecture.
Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py
"""
from typing import Tuple, Dict, Any, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from einops import rearrange
from functools import partial
from torch import nn, einsum
from torch import nn
__all__ = [
"coat_tiny",
@ -36,6 +41,19 @@ def _cfg_coat(url='', **kwargs):
}
default_cfgs = {
'coat_tiny': _cfg_coat(),
'coat_mini': _cfg_coat(),
'coat_lite_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
),
'coat_lite_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
),
'coat_lite_small': _cfg_coat(),
}
class Mlp(nn.Module):
""" Feed-forward network (FFN, a.k.a. MLP) class. """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
@ -64,14 +82,17 @@ class ConvRelPosEnc(nn.Module):
Ch: Channels per head.
h: Number of heads.
window: Window size(s) in convolutional relative positional encoding. It can have two forms:
1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc.
2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
It will apply different window size to the attention head splits.
1. An integer of window size, which assigns all attention heads with the same window s
size in ConvRelPosEnc.
2. A dict mapping window size to #attention head splits (
e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
It will apply different window size to the attention head splits.
"""
super().__init__()
if isinstance(window, int):
window = {window: h} # Set the same window size for all attention heads.
# Set the same window size for all attention heads.
window = {window: h}
self.window = window
elif isinstance(window, dict):
self.window = window
@ -81,8 +102,10 @@ class ConvRelPosEnc(nn.Module):
self.conv_list = nn.ModuleList()
self.head_splits = []
for cur_window, cur_head_split in window.items():
dilation = 1 # Use dilation=1 at default.
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
dilation = 1
# Determine padding size.
# Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch,
kernel_size=(cur_window, cur_window),
padding=(padding_size, padding_size),
@ -93,25 +116,25 @@ class ConvRelPosEnc(nn.Module):
self.head_splits.append(cur_head_split)
self.channel_splits = [x*Ch for x in self.head_splits]
def forward(self, q, v, size):
def forward(self, q, v, size: Tuple[int, int]):
B, h, N, Ch = q.shape
H, W = size
assert N == 1 + H * W
# Convolutional relative position encoding.
q_img = q[:,:,1:,:] # Shape: [B, h, H*W, Ch].
v_img = v[:,:,1:,:] # Shape: [B, h, H*W, Ch].
v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W].
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels.
conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)]
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W)
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
conv_v_img_list = []
for i, conv in enumerate(self.conv_list):
conv_v_img_list.append(conv(v_img_list[i]))
conv_v_img = torch.cat(conv_v_img_list, dim=1)
conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch].
EV_hat_img = q_img * conv_v_img
zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device)
EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch].
conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2)
EV_hat = q_img * conv_v_img
EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
return EV_hat
@ -124,37 +147,37 @@ class FactorAtt_ConvRelPosEnc(nn.Module):
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
# Shared convolutional relative position encoding.
self.crpe = shared_crpe
def forward(self, x, size):
def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape
# Generate Q, K, V.
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # Shape: [3, B, h, N, Ch].
q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch].
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch]
# Factorized attention.
k_softmax = k.softmax(dim=2) # Softmax on dim N.
k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v) # Shape: [B, h, Ch, Ch].
factor_att = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v) # Shape: [B, h, N, Ch].
k_softmax = k.softmax(dim=2)
factor_att = k_softmax.transpose(-1, -2) @ v
factor_att = q @ factor_att
# Convolutional relative position encoding.
crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch].
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
# Merge and reshape.
x = self.scale * factor_att + crpe
x = x.transpose(1, 2).reshape(B, N, C) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C].
x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
# Output projection.
x = self.proj(x)
x = self.proj_drop(x)
return x # Shape: [B, N, C].
return x
class ConvPosEnc(nn.Module):
@ -165,13 +188,13 @@ class ConvPosEnc(nn.Module):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
def forward(self, x, size):
def forward(self, x, size: Tuple[int, int]):
B, N, C = x.shape
H, W = size
assert N == 1 + H * W
# Extract CLS token and image tokens.
cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
# Depthwise convolution.
feat = img_tokens.transpose(1, 2).view(B, C, H, W)
@ -206,11 +229,11 @@ class SerialBlock(nn.Module):
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, size):
def forward(self, x, size: Tuple[int, int]):
# Conv-Attention.
x = self.cpe(x, size) # Apply convolutional position encoding.
x = self.cpe(x, size)
cur = self.norm1(x)
cur = self.factoratt_crpe(cur, size) # Apply factorized attention and convolutional relative position encoding.
cur = self.factoratt_crpe(cur, size)
x = x + self.drop_path(cur)
# MLP.
@ -252,10 +275,12 @@ class ParallelBlock(nn.Module):
self.norm22 = norm_layer(dims[1])
self.norm23 = norm_layer(dims[2])
self.norm24 = norm_layer(dims[3])
assert dims[1] == dims[2] == dims[3] # In parallel block, we assume dimensions are the same and share the linear transformation.
# In parallel block, we assume dimensions are the same and share the linear transformation.
assert dims[1] == dims[2] == dims[3]
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def upsample(self, x, factor, size):
""" Feature map up-sampling. """
@ -271,7 +296,7 @@ class ParallelBlock(nn.Module):
H, W = size
assert N == 1 + H * W
cls_token = x[:, :1, :]
cls_token = x[:, :1, :]
img_tokens = x[:, 1:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
@ -293,18 +318,18 @@ class ParallelBlock(nn.Module):
cur2 = self.norm12(x2)
cur3 = self.norm13(x3)
cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2,W2))
cur3 = self.factoratt_crpe3(cur3, size=(H3,W3))
cur4 = self.factoratt_crpe4(cur4, size=(H4,W4))
upsample3_2 = self.upsample(cur3, factor=2, size=(H3,W3))
upsample4_3 = self.upsample(cur4, factor=2, size=(H4,W4))
upsample4_2 = self.upsample(cur4, factor=4, size=(H4,W4))
downsample2_3 = self.downsample(cur2, factor=2, size=(H2,W2))
downsample3_4 = self.downsample(cur3, factor=2, size=(H3,W3))
downsample2_4 = self.downsample(cur2, factor=4, size=(H2,W2))
cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2))
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3))
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4))
upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3))
upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4))
upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4))
downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2))
downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3))
downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2))
cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4
x2 = x2 + self.drop_path(cur2)
x3 = x3 + self.drop_path(cur3)
x4 = x4 + self.drop_path(cur4)
@ -334,8 +359,10 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
f"img_size {img_size} should be divided by patch_size {patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # Note: self.H, self.W and self.num_patches are not used
self.num_patches = self.H * self.W # since the image size may change on the fly.
# Note: self.H, self.W and self.num_patches are not used
# since the image size may change on the fly.
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
@ -355,18 +382,22 @@ class CoaT(nn.Module):
serial_depths=[0, 0, 0, 0], parallel_depth=0,
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features=None, crpe_window={3:2, 5:3, 7:3},
**kwargs):
return_interm_layers=False, out_features = None, crpe_window=None, **kwargs):
super().__init__()
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers
self.out_features = out_features
self.num_classes = num_classes
# Patch embeddings.
self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0])
self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3])
# Class tokens.
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
@ -442,6 +473,8 @@ class CoaT(nn.Module):
)
for _ in range(parallel_depth)]
)
else:
self.parallel_blocks = None
# Classification head(s).
if not self.return_interm_layers:
@ -450,12 +483,14 @@ class CoaT(nn.Module):
self.norm3 = norm_layer(embed_dims[2])
self.norm4 = norm_layer(embed_dims[3])
if self.parallel_depth > 0: # CoaT series: Aggregate features of last three scales for classification.
if self.parallel_depth > 0:
# CoaT series: Aggregate features of last three scales for classification.
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
self.head = nn.Linear(embed_dims[3], num_classes)
else:
self.head = nn.Linear(embed_dims[3], num_classes) # CoaT-Lite series: Use feature of last scale for classification.
# CoaT-Lite series: Use feature of last scale for classification.
self.head = nn.Linear(embed_dims[3], num_classes)
# Initialize weights.
trunc_normal_(self.cls_token1, std=.02)
@ -530,8 +565,9 @@ class CoaT(nn.Module):
x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
# Only serial blocks: Early return.
if self.parallel_depth == 0:
if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
if self.parallel_blocks is None:
if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
feat_out['x1_nocls'] = x1_nocls
@ -542,7 +578,8 @@ class CoaT(nn.Module):
if 'x4_nocls' in self.out_features:
feat_out['x4_nocls'] = x4_nocls
return feat_out
else: # Return features for classification.
else:
# Return features for classification.
x4 = self.norm4(x4)
x4_cls = x4[:, 0]
return x4_cls
@ -551,7 +588,8 @@ class CoaT(nn.Module):
for blk in self.parallel_blocks:
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
if not torch.jit.is_scripting() and self.return_interm_layers:
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
feat_out = {}
if 'x1_nocls' in self.out_features:
x1_nocls = self.remove_cls(x1)
@ -574,50 +612,70 @@ class CoaT(nn.Module):
x2 = self.norm2(x2)
x3 = self.norm3(x3)
x4 = self.norm4(x4)
x2_cls = x2[:, :1] # Shape: [B, 1, C].
x2_cls = x2[:, :1] # [B, 1, C]
x3_cls = x3[:, :1]
x4_cls = x4[:, :1]
merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # Shape: [B, 3, C].
merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C].
merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C]
merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]
return merged_cls
def forward(self, x):
if self.return_interm_layers: # Return intermediate features (for down-stream tasks).
if self.return_interm_layers:
# Return intermediate features (for down-stream tasks).
return self.forward_features(x)
else: # Return features for classification.
else:
# Return features for classification.
x = self.forward_features(x)
x = self.head(x)
return x
# CoaT.
@register_model
def coat_tiny(**kwargs):
model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = _cfg_coat()
def coat_tiny(pretrained=False, **kwargs):
model = CoaT(
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_tiny']
return model
@register_model
def coat_mini(**kwargs):
model = CoaT(patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = _cfg_coat()
def coat_mini(pretrained=False, **kwargs):
model = CoaT(
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_mini']
return model
# CoaT-Lite.
@register_model
def coat_lite_tiny(**kwargs):
model = CoaT(patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = _cfg_coat()
def coat_lite_tiny(pretrained=False, **kwargs):
model = CoaT(
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def coat_lite_mini(**kwargs):
model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = _cfg_coat()
def coat_lite_mini(pretrained=False, **kwargs):
model = CoaT(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
return model
@register_model
def coat_lite_small(**kwargs):
model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = _cfg_coat()
def coat_lite_small(pretrained=False, **kwargs):
model = CoaT(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_lite_small']
return model
Loading…
Cancel
Save