Update davit.py

pull/1630/head
Fredo Guan 2 years ago
parent 488c263a63
commit 24514efb4e

@ -64,8 +64,11 @@ class ConvPosEnc(nn.Module):
feat = feat.flatten(2).transpose(1, 2)
x = x + self.activation(feat)
return x
# reason: dim in control sequence
# FIXME reimplement in a way that allows tracing
@register_notrace_module
class PatchEmbed(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
@ -100,7 +103,7 @@ class PatchEmbed(nn.Module):
padding=to_2tuple(pad))
self.norm = nn.LayerNorm(in_chans)
@register_notrace_function # reason: dim in control sequence
def forward(self, x : Tensor, size: Tuple[int, int]):
H, W = size
dim = x.dim()

Loading…
Cancel
Save