From b4b8d1ec1860124970a4ec480684b96f66951da6 Mon Sep 17 00:00:00 2001 From: KAI ZHAO Date: Tue, 14 Dec 2021 17:22:54 +0800 Subject: [PATCH] fix hard-coded strides --- timm/models/visformer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 6e832cd0..37284c9d 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -167,14 +167,14 @@ class Visformer(nn.Module): self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size = [x // 16 for x in img_size] + img_size = [x // patch_size for x in img_size] else: if self.init_channels is None: self.stem = None self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size = [x // 8 for x in img_size] + img_size = [x // (patch_size // 2) for x in img_size] else: self.stem = nn.Sequential( nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), @@ -185,7 +185,7 @@ class Visformer(nn.Module): self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) - img_size = [x // 4 for x in img_size] + img_size = [x // (patch_size // 4) for x in img_size] if self.pos_embed: if self.vit_stem: @@ -207,7 +207,7 @@ class Visformer(nn.Module): self.patch_embed2 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) - img_size = [x // 2 for x in img_size] + img_size = [x // (patch_size // 8) for x in img_size] if self.pos_embed: self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, *img_size)) self.stage2 = nn.ModuleList([ @@ -224,7 +224,7 @@ class Visformer(nn.Module): self.patch_embed3 = PatchEmbed( img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) - img_size = [x // 2 for x in img_size] + img_size = [x // (patch_size // 8) for x in img_size] if self.pos_embed: self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, *img_size)) self.stage3 = nn.ModuleList([