Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent dd6531c525
commit f47f6fce31

@ -411,8 +411,8 @@ class DaViT(nn.Module):
for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = nn.ModuleList([
nn.ModuleList([
stage = nn.Sequential([
nn.Sequential([
ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],

Loading…
Cancel
Save