|
|
|
@ -411,7 +411,7 @@ class DaViTStage(nn.Module):
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
stage_blocks.append(nn.ModuleList(*dual_attention_block))
|
|
|
|
|
stage_blocks.append(nn.ModuleList(dual_attention_block))
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList(*stage_blocks)
|
|
|
|
|
|
|
|
|
@ -508,7 +508,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.stages = nn.ModuleList(*stages)
|
|
|
|
|
self.stages = nn.ModuleList(stages)
|
|
|
|
|
|
|
|
|
|
self.norms = norm_layer(self.num_features)
|
|
|
|
|
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
|
|
|
|