@ -111,7 +111,7 @@ class PatchEmbed(nn.Module):
self.norm = nn.LayerNorm(in_chans)
def forward(self, x, size: Tuple[int, int]):
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
dim = len(x.shape)
if dim == 3: