Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8965410b55
commit 87abcafaa5

@ -38,7 +38,7 @@ class MySequential(nn.Sequential):
def forward(self, inputs : Tensor, size : Tuple[int, int]): def forward(self, inputs : Tensor, size : Tuple[int, int]):
for module in self: for module in self:
inputs = module(inputs, size) inputs, size = module(inputs, size)
return inputs return inputs
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
@ -185,7 +185,7 @@ class ChannelBlock(nn.Module):
act_layer=act_layer) act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]): def forward(self, x : Tensor, size: Tuple[int, int]):
x = self.cpe[0](x, size) x = self.cpe[0](x, size)
cur = self.norm1(x) cur = self.norm1(x)
cur = self.attn(cur) cur = self.attn(cur)
@ -310,7 +310,7 @@ class SpatialBlock(nn.Module):
act_layer=act_layer) act_layer=act_layer)
def forward(self, x, size: Tuple[int, int]): def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size H, W = size
B, L, C = x.shape B, L, C = x.shape

Loading…
Cancel
Save