Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 3a2dc4fbd8
commit f5da145a76

@ -419,7 +419,7 @@ class DaViT(nn.Module):
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
stage = MySequential(*[ stage = MySequential(*[
Sequential(*[ MySequential(*[
ChannelBlock( ChannelBlock(
dim=self.embed_dims[item], dim=self.embed_dims[item],
num_heads=self.num_heads[item], num_heads=self.num_heads[item],

Loading…
Cancel
Save