Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent facaec52e9
commit d8c8c4f4f2

@ -539,7 +539,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_network(self, x): def forward_network(self, x : Tensor):
size: Tuple[int, int] = (x.size(2), x.size(3)) size: Tuple[int, int] = (x.size(2), x.size(3))
features = [x] features = [x]
sizes = [size] sizes = [size]

Loading…
Cancel
Save