From 715519a5eff9046e40958b4c222e0e96f75014e9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 6 May 2021 14:08:20 -0700 Subject: [PATCH] Rethink name of patch embed grid info --- timm/models/coat.py | 8 ++++---- timm/models/layers/patch_embed.py | 4 ++-- timm/models/swin_transformer.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/timm/models/coat.py b/timm/models/coat.py index 38bc93a3..cb265522 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -490,7 +490,7 @@ class CoaT(nn.Module): # Serial blocks 1. x1 = self.patch_embed1(x0) - H1, W1 = self.patch_embed1.out_size + H1, W1 = self.patch_embed1.grid_size x1 = self.insert_cls(x1, self.cls_token1) for blk in self.serial_blocks1: x1 = blk(x1, size=(H1, W1)) @@ -499,7 +499,7 @@ class CoaT(nn.Module): # Serial blocks 2. x2 = self.patch_embed2(x1_nocls) - H2, W2 = self.patch_embed2.out_size + H2, W2 = self.patch_embed2.grid_size x2 = self.insert_cls(x2, self.cls_token2) for blk in self.serial_blocks2: x2 = blk(x2, size=(H2, W2)) @@ -508,7 +508,7 @@ class CoaT(nn.Module): # Serial blocks 3. x3 = self.patch_embed3(x2_nocls) - H3, W3 = self.patch_embed3.out_size + H3, W3 = self.patch_embed3.grid_size x3 = self.insert_cls(x3, self.cls_token3) for blk in self.serial_blocks3: x3 = blk(x3, size=(H3, W3)) @@ -517,7 +517,7 @@ class CoaT(nn.Module): # Serial blocks 4. x4 = self.patch_embed4(x3_nocls) - H4, W4 = self.patch_embed4.out_size + H4, W4 = self.patch_embed4.grid_size x4 = self.insert_cls(x4, self.cls_token4) for blk in self.serial_blocks4: x4 = blk(x4, size=(H4, W4)) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index f7a07e18..b06f9982 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -21,8 +21,8 @@ class PatchEmbed(nn.Module): patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size - self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) - self.num_patches = self.out_size[0] * self.out_size[1] + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index 2880aa02..a845f505 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -467,7 +467,7 @@ class SwinTransformer(nn.Module): img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, norm_layer=norm_layer if self.patch_norm else None) num_patches = self.patch_embed.num_patches - self.patch_grid = self.patch_embed.out_size + self.patch_grid = self.patch_embed.grid_size # absolute position embedding if self.ape: