From c2d5087eae8dfaa70daeba0cd6b29243c31f0bbd Mon Sep 17 00:00:00 2001 From: morizin <50985248+morizin@users.noreply.github.com> Date: Sat, 24 Apr 2021 17:47:57 +0530 Subject: [PATCH 1/6] Add files via upload --- timm/models/coat.py | 623 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 623 insertions(+) create mode 100644 timm/models/coat.py diff --git a/timm/models/coat.py b/timm/models/coat.py new file mode 100644 index 00000000..40dd5e8c --- /dev/null +++ b/timm/models/coat.py @@ -0,0 +1,623 @@ +""" +CoaT architecture. + +Modified from timm/models/vision_transformer.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model + +from einops import rearrange +from functools import partial +from torch import nn, einsum + +__all__ = [ + "coat_tiny", + "coat_mini", + "coat_lite_tiny", + "coat_lite_mini", + "coat_lite_small" +] + + +def _cfg_coat(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +class Mlp(nn.Module): + """ Feed-forward network (FFN, a.k.a. MLP) class. """ + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ConvRelPosEnc(nn.Module): + """ Convolutional relative position encoding. """ + def __init__(self, Ch, h, window): + """ + Initialization. + Ch: Channels per head. + h: Number of heads. + window: Window size(s) in convolutional relative positional encoding. It can have two forms: + 1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc. + 2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) + It will apply different window size to the attention head splits. + """ + super().__init__() + + if isinstance(window, int): + window = {window: h} # Set the same window size for all attention heads. + self.window = window + elif isinstance(window, dict): + self.window = window + else: + raise ValueError() + + self.conv_list = nn.ModuleList() + self.head_splits = [] + for cur_window, cur_head_split in window.items(): + dilation = 1 # Use dilation=1 at default. + padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 + cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch, + kernel_size=(cur_window, cur_window), + padding=(padding_size, padding_size), + dilation=(dilation, dilation), + groups=cur_head_split*Ch, + ) + self.conv_list.append(cur_conv) + self.head_splits.append(cur_head_split) + self.channel_splits = [x*Ch for x in self.head_splits] + + def forward(self, q, v, size): + B, h, N, Ch = q.shape + H, W = size + assert N == 1 + H * W + + # Convolutional relative position encoding. + q_img = q[:,:,1:,:] # Shape: [B, h, H*W, Ch]. + v_img = v[:,:,1:,:] # Shape: [B, h, H*W, Ch]. + + v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. + v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels. + conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)] + conv_v_img = torch.cat(conv_v_img_list, dim=1) + conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch]. + + EV_hat_img = q_img * conv_v_img + zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device) + EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch]. + + return EV_hat + + +class FactorAtt_ConvRelPosEnc(nn.Module): + """ Factorized attention with convolutional relative position encoding class. """ + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # Shared convolutional relative position encoding. + self.crpe = shared_crpe + + def forward(self, x, size): + B, N, C = x.shape + + # Generate Q, K, V. + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # Shape: [3, B, h, N, Ch]. + q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch]. + + # Factorized attention. + k_softmax = k.softmax(dim=2) # Softmax on dim N. + k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v) # Shape: [B, h, Ch, Ch]. + factor_att = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v) # Shape: [B, h, N, Ch]. + + # Convolutional relative position encoding. + crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch]. + + # Merge and reshape. + x = self.scale * factor_att + crpe + x = x.transpose(1, 2).reshape(B, N, C) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]. + + # Output projection. + x = self.proj(x) + x = self.proj_drop(x) + + return x # Shape: [B, N, C]. + + +class ConvPosEnc(nn.Module): + """ Convolutional Position Encoding. + Note: This module is similar to the conditional position encoding in CPVT. + """ + def __init__(self, dim, k=3): + super(ConvPosEnc, self).__init__() + self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) + + def forward(self, x, size): + B, N, C = x.shape + H, W = size + assert N == 1 + H * W + + # Extract CLS token and image tokens. + cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C]. + + # Depthwise convolution. + feat = img_tokens.transpose(1, 2).view(B, C, H, W) + x = self.proj(feat) + feat + x = x.flatten(2).transpose(1, 2) + + # Combine with CLS token. + x = torch.cat((cls_token, x), dim=1) + + return x + + +class SerialBlock(nn.Module): + """ Serial block class. + Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + shared_cpe=None, shared_crpe=None): + super().__init__() + + # Conv-Attention. + self.cpe = shared_cpe + + self.norm1 = norm_layer(dim) + self.factoratt_crpe = FactorAtt_ConvRelPosEnc( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpe) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + # MLP. + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size): + # Conv-Attention. + x = self.cpe(x, size) # Apply convolutional position encoding. + cur = self.norm1(x) + cur = self.factoratt_crpe(cur, size) # Apply factorized attention and convolutional relative position encoding. + x = x + self.drop_path(cur) + + # MLP. + cur = self.norm2(x) + cur = self.mlp(cur) + x = x + self.drop_path(cur) + + return x + + +class ParallelBlock(nn.Module): + """ Parallel block class. """ + def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + shared_cpes=None, shared_crpes=None): + super().__init__() + + # Conv-Attention. + self.cpes = shared_cpes + + self.norm12 = norm_layer(dims[1]) + self.norm13 = norm_layer(dims[2]) + self.norm14 = norm_layer(dims[3]) + self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( + dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpes[1] + ) + self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( + dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpes[2] + ) + self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( + dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + shared_crpe=shared_crpes[3] + ) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + # MLP. + self.norm22 = norm_layer(dims[1]) + self.norm23 = norm_layer(dims[2]) + self.norm24 = norm_layer(dims[3]) + assert dims[1] == dims[2] == dims[3] # In parallel block, we assume dimensions are the same and share the linear transformation. + assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] + mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) + self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def upsample(self, x, factor, size): + """ Feature map up-sampling. """ + return self.interpolate(x, scale_factor=factor, size=size) + + def downsample(self, x, factor, size): + """ Feature map down-sampling. """ + return self.interpolate(x, scale_factor=1.0/factor, size=size) + + def interpolate(self, x, scale_factor, size): + """ Feature map interpolation. """ + B, N, C = x.shape + H, W = size + assert N == 1 + H * W + + cls_token = x[:, :1, :] + img_tokens = x[:, 1:, :] + + img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) + img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') + img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) + + out = torch.cat((cls_token, img_tokens), dim=1) + + return out + + def forward(self, x1, x2, x3, x4, sizes): + _, (H2, W2), (H3, W3), (H4, W4) = sizes + + # Conv-Attention. + x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored. + x3 = self.cpes[2](x3, size=(H3, W3)) + x4 = self.cpes[3](x4, size=(H4, W4)) + + cur2 = self.norm12(x2) + cur3 = self.norm13(x3) + cur4 = self.norm14(x4) + cur2 = self.factoratt_crpe2(cur2, size=(H2,W2)) + cur3 = self.factoratt_crpe3(cur3, size=(H3,W3)) + cur4 = self.factoratt_crpe4(cur4, size=(H4,W4)) + upsample3_2 = self.upsample(cur3, factor=2, size=(H3,W3)) + upsample4_3 = self.upsample(cur4, factor=2, size=(H4,W4)) + upsample4_2 = self.upsample(cur4, factor=4, size=(H4,W4)) + downsample2_3 = self.downsample(cur2, factor=2, size=(H2,W2)) + downsample3_4 = self.downsample(cur3, factor=2, size=(H3,W3)) + downsample2_4 = self.downsample(cur2, factor=4, size=(H2,W2)) + cur2 = cur2 + upsample3_2 + upsample4_2 + cur3 = cur3 + upsample4_3 + downsample2_3 + cur4 = cur4 + downsample3_4 + downsample2_4 + x2 = x2 + self.drop_path(cur2) + x3 = x3 + self.drop_path(cur3) + x4 = x4 + self.drop_path(cur4) + + # MLP. + cur2 = self.norm22(x2) + cur3 = self.norm23(x3) + cur4 = self.norm24(x4) + cur2 = self.mlp2(cur2) + cur3 = self.mlp3(cur3) + cur4 = self.mlp4(cur4) + x2 = x2 + self.drop_path(cur2) + x3 = x3 + self.drop_path(cur3) + x4 = x4 + self.drop_path(cur4) + + return x1, x2, x3, x4 + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # Note: self.H, self.W and self.num_patches are not used + self.num_patches = self.H * self.W # since the image size may change on the fly. + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + _, _, H, W = x.shape + out_H, out_W = H // self.patch_size[0], W // self.patch_size[1] + + x = self.proj(x).flatten(2).transpose(1, 2) + out = self.norm(x) + + return out, (out_H, out_W) + + +class CoaT(nn.Module): + """ CoaT class. """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], + serial_depths=[0, 0, 0, 0], parallel_depth=0, + num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., + drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, out_features=None, crpe_window={3:2, 5:3, 7:3}, + **kwargs): + super().__init__() + self.return_interm_layers = return_interm_layers + self.out_features = out_features + self.num_classes = num_classes + + # Patch embeddings. + self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) + self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) + self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) + + # Class tokens. + self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) + self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1])) + self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2])) + self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3])) + + # Convolutional position encodings. + self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3) + self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3) + self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3) + self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3) + + # Convolutional relative position encodings. + self.crpe1 = ConvRelPosEnc(Ch=embed_dims[0] // num_heads, h=num_heads, window=crpe_window) + self.crpe2 = ConvRelPosEnc(Ch=embed_dims[1] // num_heads, h=num_heads, window=crpe_window) + self.crpe3 = ConvRelPosEnc(Ch=embed_dims[2] // num_heads, h=num_heads, window=crpe_window) + self.crpe4 = ConvRelPosEnc(Ch=embed_dims[3] // num_heads, h=num_heads, window=crpe_window) + + # Disable stochastic depth. + dpr = drop_path_rate + assert dpr == 0.0 + + # Serial blocks 1. + self.serial_blocks1 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe1, shared_crpe=self.crpe1 + ) + for _ in range(serial_depths[0])] + ) + + # Serial blocks 2. + self.serial_blocks2 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe2, shared_crpe=self.crpe2 + ) + for _ in range(serial_depths[1])] + ) + + # Serial blocks 3. + self.serial_blocks3 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe3, shared_crpe=self.crpe3 + ) + for _ in range(serial_depths[2])] + ) + + # Serial blocks 4. + self.serial_blocks4 = nn.ModuleList([ + SerialBlock( + dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpe=self.cpe4, shared_crpe=self.crpe4 + ) + for _ in range(serial_depths[3])] + ) + + # Parallel blocks. + self.parallel_depth = parallel_depth + if self.parallel_depth > 0: + self.parallel_blocks = nn.ModuleList([ + ParallelBlock( + dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], + shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4] + ) + for _ in range(parallel_depth)] + ) + + # Classification head(s). + if not self.return_interm_layers: + self.norm1 = norm_layer(embed_dims[0]) + self.norm2 = norm_layer(embed_dims[1]) + self.norm3 = norm_layer(embed_dims[2]) + self.norm4 = norm_layer(embed_dims[3]) + + if self.parallel_depth > 0: # CoaT series: Aggregate features of last three scales for classification. + assert embed_dims[1] == embed_dims[2] == embed_dims[3] + self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) + self.head = nn.Linear(embed_dims[3], num_classes) + else: + self.head = nn.Linear(embed_dims[3], num_classes) # CoaT-Lite series: Use feature of last scale for classification. + + # Initialize weights. + trunc_normal_(self.cls_token1, std=.02) + trunc_normal_(self.cls_token2, std=.02) + trunc_normal_(self.cls_token3, std=.02) + trunc_normal_(self.cls_token4, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def insert_cls(self, x, cls_token): + """ Insert CLS token. """ + cls_tokens = cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_tokens, x), dim=1) + return x + + def remove_cls(self, x): + """ Remove CLS token. """ + return x[:, 1:, :] + + def forward_features(self, x0): + B = x0.shape[0] + + # Serial blocks 1. + x1, (H1, W1) = self.patch_embed1(x0) + x1 = self.insert_cls(x1, self.cls_token1) + for blk in self.serial_blocks1: + x1 = blk(x1, size=(H1, W1)) + x1_nocls = self.remove_cls(x1) + x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() + + # Serial blocks 2. + x2, (H2, W2) = self.patch_embed2(x1_nocls) + x2 = self.insert_cls(x2, self.cls_token2) + for blk in self.serial_blocks2: + x2 = blk(x2, size=(H2, W2)) + x2_nocls = self.remove_cls(x2) + x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() + + # Serial blocks 3. + x3, (H3, W3) = self.patch_embed3(x2_nocls) + x3 = self.insert_cls(x3, self.cls_token3) + for blk in self.serial_blocks3: + x3 = blk(x3, size=(H3, W3)) + x3_nocls = self.remove_cls(x3) + x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() + + # Serial blocks 4. + x4, (H4, W4) = self.patch_embed4(x3_nocls) + x4 = self.insert_cls(x4, self.cls_token4) + for blk in self.serial_blocks4: + x4 = blk(x4, size=(H4, W4)) + x4_nocls = self.remove_cls(x4) + x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() + + # Only serial blocks: Early return. + if self.parallel_depth == 0: + if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + feat_out = {} + if 'x1_nocls' in self.out_features: + feat_out['x1_nocls'] = x1_nocls + if 'x2_nocls' in self.out_features: + feat_out['x2_nocls'] = x2_nocls + if 'x3_nocls' in self.out_features: + feat_out['x3_nocls'] = x3_nocls + if 'x4_nocls' in self.out_features: + feat_out['x4_nocls'] = x4_nocls + return feat_out + else: # Return features for classification. + x4 = self.norm4(x4) + x4_cls = x4[:, 0] + return x4_cls + + # Parallel blocks. + for blk in self.parallel_blocks: + x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) + + if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + feat_out = {} + if 'x1_nocls' in self.out_features: + x1_nocls = self.remove_cls(x1) + x1_nocls = x1_nocls.reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x1_nocls'] = x1_nocls + if 'x2_nocls' in self.out_features: + x2_nocls = self.remove_cls(x2) + x2_nocls = x2_nocls.reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x2_nocls'] = x2_nocls + if 'x3_nocls' in self.out_features: + x3_nocls = self.remove_cls(x3) + x3_nocls = x3_nocls.reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x3_nocls'] = x3_nocls + if 'x4_nocls' in self.out_features: + x4_nocls = self.remove_cls(x4) + x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() + feat_out['x4_nocls'] = x4_nocls + return feat_out + else: + x2 = self.norm2(x2) + x3 = self.norm3(x3) + x4 = self.norm4(x4) + x2_cls = x2[:, :1] # Shape: [B, 1, C]. + x3_cls = x3[:, :1] + x4_cls = x4[:, :1] + merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # Shape: [B, 3, C]. + merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]. + return merged_cls + + def forward(self, x): + if self.return_interm_layers: # Return intermediate features (for down-stream tasks). + return self.forward_features(x) + else: # Return features for classification. + x = self.forward_features(x) + x = self.head(x) + return x + + +# CoaT. +@register_model +def coat_tiny(**kwargs): + model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model.default_cfg = _cfg_coat() + return model + +@register_model +def coat_mini(**kwargs): + model = CoaT(patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model.default_cfg = _cfg_coat() + return model + +# CoaT-Lite. +@register_model +def coat_lite_tiny(**kwargs): + model = CoaT(patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model.default_cfg = _cfg_coat() + return model + +@register_model +def coat_lite_mini(**kwargs): + model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model.default_cfg = _cfg_coat() + return model + +@register_model +def coat_lite_small(**kwargs): + model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model.default_cfg = _cfg_coat() + return model \ No newline at end of file From 06841427cd308157f631b418ff87cb6076fc4e8f Mon Sep 17 00:00:00 2001 From: morizin <50985248+morizin@users.noreply.github.com> Date: Sat, 24 Apr 2021 18:11:28 +0530 Subject: [PATCH 2/6] Add files via upload --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index dcc9ab89..1dee97e7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -15,7 +15,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*'] +NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*','coat_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures From fd022fd6a2e152191f78b59e7f6aeae1c4864816 Mon Sep 17 00:00:00 2001 From: morizin <50985248+morizin@users.noreply.github.com> Date: Sat, 24 Apr 2021 18:22:36 +0530 Subject: [PATCH 3/6] Update __init__.py --- timm/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 16ced3da..7714def8 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -9,6 +9,7 @@ from .ghostnet import * from .gluon_resnet import * from .gluon_xception import * from .hardcorenas import * +from .coat import * from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * From 1e3b6d4dfc559da8364795b50f149285158dc7fd Mon Sep 17 00:00:00 2001 From: morizin <50985248+morizin@users.noreply.github.com> Date: Sat, 24 Apr 2021 18:22:59 +0530 Subject: [PATCH 4/6] Update __init__.py --- timm/models/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 7714def8..400e1f64 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -1,5 +1,6 @@ from .byoanet import * from .byobnet import * +from .coat import * from .cspnet import * from .densenet import * from .dla import * @@ -9,7 +10,6 @@ from .ghostnet import * from .gluon_resnet import * from .gluon_xception import * from .hardcorenas import * -from .coat import * from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * From 76739a7589ebde1fc6b015e5f9f3e2dc8a73299e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 28 Apr 2021 16:31:35 -0700 Subject: [PATCH 5/6] CoaT merge. Bit of formatting, fix torchscript (for non features), remove einops/einsum dep, add pretrained weight hub (url) support. --- timm/models/coat.py | 230 +++++++++++++++++++++++++++----------------- 1 file changed, 144 insertions(+), 86 deletions(-) diff --git a/timm/models/coat.py b/timm/models/coat.py index 40dd5e8c..99357fda 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -1,20 +1,25 @@ """ CoaT architecture. +Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399 + +Official CoaT code at: https://github.com/mlpc-ucsd/CoaT + Modified from timm/models/vision_transformer.py """ +from typing import Tuple, Dict, Any, Optional import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.helpers import load_pretrained from timm.models.layers import DropPath, to_2tuple, trunc_normal_ from timm.models.registry import register_model -from einops import rearrange from functools import partial -from torch import nn, einsum +from torch import nn __all__ = [ "coat_tiny", @@ -36,6 +41,19 @@ def _cfg_coat(url='', **kwargs): } +default_cfgs = { + 'coat_tiny': _cfg_coat(), + 'coat_mini': _cfg_coat(), + 'coat_lite_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' + ), + 'coat_lite_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' + ), + 'coat_lite_small': _cfg_coat(), +} + + class Mlp(nn.Module): """ Feed-forward network (FFN, a.k.a. MLP) class. """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): @@ -64,14 +82,17 @@ class ConvRelPosEnc(nn.Module): Ch: Channels per head. h: Number of heads. window: Window size(s) in convolutional relative positional encoding. It can have two forms: - 1. An integer of window size, which assigns all attention heads with the same window size in ConvRelPosEnc. - 2. A dict mapping window size to #attention head splits (e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) - It will apply different window size to the attention head splits. + 1. An integer of window size, which assigns all attention heads with the same window s + size in ConvRelPosEnc. + 2. A dict mapping window size to #attention head splits ( + e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2}) + It will apply different window size to the attention head splits. """ super().__init__() if isinstance(window, int): - window = {window: h} # Set the same window size for all attention heads. + # Set the same window size for all attention heads. + window = {window: h} self.window = window elif isinstance(window, dict): self.window = window @@ -81,8 +102,10 @@ class ConvRelPosEnc(nn.Module): self.conv_list = nn.ModuleList() self.head_splits = [] for cur_window, cur_head_split in window.items(): - dilation = 1 # Use dilation=1 at default. - padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 # Determine padding size. Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 + dilation = 1 + # Determine padding size. + # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338 + padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2 cur_conv = nn.Conv2d(cur_head_split*Ch, cur_head_split*Ch, kernel_size=(cur_window, cur_window), padding=(padding_size, padding_size), @@ -93,25 +116,25 @@ class ConvRelPosEnc(nn.Module): self.head_splits.append(cur_head_split) self.channel_splits = [x*Ch for x in self.head_splits] - def forward(self, q, v, size): + def forward(self, q, v, size: Tuple[int, int]): B, h, N, Ch = q.shape H, W = size assert N == 1 + H * W # Convolutional relative position encoding. - q_img = q[:,:,1:,:] # Shape: [B, h, H*W, Ch]. - v_img = v[:,:,1:,:] # Shape: [B, h, H*W, Ch]. - - v_img = rearrange(v_img, 'B h (H W) Ch -> B (h Ch) H W', H=H, W=W) # Shape: [B, h, H*W, Ch] -> [B, h*Ch, H, W]. - v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels. - conv_v_img_list = [conv(x) for conv, x in zip(self.conv_list, v_img_list)] + q_img = q[:, :, 1:, :] # [B, h, H*W, Ch] + v_img = v[:, :, 1:, :] # [B, h, H*W, Ch] + + v_img = v_img.transpose(-1, -2).reshape(B, h * Ch, H, W) + v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels + conv_v_img_list = [] + for i, conv in enumerate(self.conv_list): + conv_v_img_list.append(conv(v_img_list[i])) conv_v_img = torch.cat(conv_v_img_list, dim=1) - conv_v_img = rearrange(conv_v_img, 'B (h Ch) H W -> B h (H W) Ch', h=h) # Shape: [B, h*Ch, H, W] -> [B, h, H*W, Ch]. - - EV_hat_img = q_img * conv_v_img - zero = torch.zeros((B, h, 1, Ch), dtype=q.dtype, layout=q.layout, device=q.device) - EV_hat = torch.cat((zero, EV_hat_img), dim=2) # Shape: [B, h, N, Ch]. + conv_v_img = conv_v_img.reshape(B, h, Ch, H * W).transpose(-1, -2) + EV_hat = q_img * conv_v_img + EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch]. return EV_hat @@ -124,37 +147,37 @@ class FactorAtt_ConvRelPosEnc(nn.Module): self.scale = qk_scale or head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. + self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) # Shared convolutional relative position encoding. self.crpe = shared_crpe - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape # Generate Q, K, V. - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # Shape: [3, B, h, N, Ch]. - q, k, v = qkv[0], qkv[1], qkv[2] # Shape: [B, h, N, Ch]. + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [B, h, N, Ch] # Factorized attention. - k_softmax = k.softmax(dim=2) # Softmax on dim N. - k_softmax_T_dot_v = einsum('b h n k, b h n v -> b h k v', k_softmax, v) # Shape: [B, h, Ch, Ch]. - factor_att = einsum('b h n k, b h k v -> b h n v', q, k_softmax_T_dot_v) # Shape: [B, h, N, Ch]. + k_softmax = k.softmax(dim=2) + factor_att = k_softmax.transpose(-1, -2) @ v + factor_att = q @ factor_att # Convolutional relative position encoding. - crpe = self.crpe(q, v, size=size) # Shape: [B, h, N, Ch]. + crpe = self.crpe(q, v, size=size) # [B, h, N, Ch] # Merge and reshape. x = self.scale * factor_att + crpe - x = x.transpose(1, 2).reshape(B, N, C) # Shape: [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]. + x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C] # Output projection. x = self.proj(x) x = self.proj_drop(x) - return x # Shape: [B, N, C]. + return x class ConvPosEnc(nn.Module): @@ -165,13 +188,13 @@ class ConvPosEnc(nn.Module): super(ConvPosEnc, self).__init__() self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): B, N, C = x.shape H, W = size assert N == 1 + H * W # Extract CLS token and image tokens. - cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C]. + cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C] # Depthwise convolution. feat = img_tokens.transpose(1, 2).view(B, C, H, W) @@ -206,11 +229,11 @@ class SerialBlock(nn.Module): mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def forward(self, x, size): + def forward(self, x, size: Tuple[int, int]): # Conv-Attention. - x = self.cpe(x, size) # Apply convolutional position encoding. + x = self.cpe(x, size) cur = self.norm1(x) - cur = self.factoratt_crpe(cur, size) # Apply factorized attention and convolutional relative position encoding. + cur = self.factoratt_crpe(cur, size) x = x + self.drop_path(cur) # MLP. @@ -252,10 +275,12 @@ class ParallelBlock(nn.Module): self.norm22 = norm_layer(dims[1]) self.norm23 = norm_layer(dims[2]) self.norm24 = norm_layer(dims[3]) - assert dims[1] == dims[2] == dims[3] # In parallel block, we assume dimensions are the same and share the linear transformation. + # In parallel block, we assume dimensions are the same and share the linear transformation. + assert dims[1] == dims[2] == dims[3] assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3] mlp_hidden_dim = int(dims[1] * mlp_ratios[1]) - self.mlp2 = self.mlp3 = self.mlp4 = Mlp(in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp2 = self.mlp3 = self.mlp4 = Mlp( + in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def upsample(self, x, factor, size): """ Feature map up-sampling. """ @@ -271,7 +296,7 @@ class ParallelBlock(nn.Module): H, W = size assert N == 1 + H * W - cls_token = x[:, :1, :] + cls_token = x[:, :1, :] img_tokens = x[:, 1:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) @@ -293,18 +318,18 @@ class ParallelBlock(nn.Module): cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) - cur2 = self.factoratt_crpe2(cur2, size=(H2,W2)) - cur3 = self.factoratt_crpe3(cur3, size=(H3,W3)) - cur4 = self.factoratt_crpe4(cur4, size=(H4,W4)) - upsample3_2 = self.upsample(cur3, factor=2, size=(H3,W3)) - upsample4_3 = self.upsample(cur4, factor=2, size=(H4,W4)) - upsample4_2 = self.upsample(cur4, factor=4, size=(H4,W4)) - downsample2_3 = self.downsample(cur2, factor=2, size=(H2,W2)) - downsample3_4 = self.downsample(cur3, factor=2, size=(H3,W3)) - downsample2_4 = self.downsample(cur2, factor=4, size=(H2,W2)) - cur2 = cur2 + upsample3_2 + upsample4_2 - cur3 = cur3 + upsample4_3 + downsample2_3 - cur4 = cur4 + downsample3_4 + downsample2_4 + cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) + cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) + cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) + upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) + upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) + upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) + downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) + downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) + downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) + cur2 = cur2 + upsample3_2 + upsample4_2 + cur3 = cur3 + upsample4_3 + downsample2_3 + cur4 = cur4 + downsample3_4 + downsample2_4 x2 = x2 + self.drop_path(cur2) x3 = x3 + self.drop_path(cur3) x4 = x4 + self.drop_path(cur4) @@ -334,8 +359,10 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ f"img_size {img_size} should be divided by patch_size {patch_size}." - self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] # Note: self.H, self.W and self.num_patches are not used - self.num_patches = self.H * self.W # since the image size may change on the fly. + # Note: self.H, self.W and self.num_patches are not used + # since the image size may change on the fly. + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = nn.LayerNorm(embed_dim) @@ -355,18 +382,22 @@ class CoaT(nn.Module): serial_depths=[0, 0, 0, 0], parallel_depth=0, num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features=None, crpe_window={3:2, 5:3, 7:3}, - **kwargs): + return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): super().__init__() + crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers self.out_features = out_features self.num_classes = num_classes # Patch embeddings. - self.patch_embed1 = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) - self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) - self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) - self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0]) + self.patch_embed2 = PatchEmbed( + img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1]) + self.patch_embed3 = PatchEmbed( + img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2]) + self.patch_embed4 = PatchEmbed( + img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3]) # Class tokens. self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0])) @@ -442,6 +473,8 @@ class CoaT(nn.Module): ) for _ in range(parallel_depth)] ) + else: + self.parallel_blocks = None # Classification head(s). if not self.return_interm_layers: @@ -450,12 +483,14 @@ class CoaT(nn.Module): self.norm3 = norm_layer(embed_dims[2]) self.norm4 = norm_layer(embed_dims[3]) - if self.parallel_depth > 0: # CoaT series: Aggregate features of last three scales for classification. + if self.parallel_depth > 0: + # CoaT series: Aggregate features of last three scales for classification. assert embed_dims[1] == embed_dims[2] == embed_dims[3] self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1) self.head = nn.Linear(embed_dims[3], num_classes) else: - self.head = nn.Linear(embed_dims[3], num_classes) # CoaT-Lite series: Use feature of last scale for classification. + # CoaT-Lite series: Use feature of last scale for classification. + self.head = nn.Linear(embed_dims[3], num_classes) # Initialize weights. trunc_normal_(self.cls_token1, std=.02) @@ -530,8 +565,9 @@ class CoaT(nn.Module): x4_nocls = x4_nocls.reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous() # Only serial blocks: Early return. - if self.parallel_depth == 0: - if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + if self.parallel_blocks is None: + if not torch.jit.is_scripting() and self.return_interm_layers: + # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: feat_out['x1_nocls'] = x1_nocls @@ -542,7 +578,8 @@ class CoaT(nn.Module): if 'x4_nocls' in self.out_features: feat_out['x4_nocls'] = x4_nocls return feat_out - else: # Return features for classification. + else: + # Return features for classification. x4 = self.norm4(x4) x4_cls = x4[:, 0] return x4_cls @@ -551,7 +588,8 @@ class CoaT(nn.Module): for blk in self.parallel_blocks: x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) - if self.return_interm_layers: # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). + if not torch.jit.is_scripting() and self.return_interm_layers: + # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2). feat_out = {} if 'x1_nocls' in self.out_features: x1_nocls = self.remove_cls(x1) @@ -574,50 +612,70 @@ class CoaT(nn.Module): x2 = self.norm2(x2) x3 = self.norm3(x3) x4 = self.norm4(x4) - x2_cls = x2[:, :1] # Shape: [B, 1, C]. + x2_cls = x2[:, :1] # [B, 1, C] x3_cls = x3[:, :1] x4_cls = x4[:, :1] - merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # Shape: [B, 3, C]. - merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C]. + merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C] + merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C] return merged_cls def forward(self, x): - if self.return_interm_layers: # Return intermediate features (for down-stream tasks). + if self.return_interm_layers: + # Return intermediate features (for down-stream tasks). return self.forward_features(x) - else: # Return features for classification. + else: + # Return features for classification. x = self.forward_features(x) x = self.head(x) return x -# CoaT. @register_model -def coat_tiny(**kwargs): - model = CoaT(patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_tiny(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, + num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model.default_cfg = default_cfgs['coat_tiny'] return model + @register_model -def coat_mini(**kwargs): - model = CoaT(patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_mini(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, + num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) + model.default_cfg = default_cfgs['coat_mini'] return model -# CoaT-Lite. + @register_model -def coat_lite_tiny(**kwargs): - model = CoaT(patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_lite_tiny(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + # FIXME use builder + model.default_cfg = default_cfgs['coat_lite_mini'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model + @register_model -def coat_lite_mini(**kwargs): - model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_lite_mini(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + # FIXME use builder + model.default_cfg = default_cfgs['coat_lite_mini'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model + @register_model -def coat_lite_small(**kwargs): - model = CoaT(patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = _cfg_coat() +def coat_lite_small(pretrained=False, **kwargs): + model = CoaT( + patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, + num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) + model.default_cfg = default_cfgs['coat_lite_small'] return model \ No newline at end of file From e5e15754c9e5d7deae65b874dbc918b788ce88b9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 28 Apr 2021 18:09:23 -0700 Subject: [PATCH 6/6] Fix coat first conv ident --- timm/models/coat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/coat.py b/timm/models/coat.py index 99357fda..7b364dae 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -36,7 +36,7 @@ def _cfg_coat(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embed1.proj', 'classifier': 'head', **kwargs } @@ -654,7 +654,7 @@ def coat_lite_tiny(pretrained=False, **kwargs): patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_mini'] + model.default_cfg = default_cfgs['coat_lite_tiny'] if pretrained: load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model