Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent d03e8cc94d
commit 38ea27001f

@ -421,7 +421,7 @@ class DaViT(nn.Module):
for stage_id, stage_param in enumerate(self.architecture): for stage_id, stage_param in enumerate(self.architecture):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = nn.ModuleList([ stage = MySequential(*[
MySequential(*[ MySequential(*[
ChannelBlock( ChannelBlock(
dim=self.embed_dims[item], dim=self.embed_dims[item],

Loading…
Cancel
Save