Rethink name of patch embed grid info

pull/625/head
Ross Wightman 4 years ago
parent b2c305c2aa
commit 715519a5ef

@ -490,7 +490,7 @@ class CoaT(nn.Module):
# Serial blocks 1. # Serial blocks 1.
x1 = self.patch_embed1(x0) x1 = self.patch_embed1(x0)
H1, W1 = self.patch_embed1.out_size H1, W1 = self.patch_embed1.grid_size
x1 = self.insert_cls(x1, self.cls_token1) x1 = self.insert_cls(x1, self.cls_token1)
for blk in self.serial_blocks1: for blk in self.serial_blocks1:
x1 = blk(x1, size=(H1, W1)) x1 = blk(x1, size=(H1, W1))
@ -499,7 +499,7 @@ class CoaT(nn.Module):
# Serial blocks 2. # Serial blocks 2.
x2 = self.patch_embed2(x1_nocls) x2 = self.patch_embed2(x1_nocls)
H2, W2 = self.patch_embed2.out_size H2, W2 = self.patch_embed2.grid_size
x2 = self.insert_cls(x2, self.cls_token2) x2 = self.insert_cls(x2, self.cls_token2)
for blk in self.serial_blocks2: for blk in self.serial_blocks2:
x2 = blk(x2, size=(H2, W2)) x2 = blk(x2, size=(H2, W2))
@ -508,7 +508,7 @@ class CoaT(nn.Module):
# Serial blocks 3. # Serial blocks 3.
x3 = self.patch_embed3(x2_nocls) x3 = self.patch_embed3(x2_nocls)
H3, W3 = self.patch_embed3.out_size H3, W3 = self.patch_embed3.grid_size
x3 = self.insert_cls(x3, self.cls_token3) x3 = self.insert_cls(x3, self.cls_token3)
for blk in self.serial_blocks3: for blk in self.serial_blocks3:
x3 = blk(x3, size=(H3, W3)) x3 = blk(x3, size=(H3, W3))
@ -517,7 +517,7 @@ class CoaT(nn.Module):
# Serial blocks 4. # Serial blocks 4.
x4 = self.patch_embed4(x3_nocls) x4 = self.patch_embed4(x3_nocls)
H4, W4 = self.patch_embed4.out_size H4, W4 = self.patch_embed4.grid_size
x4 = self.insert_cls(x4, self.cls_token4) x4 = self.insert_cls(x4, self.cls_token4)
for blk in self.serial_blocks4: for blk in self.serial_blocks4:
x4 = blk(x4, size=(H4, W4)) x4 = blk(x4, size=(H4, W4))

@ -21,8 +21,8 @@ class PatchEmbed(nn.Module):
patch_size = to_2tuple(patch_size) patch_size = to_2tuple(patch_size)
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size self.patch_size = patch_size
self.out_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.out_size[0] * self.out_size[1] self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

@ -467,7 +467,7 @@ class SwinTransformer(nn.Module):
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None) norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.patch_grid = self.patch_embed.out_size self.patch_grid = self.patch_embed.grid_size
# absolute position embedding # absolute position embedding
if self.ape: if self.ape:

Loading…
Cancel
Save