|
|
|
@ -167,14 +167,14 @@ class Visformer(nn.Module):
|
|
|
|
|
self.patch_embed1 = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
|
|
|
|
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
|
|
|
|
img_size = [x // 16 for x in img_size]
|
|
|
|
|
img_size = [x // patch_size for x in img_size]
|
|
|
|
|
else:
|
|
|
|
|
if self.init_channels is None:
|
|
|
|
|
self.stem = None
|
|
|
|
|
self.patch_embed1 = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
|
|
|
|
|
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
|
|
|
|
img_size = [x // 8 for x in img_size]
|
|
|
|
|
img_size = [x // (patch_size // 2) for x in img_size]
|
|
|
|
|
else:
|
|
|
|
|
self.stem = nn.Sequential(
|
|
|
|
|
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
|
|
|
|
@ -185,7 +185,7 @@ class Visformer(nn.Module):
|
|
|
|
|
self.patch_embed1 = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
|
|
|
|
|
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
|
|
|
|
|
img_size = [x // 4 for x in img_size]
|
|
|
|
|
img_size = [x // (patch_size // 4) for x in img_size]
|
|
|
|
|
|
|
|
|
|
if self.pos_embed:
|
|
|
|
|
if self.vit_stem:
|
|
|
|
@ -207,7 +207,7 @@ class Visformer(nn.Module):
|
|
|
|
|
self.patch_embed2 = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
|
|
|
|
|
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
|
|
|
|
|
img_size = [x // 2 for x in img_size]
|
|
|
|
|
img_size = [x // (patch_size // 8) for x in img_size]
|
|
|
|
|
if self.pos_embed:
|
|
|
|
|
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
|
|
|
|
|
self.stage2 = nn.ModuleList([
|
|
|
|
@ -224,7 +224,7 @@ class Visformer(nn.Module):
|
|
|
|
|
self.patch_embed3 = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
|
|
|
|
|
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
|
|
|
|
|
img_size = [x // 2 for x in img_size]
|
|
|
|
|
img_size = [x // (patch_size // 8) for x in img_size]
|
|
|
|
|
if self.pos_embed:
|
|
|
|
|
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
|
|
|
|
|
self.stage3 = nn.ModuleList([
|
|
|
|
|