|
|
|
@ -68,8 +68,8 @@ class ConvPosEnc(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
|
@register_notrace_module
|
|
|
|
|
class PatchEmbedOld(nn.Module):
|
|
|
|
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
|
|
|
|
allowing input size to be adjusted during model forward operation
|
|
|
|
|
"""
|
|
|
|
@ -117,12 +117,65 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
#x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
|
if self.norm.normalized_shape[0] == self.embed_dim:
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
@register_notrace_module
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
|
""" Size-agnostic implementation of 2D image to patch embedding,
|
|
|
|
|
allowing input size to be adjusted during model forward operation
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
patch_size=4,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
embed_dim=96,
|
|
|
|
|
overlapped=False):
|
|
|
|
|
super().__init__()
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
self.in_chans = in_chans
|
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
|
|
|
|
|
|
if patch_size[0] == 4:
|
|
|
|
|
self.proj = nn.Conv2d(
|
|
|
|
|
in_chans,
|
|
|
|
|
embed_dim,
|
|
|
|
|
kernel_size=(7, 7),
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=(3, 3))
|
|
|
|
|
self.norm = nn.LayerNorm(embed_dim)
|
|
|
|
|
if patch_size[0] == 2:
|
|
|
|
|
kernel = 3 if overlapped else 2
|
|
|
|
|
pad = 1 if overlapped else 0
|
|
|
|
|
self.proj = nn.Conv2d(
|
|
|
|
|
in_chans,
|
|
|
|
|
embed_dim,
|
|
|
|
|
kernel_size=to_2tuple(kernel),
|
|
|
|
|
stride=patch_size,
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
if self.norm.normalized_shape[0] == self.in_chans:
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
|
|
|
|
|
|
|
|
|
|
x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
|
|
|
|
|
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
|
|
|
|
|
if self.norm.normalized_shape[0] == self.embed_dim:
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
class ChannelAttention(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, num_heads=8, qkv_bias=False):
|
|
|
|
|