diff --git a/timm/models/davit.py b/timm/models/davit.py index 51162c5a..0b8c219f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -486,7 +486,7 @@ class DaViT(nn.Module): stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])] stage = DaViTStage( - in_chans if stage_id == 0 else embed_dims[i - 1], + in_chans if stage_id == 0 else embed_dims[stage_id - 1], embed_dims[stage_id], depth = 1, patch_size = patch_size,