From 447dd0d23fe1b7711d202ac36b10f903cc8b7d21 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 20:39:36 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 255 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 248 insertions(+), 7 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 18cf5af9..db73b903 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -45,7 +45,7 @@ class SequentialWithSize(nn.Sequential): return x, size -class ConvPosEnc(nn.Module): +class ConvPosEncOld(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): super(ConvPosEnc, self).__init__() @@ -79,10 +79,44 @@ class ConvPosEnc(nn.Module): return x + +class ConvPosEnc(nn.Module): + def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): + + super(ConvPosEnc, self).__init__() + self.proj = nn.Conv2d(dim, + dim, + to_2tuple(k), + to_2tuple(1), + to_2tuple(k // 2), + groups=dim) + self.normtype = normtype + self.norm = nn.Identity() + if self.normtype == 'batch': + self.norm = nn.BatchNorm2d(dim) + elif self.normtype == 'layer': + self.norm = nn.LayerNorm(dim) + self.activation = nn.GELU() if act else nn.Identity() + + def forward(self, x : Tensor): + B, C, H, W = x.shape + + #feat = x.transpose(1, 2).view(B, C, H, W) + feat = self.proj(feat) + if self.normtype == 'batch': + feat = self.norm(feat).flatten(2).transpose(1, 2) + elif self.normtype == 'layer': + feat = self.norm(feat.flatten(2).transpose(1, 2)) + else: + feat = feat.flatten(2).transpose(1, 2) + x = x + self.activation(feat).transpose(1, 2).view(B, C, H, W) + return x + + # reason: dim in control sequence # FIXME reimplement to allow tracing @register_notrace_module -class PatchEmbed(nn.Module): +class PatchEmbedOld(nn.Module): """ Size-agnostic implementation of 2D image to patch embedding, allowing input size to be adjusted during model forward operation """ @@ -141,6 +175,60 @@ class PatchEmbed(nn.Module): x = self.norm(x) return x, newsize +@register_notrace_module +class PatchEmbed(nn.Module): + """ Size-agnostic implementation of 2D image to patch embedding, + allowing input size to be adjusted during model forward operation + """ + + def __init__( + self, + patch_size=4, + in_chans=3, + embed_dim=96, + overlapped=False): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + + if patch_size[0] == 4: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(7, 7), + stride=patch_size, + padding=(3, 3)) + self.norm = nn.LayerNorm(embed_dim) + if patch_size[0] == 2: + kernel = 3 if overlapped else 2 + pad = 1 if overlapped else 0 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=to_2tuple(kernel), + stride=patch_size, + padding=to_2tuple(pad)) + self.norm = nn.LayerNorm(in_chans) + + + def forward(self, x : Tensor): + B, C, H, W = x.shape + if self.norm.normalized_shape[0] == self.in_chans: + x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W) + + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) + #x = x.flatten(2).transpose(1, 2) + if self.norm.normalized_shape[0] == self.embed_dim: + x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W) + return x + class ChannelAttention(nn.Module): @@ -166,9 +254,37 @@ class ChannelAttention(nn.Module): x = x.transpose(1, 2).reshape(B, N, C) x = self.proj(x) return x + +class ChannelAttentionNew(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 -class ChannelBlock(nn.Module): + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x : Tensor): + B, C, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = x.transpose(1, 2).view(B, C, H, W) + return x + + +class ChannelBlockOld(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, @@ -204,6 +320,46 @@ class ChannelBlock(nn.Module): x = x + self.drop_path(self.mlp(self.norm2(x))) return x, size +class ChannelBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, + ffn=True, cpe_act=False): + super().__init__() + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.ffn = ffn + self.norm1 = norm_layer(dim) + self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + + if self.ffn: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer) + + + def forward(self, x : Tensor): + + B, C, H, W = x.shape + + x = self.cpe1(x).flatten(2).transpose(1, 2) + + cur = self.norm1(x) + cur = self.attn(cur) + x = x + self.drop_path(cur) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)).flatten(2).transpose(1, 2) + if self.ffn: + x = x + self.drop_path(self.mlp(self.norm2(x))) + + x = x.transpose(1, 2).view(B, C, H, W) + + return x def window_partition(x : Tensor, window_size: int): """ @@ -319,6 +475,91 @@ class SpatialBlock(nn.Module): act_layer=act_layer) + def forward(self, x : Tensor, size: Tuple[int, int]): + B, C, H, W = x.shape + + + shortcut = self.cpe1(x, size).flatten(2).transpose(1, 2) + x = self.norm1(shortcut) + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x_windows = window_partition(x, self.window_size) + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) + + # merge windows + attn_windows = attn_windows.view(-1, + self.window_size, + self.window_size, + C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) + + #if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(x) + + x = self.cpe2(x, size) + if self.ffn: + x = x + self.drop_path(self.mlp(self.norm2(x))) + + x = x.transpose(1, 2).view(B, C, H, W) + + return x, size + +class SpatialBlockOld(nn.Module): + r""" Windows Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=7, + mlp_ratio=4., qkv_bias=True, drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, + ffn=True, cpe_act=False): + super().__init__() + self.dim = dim + self.ffn = ffn + self.num_heads = num_heads + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + + if self.ffn: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer) + + def forward(self, x : Tensor, size: Tuple[int, int]): H, W = size @@ -357,7 +598,7 @@ class SpatialBlock(nn.Module): if self.ffn: x = x + self.drop_path(self.mlp(self.norm2(x))) return x, size - + class DaViTStage(nn.Module): def __init__( self, @@ -424,9 +665,9 @@ class DaViTStage(nn.Module): cpe_act=cpe_act )) - stage_blocks.append(SequentialWithSize(*dual_attention_block)) + stage_blocks.append(nn.Sequential(*dual_attention_block)) - self.blocks = SequentialWithSize(*stage_blocks) + self.blocks = nn.Sequential(*stage_blocks) def forward(self, x : Tensor, size: Tuple[int, int]): x, size = self.patch_embed(x, size) @@ -519,7 +760,7 @@ class DaViT(nn.Module): self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] - self.stages = SequentialWithSize(*stages) + self.stages = nn.Sequential(*stages) self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)