@ -534,7 +534,7 @@ class DaViT(nn.Module):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stem(x)
x = self.stages(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x