|
|
|
@ -113,7 +113,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
padding=to_2tuple(pad))
|
|
|
|
|
self.norm = nn.LayerNorm(in_chans)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size):
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
H, W = size
|
|
|
|
|
dim = len(x.shape)
|
|
|
|
|
if dim == 3:
|
|
|
|
@ -186,7 +186,7 @@ class ChannelBlock(nn.Module):
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size):
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
x = self.cpe[0](x, size)
|
|
|
|
|
cur = self.norm1(x)
|
|
|
|
|
cur = self.attn(cur)
|
|
|
|
@ -310,7 +310,7 @@ class SpatialBlock(nn.Module):
|
|
|
|
|
hidden_features=mlp_hidden_dim,
|
|
|
|
|
act_layer=act_layer)
|
|
|
|
|
|
|
|
|
|
def forward(self, x, size):
|
|
|
|
|
def forward(self, x, size: Tuple[int, int]):
|
|
|
|
|
H, W = size
|
|
|
|
|
B, L, C = x.shape
|
|
|
|
|
assert L == H * W, "input feature has wrong size"
|
|
|
|
|