Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent facaec52e9
commit d8c8c4f4f2

@ -539,7 +539,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_network(self, x):
def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x]
sizes = [size]

Loading…
Cancel
Save