From 76739a7589ebde1fc6b015e5f9f3e2dc8a73299e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 28 Apr 2021 16:31:35 -0700 Subject: [PATCH] CoaT merge. Bit of formatting, fix torchscript (for non features), remove einops/einsum dep, add pretrained weight hub (url) support. --- timm/models/coat.py | 230 +++++++++++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 86 deletions(-) diff --git a/timm/models/coat.py b/timm/models/coat.py index 40dd5e8c..99357fda 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -1,20 +1,25 @@ """ CoaT architecture. +Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399 + +Official CoaT code at: https://github.com/mlpc-ucsd/CoaT + Modified from timm/models/vision_transformer.py """ +from typing import Tuple, Dict, Any, Optional import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.registry import register_model -from einops import rearrange from functools import partial -from torch import nn, einsum +from torch import nn __all__ = [ "coat_tiny", @@ -36,6 +41,19 @@ def _cfg_coat(url='', **kwargs): } +default_cfgs = { + 'coat_tiny': _cfg_coat(), + 'coat_mini': _cfg_coat(), + 'coat_lite_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' + ), + 'coat_lite_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' + ), + 'coat_lite_small': _cfg_coat(), +} + + class Mlp(nn.Module): """ Feed-forward network (FFN, a.k.a. MLP) class. """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): @@ -64,14 +82,17 @@ class ConvRelPosEnc(nn.Module): Ch: Channels per head. h: Number of heads. window: Window size(s) in convolutional relative positional encoding. It can have two forms: - 1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc. - 2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) - It will apply different window size to the attention head splits. + 1. An integer of window size, which assigns all attention heads with the same window s + size in ConvRelPosEnc. + 2. A dict mapping window size to #attention head splits ( + e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) + It will apply different window size to the attention head splits. """ super().__init__() if isinstance(window, int): - window = {window: h} # Set the same window size for all attention heads. + # Set the same window size for all attention heads. + window = {window: h} self.window = window elif isinstance(window, dict): self.window = window @@ -81,8 +102,10 @@ class ConvRelPosEnc(nn.Module): self.conv_list = nn.ModuleList() self.head_splits = [] for cur_window, cur_head_split in window.items(): - dilation = 1 # Use dilation=1 at default. - padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 + dilation = 1 + # Determine padding size. + # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 + padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch, kernel_size=(cur_window, cur_window), padding=(padding_size, padding_size), @@ -93,25 +116,25 @@ class ConvRelPosEnc(nn.Module): self.head_splits.append(cur_head_split) self.channel_splits = [x*Ch for x in self.head_splits] - def forward(self, q, v, size): + def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size assert N == 1 + H * W # Convolutional relative position encoding. - q_img = q[:,:,1:,:] # Shape: [B, h, H*W, Ch]. - v_img = v[:,:,1:,:] # Shape: [B, h, H*W, Ch]. - - v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. - v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels. - conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)] + q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] + v_img = v[:, :, 1:, :] # [B, h, H*W, Ch] + + v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W) + v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels + conv_v_img_list = [] + for i, conv in enumerate(self.conv_list): + conv_v_img_list.append(conv(v_img_list[i])) conv_v_img = torch.cat(conv_v_img_list, dim=1) - conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch]. - - EV_hat_img = q_img * conv_v_img - zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device) - EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch]. + conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2) + EV_hat = q_img * conv_v_img + EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch]. return EV_hat @@ -124,37 +147,37 @@ class FactorAtt_ConvRelPosEnc(nn.Module): self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. + self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) # Shared convolutional relative position encoding. self.crpe = shared_crpe - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape # Generate Q, K, V. - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # Shape: [3, B, h, N, Ch]. - q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch]. + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch] # Factorized attention. - k_softmax = k.softmax(dim=2) # Softmax on dim N. - k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v) # Shape: [B, h, Ch, Ch]. - factor_att = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v) # Shape: [B, h, N, Ch]. + k_softmax = k.softmax(dim=2) + factor_att = k_softmax.transpose(-1, -2) @ v + factor_att = q @ factor_att # Convolutional relative position encoding. - crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch]. + crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] # Merge and reshape. x = self.scale * factor_att + crpe - x = x.transpose(1, 2).reshape(B, N, C) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]. + x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C] # Output projection. x = self.proj(x) x = self.proj_drop(x) - return x # Shape: [B, N, C]. + return x class ConvPosEnc(nn.Module): @@ -165,13 +188,13 @@ class ConvPosEnc(nn.Module): super(ConvPosEnc, self).__init__() self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size assert N == 1 + H * W # Extract CLS token and image tokens. - cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C]. + cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] # Depthwise convolution. feat = img_tokens.transpose(1, 2).view(B, C, H, W) @@ -206,11 +229,11 @@ class SerialBlock(nn.Module): mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): # Conv-Attention. - x = self.cpe(x, size) # Apply convolutional position encoding. + x = self.cpe(x, size) cur = self.norm1(x) - cur = self.factoratt_crpe(cur, size) # Apply factorized attention and convolutional relative position encoding. + cur = self.factoratt_crpe(cur, size) x = x + self.drop_path(cur) # MLP. @@ -252,10 +275,12 @@ class ParallelBlock(nn.Module): self.norm22 = norm_layer(dims[1]) self.norm23 = norm_layer(dims[2]) self.norm24 = norm_layer(dims[3]) - assert dims[1] == dims[2] == dims[3] # In parallel block, we assume dimensions are the same and share the linear transformation. + # In parallel block, we assume dimensions are the same and share the linear transformation. + assert dims[1] == dims[2] == dims[3] assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) - self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp2 = self.mlp3 = self.mlp4 = Mlp( + in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def upsample(self, x, factor, size): """ Feature map up-sampling. """ @@ -271,7 +296,7 @@ class ParallelBlock(nn.Module): H, W = size assert N == 1 + H * W - cls_token = x[:, :1, :] + cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) @@ -293,18 +318,18 @@ class ParallelBlock(nn.Module): cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) - cur2 = self.factoratt_crpe2(cur2, size=(H2,W2)) - cur3 = self.factoratt_crpe3(cur3, size=(H3,W3)) - cur4 = self.factoratt_crpe4(cur4, size=(H4,W4)) - upsample3_2 = self.upsample(cur3, factor=2, size=(H3,W3)) - upsample4_3 = self.upsample(cur4, factor=2, size=(H4,W4)) - upsample4_2 = self.upsample(cur4, factor=4, size=(H4,W4)) - downsample2_3 = self.downsample(cur2, factor=2, size=(H2,W2)) - downsample3_4 = self.downsample(cur3, factor=2, size=(H3,W3)) - downsample2_4 = self.downsample(cur2, factor=4, size=(H2,W2)) - cur2 = cur2 + upsample3_2 + upsample4_2 - cur3 = cur3 + upsample4_3 + downsample2_3 - cur4 = cur4 + downsample3_4 + downsample2_4 + cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) + cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) + cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) + upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) + upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) + upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) + downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) + downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) + downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) + cur2 = cur2 + upsample3_2 + upsample4_2 + cur3 = cur3 + upsample4_3 + downsample2_3 + cur4 = cur4 + downsample3_4 + downsample2_4 x2 = x2 + self.drop_path(cur2) x3 = x3 + self.drop_path(cur3) x4 = x4 + self.drop_path(cur4) @@ -334,8 +359,10 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ f"img_size {img_size} should be divided by patch_size {patch_size}." - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # Note: self.H, self.W and self.num_patches are not used - self.num_patches = self.H * self.W # since the image size may change on the fly. + # Note: self.H, self.W and self.num_patches are not used + # since the image size may change on the fly. + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) @@ -355,18 +382,22 @@ class CoaT(nn.Module): serial_depths=[0, 0, 0, 0], parallel_depth=0, num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features=None, crpe_window={3:2, 5:3, 7:3}, - **kwargs): + return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): super().__init__() + crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers self.out_features = out_features self.num_classes = num_classes # Patch embeddings. - self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) - self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) - self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) - self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + self.patch_embed2 = PatchEmbed( + img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) + self.patch_embed3 = PatchEmbed( + img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) + self.patch_embed4 = PatchEmbed( + img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) # Class tokens. self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) @@ -442,6 +473,8 @@ class CoaT(nn.Module): ) for _ in range(parallel_depth)] ) + else: + self.parallel_blocks = None # Classification head(s). if not self.return_interm_layers: @@ -450,12 +483,14 @@ class CoaT(nn.Module): self.norm3 = norm_layer(embed_dims[2]) self.norm4 = norm_layer(embed_dims[3]) - if self.parallel_depth > 0: # CoaT series: Aggregate features of last three scales for classification. + if self.parallel_depth > 0: + # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) self.head = nn.Linear(embed_dims[3], num_classes) else: - self.head = nn.Linear(embed_dims[3], num_classes) # CoaT-Lite series: Use feature of last scale for classification. + # CoaT-Lite series: Use feature of last scale for classification. + self.head = nn.Linear(embed_dims[3], num_classes) # Initialize weights. trunc_normal_(self.cls_token1, std=.02) @@ -530,8 +565,9 @@ class CoaT(nn.Module): x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() # Only serial blocks: Early return. - if self.parallel_depth == 0: - if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + if self.parallel_blocks is None: + if not torch.jit.is_scripting() and self.return_interm_layers: + # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: feat_out['x1_nocls'] = x1_nocls @@ -542,7 +578,8 @@ class CoaT(nn.Module): if 'x4_nocls' in self.out_features: feat_out['x4_nocls'] = x4_nocls return feat_out - else: # Return features for classification. + else: + # Return features for classification. x4 = self.norm4(x4) x4_cls = x4[:, 0] return x4_cls @@ -551,7 +588,8 @@ class CoaT(nn.Module): for blk in self.parallel_blocks: x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) - if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + if not torch.jit.is_scripting() and self.return_interm_layers: + # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: x1_nocls = self.remove_cls(x1) @@ -574,50 +612,70 @@ class CoaT(nn.Module): x2 = self.norm2(x2) x3 = self.norm3(x3) x4 = self.norm4(x4) - x2_cls = x2[:, :1] # Shape: [B, 1, C]. + x2_cls = x2[:, :1] # [B, 1, C] x3_cls = x3[:, :1] x4_cls = x4[:, :1] - merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # Shape: [B, 3, C]. - merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]. + merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C] + merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C] return merged_cls def forward(self, x): - if self.return_interm_layers: # Return intermediate features (for down-stream tasks). + if self.return_interm_layers: + # Return intermediate features (for down-stream tasks). return self.forward_features(x) - else: # Return features for classification. + else: + # Return features for classification. x = self.forward_features(x) x = self.head(x) return x -# CoaT. @register_model -def coat_tiny(**kwargs): - model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_tiny(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, + num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model.default_cfg = default_cfgs['coat_tiny'] return model + @register_model -def coat_mini(**kwargs): - model = CoaT(patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_mini(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, + num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model.default_cfg = default_cfgs['coat_mini'] return model -# CoaT-Lite. + @register_model -def coat_lite_tiny(**kwargs): - model = CoaT(patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_lite_tiny(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + # FIXME use builder + model.default_cfg = default_cfgs['coat_lite_mini'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model + @register_model -def coat_lite_mini(**kwargs): - model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_lite_mini(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + # FIXME use builder + model.default_cfg = default_cfgs['coat_lite_mini'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model + @register_model -def coat_lite_small(**kwargs): - model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_lite_small(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model.default_cfg = default_cfgs['coat_lite_small'] return model \ No newline at end of file