diff --git a/timm/models/davit.py b/timm/models/davit.py index 8c740300..722dac42 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -539,7 +539,7 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - def forward_network(self, x): + def forward_network(self, x : Tensor): size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] sizes = [size]