|
|
|
@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
|
|
|
|
|
|
|
|
|
|
from .features import FeatureInfo
|
|
|
|
|
from .fx_features import register_notrace_function, register_notrace_module
|
|
|
|
|
from .helpers import build_model_with_cfg, checkpoint_seq, pretrained_cfg_for_features
|
|
|
|
|
from .helpers import build_model_with_cfg, pretrained_cfg_for_features
|
|
|
|
|
from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp
|
|
|
|
|
from .pretrained import generate_default_cfgs
|
|
|
|
|
from .registry import register_model
|
|
|
|
@ -421,7 +421,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
|
|
|
|
|
stage_blocks.append(nn.ModuleList(dual_attention_block))
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList(stage_blocks)
|
|
|
|
|
self.blocks = SequentialWithSize(*stage_blocks)
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
|
x, size = self.patch_embed(x, size)
|
|
|
|
@ -516,7 +516,6 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.stages = SequentialWithSize(*stages)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
@ -579,7 +578,7 @@ class DaViT(nn.Module):
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
#x, sizes = self.forward_network(x)
|
|
|
|
|
size: Tuple[int, int] = (x.size(2), x.size(3))
|
|
|
|
|
x, size = self.stages(x, size)
|
|
|
|
|
x, size = stages(x, size)
|
|
|
|
|
|
|
|
|
|
# take final feature and norm
|
|
|
|
|
x = self.norms(x)
|
|
|
|
|