diff --git a/timm/models/davit.py b/timm/models/davit.py index 72c11679..e9957f19 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -70,6 +70,7 @@ class ConvPosEnc(nn.Module): to_2tuple(k // 2), groups=dim) self.normtype = normtype + self.norm = nn.Identity() if self.normtype == 'batch': self.norm = nn.BatchNorm2d(dim) elif self.normtype == 'layer':