From 8164c0aa304a385be6057508d31895a70c466610 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 21:13:44 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 59 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 0ccadd79..2e9d3ab9 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -68,8 +68,8 @@ class ConvPosEnc(nn.Module): return x - -class PatchEmbed(nn.Module): +@register_notrace_module +class PatchEmbedOld(nn.Module): """ Size-agnostic implementation of 2D image to patch embedding, allowing input size to be adjusted during model forward operation """ @@ -117,12 +117,65 @@ class PatchEmbed(nn.Module): x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) - #x = x.flatten(2).transpose(1, 2) + if self.norm.normalized_shape[0] == self.embed_dim: x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x +@register_notrace_module +class PatchEmbed(nn.Module): + """ Size-agnostic implementation of 2D image to patch embedding, + allowing input size to be adjusted during model forward operation + """ + def __init__( + self, + patch_size=4, + in_chans=3, + embed_dim=96, + overlapped=False): + super().__init__() + patch_size = to_2tuple(patch_size) + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + + if patch_size[0] == 4: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(7, 7), + stride=patch_size, + padding=(3, 3)) + self.norm = nn.LayerNorm(embed_dim) + if patch_size[0] == 2: + kernel = 3 if overlapped else 2 + pad = 1 if overlapped else 0 + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=to_2tuple(kernel), + stride=patch_size, + padding=to_2tuple(pad)) + self.norm = nn.LayerNorm(in_chans) + + + def forward(self, x : Tensor): + B, C, H, W = x.shape + if self.norm.normalized_shape[0] == self.in_chans: + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + + + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) + + if self.norm.normalized_shape[0] == self.embed_dim: + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + return x + class ChannelAttention(nn.Module): def __init__(self, dim, num_heads=8, qkv_bias=False):