Merge pull request #1034 from zeakey/master

Fix hard-coded strides in VisFormer.
pull/1055/head
Ross Wightman 3 years ago committed by GitHub
commit f55c22bebf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -167,14 +167,14 @@ class Visformer(nn.Module):
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size = [x // 16 for x in img_size]
img_size = [x // patch_size for x in img_size]
else:
if self.init_channels is None:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size = [x // 8 for x in img_size]
img_size = [x // (patch_size // 2) for x in img_size]
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
@ -185,7 +185,7 @@ class Visformer(nn.Module):
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size = [x // 4 for x in img_size]
img_size = [x // (patch_size // 4) for x in img_size]
if self.pos_embed:
if self.vit_stem:
@ -207,7 +207,7 @@ class Visformer(nn.Module):
self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size = [x // 2 for x in img_size]
img_size = [x // (patch_size // 8) for x in img_size]
if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size))
self.stage2 = nn.ModuleList([
@ -224,7 +224,7 @@ class Visformer(nn.Module):
self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size = [x // 2 for x in img_size]
img_size = [x // (patch_size // 8) for x in img_size]
if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size))
self.stage3 = nn.ModuleList([

Loading…
Cancel
Save