From dbef70fc3190b114811112e8516eb455abaceb14 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 05:26:27 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 55 +++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index b8a9bb66..648df649 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -411,32 +411,35 @@ class DaViT(nn.Module): for stage_id, stage_param in enumerate(self.architecture): layer_offset_id = len(list(itertools.chain(*self.architecture[:stage_id]))) - stage = nn.Sequential([ - nn.Sequential([ - ChannelBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], - norm_layer=nn.LayerNorm, - ffn=ffn, - cpe_act=cpe_act - ) if attention_type == 'channel' else - SpatialBlock( - dim=self.embed_dims[item], - num_heads=self.num_heads[item], - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], - norm_layer=nn.LayerNorm, - ffn=ffn, - cpe_act=cpe_act, - window_size=window_size, - ) if attention_type == 'spatial' else None - for attention_id, attention_type in enumerate(attention_types)] - ) for layer_id, item in enumerate(stage_param) - ]) + stage = nn.Sequential( + nn.ModuleList([ + nn.Sequential( + nn.ModuleList([ + ChannelBlock( + dim=self.embed_dims[item], + num_heads=self.num_heads[item], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act + ) if attention_type == 'channel' else + SpatialBlock( + dim=self.embed_dims[item], + num_heads=self.num_heads[item], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[2 * (layer_id + layer_offset_id) + attention_id], + norm_layer=nn.LayerNorm, + ffn=ffn, + cpe_act=cpe_act, + window_size=window_size, + ) if attention_type == 'spatial' else None + for attention_id, attention_type in enumerate(attention_types)]) + ) for layer_id, item in enumerate(stage_param) + ]) + ) self.stages.add_module(f'stage_{stage_id}', stage)