Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 6604ba73d7
commit b799ba95e2

@ -486,7 +486,7 @@ class DaViT(nn.Module):
stage_drop_rates = dpr[len(attention_types) * sum(depths[:stage_id]):len(attention_types) * sum(depths[:stage_id + 1])]
stage = DaViTStage(
in_chans if stage_id == 0 else embed_dims[i - 1],
in_chans if stage_id == 0 else embed_dims[stage_id - 1],
embed_dims[stage_id],
depth = 1,
patch_size = patch_size,

Loading…
Cancel
Save