|
|
@ -421,7 +421,7 @@ class DaViT(nn.Module):
|
|
|
|
for stage_id, stage_param in enumerate(self.architecture):
|
|
|
|
for stage_id, stage_param in enumerate(self.architecture):
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id])))
|
|
|
|
|
|
|
|
|
|
|
|
stage = nn.ModuleList([
|
|
|
|
stage = MySequential(*[
|
|
|
|
MySequential(*[
|
|
|
|
MySequential(*[
|
|
|
|
ChannelBlock(
|
|
|
|
ChannelBlock(
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|