Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent bf80d655ee
commit facaec52e9

@ -411,7 +411,7 @@ class DaViTStage(nn.Module):
window_size=window_size,
))
stage_blocks.append(nn.ModuleList(*dual_attention_block))
stage_blocks.append(nn.ModuleList(dual_attention_block))
self.blocks = nn.ModuleList(*stage_blocks)
@ -508,7 +508,7 @@ class DaViT(nn.Module):
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = nn.ModuleList(*stages)
self.stages = nn.ModuleList(stages)
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)

Loading…
Cancel
Save