|
|
@ -490,7 +490,7 @@ class CoaT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# Serial blocks 1.
|
|
|
|
# Serial blocks 1.
|
|
|
|
x1 = self.patch_embed1(x0)
|
|
|
|
x1 = self.patch_embed1(x0)
|
|
|
|
H1, W1 = self.patch_embed1.out_size
|
|
|
|
H1, W1 = self.patch_embed1.grid_size
|
|
|
|
x1 = self.insert_cls(x1, self.cls_token1)
|
|
|
|
x1 = self.insert_cls(x1, self.cls_token1)
|
|
|
|
for blk in self.serial_blocks1:
|
|
|
|
for blk in self.serial_blocks1:
|
|
|
|
x1 = blk(x1, size=(H1, W1))
|
|
|
|
x1 = blk(x1, size=(H1, W1))
|
|
|
@ -499,7 +499,7 @@ class CoaT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# Serial blocks 2.
|
|
|
|
# Serial blocks 2.
|
|
|
|
x2 = self.patch_embed2(x1_nocls)
|
|
|
|
x2 = self.patch_embed2(x1_nocls)
|
|
|
|
H2, W2 = self.patch_embed2.out_size
|
|
|
|
H2, W2 = self.patch_embed2.grid_size
|
|
|
|
x2 = self.insert_cls(x2, self.cls_token2)
|
|
|
|
x2 = self.insert_cls(x2, self.cls_token2)
|
|
|
|
for blk in self.serial_blocks2:
|
|
|
|
for blk in self.serial_blocks2:
|
|
|
|
x2 = blk(x2, size=(H2, W2))
|
|
|
|
x2 = blk(x2, size=(H2, W2))
|
|
|
@ -508,7 +508,7 @@ class CoaT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# Serial blocks 3.
|
|
|
|
# Serial blocks 3.
|
|
|
|
x3 = self.patch_embed3(x2_nocls)
|
|
|
|
x3 = self.patch_embed3(x2_nocls)
|
|
|
|
H3, W3 = self.patch_embed3.out_size
|
|
|
|
H3, W3 = self.patch_embed3.grid_size
|
|
|
|
x3 = self.insert_cls(x3, self.cls_token3)
|
|
|
|
x3 = self.insert_cls(x3, self.cls_token3)
|
|
|
|
for blk in self.serial_blocks3:
|
|
|
|
for blk in self.serial_blocks3:
|
|
|
|
x3 = blk(x3, size=(H3, W3))
|
|
|
|
x3 = blk(x3, size=(H3, W3))
|
|
|
@ -517,7 +517,7 @@ class CoaT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# Serial blocks 4.
|
|
|
|
# Serial blocks 4.
|
|
|
|
x4 = self.patch_embed4(x3_nocls)
|
|
|
|
x4 = self.patch_embed4(x3_nocls)
|
|
|
|
H4, W4 = self.patch_embed4.out_size
|
|
|
|
H4, W4 = self.patch_embed4.grid_size
|
|
|
|
x4 = self.insert_cls(x4, self.cls_token4)
|
|
|
|
x4 = self.insert_cls(x4, self.cls_token4)
|
|
|
|
for blk in self.serial_blocks4:
|
|
|
|
for blk in self.serial_blocks4:
|
|
|
|
x4 = blk(x4, size=(H4, W4))
|
|
|
|
x4 = blk(x4, size=(H4, W4))
|
|
|
|