Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent a828ccaf88
commit 8164c0aa30

@ -68,8 +68,8 @@ class ConvPosEnc(nn.Module):
return x
class PatchEmbed(nn.Module):
@register_notrace_module
class PatchEmbedOld(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
@ -117,12 +117,65 @@ class PatchEmbed(nn.Module):
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x)
#x = x.flatten(2).transpose(1, 2)
if self.norm.normalized_shape[0] == self.embed_dim:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
@register_notrace_module
class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
def __init__(
self,
patch_size=4,
in_chans=3,
embed_dim=96,
overlapped=False):
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.in_chans = in_chans
self.embed_dim = embed_dim
if patch_size[0] == 4:
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=(7, 7),
stride=patch_size,
padding=(3, 3))
self.norm = nn.LayerNorm(embed_dim)
if patch_size[0] == 2:
kernel = 3 if overlapped else 2
pad = 1 if overlapped else 0
self.proj = nn.Conv2d(
in_chans,
embed_dim,
kernel_size=to_2tuple(kernel),
stride=patch_size,
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
def forward(self, x : Tensor):
B, C, H, W = x.shape
if self.norm.normalized_shape[0] == self.in_chans:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
x = self.proj(x)
if self.norm.normalized_shape[0] == self.embed_dim:
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):

Loading…
Cancel
Save