|
|
|
@ -539,7 +539,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_network(self, x):
|
|
|
|
|
def forward_network(self, x : Tensor):
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
features = [x]
|
|
|
|
|
sizes = [size]
|
|
|
|
|